JLNN: Certifiable AI and Formal Verification with Xarray ============================================================ *Proof of safety and real-time semantic traceability* This tutorial demonstrates how **JLNN (JAX Logical Neural Networks)** addresses the "black box" problem in modern AI. Unlike conventional neural networks, JLNN enables formal verification and semantic grounding. .. grid:: 2 .. grid-item-card:: Run in Google Colab :link: https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_safety_verification.ipynb :link-type: url Run the autonomous optimization cycle in your browser. .. grid-item-card:: View on GitHub :link: https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_safety_verification.ipynb :link-type: url Browse the full source code and results. Key Features --------------- 1. **Formal Verification**: Mathematically proving that a model stays within safety bounds in a defined input space. 2. **Semantic Grounding**: Using Xarray to map raw data to human-understandable logical predicates. 3. **Truth Intervals**: Operating with [Lower, Upper] bounds instead of single-point scalars to capture uncertainty. Theory: Certifiability and Intervals --------------------------------------- In JLNN, we do not work with simple scalars, but with **truth intervals [L, U]**. - **L (Lower bound)**: Minimum necessary truth. - **U (Upper bound)**: Maximum possible truth (everything that cannot be disproven). **Formal verification** occurs by feeding an "input box" (e.g., all data where petal length is between 5.5 and 7.0) into the model. JLNN calculates a single output interval. If the upper bound $U$ is low (e.g., < 0.1), we have a mathematical **proof** that the model never misclassifies a sample in this region. Environment Setup -------------------- First, ensure you have JLNN and its dependencies installed: .. code-block:: bash pip install git+https://github.com/RadimKozl/JLNN.git pip install xarray optax scikit-learn matplotlib Data Preparation with Xarray ------------------------------- We use Xarray to maintain semantic metadata, ensuring we track features by name rather than indices. .. code-block:: python import xarray as xr import numpy as np from sklearn.datasets import load_iris iris = load_iris() da = xr.DataArray( iris.data[:, 2:], coords={"sample": range(150), "feature": ["petal_length", "petal_width"]}, dims=("sample", "feature") ) y = (iris.target == 0).astype(int) # Target: Setosa # Normalization da_min, da_max = da.min(dim="sample"), da.max(dim="sample") da_norm = (da - da_min) / (da_max - da_min + 1e-6) Model Definition ------------------- The logical rule for detecting *Iris Setosa* is defined as: **Setosa = NOT (high_length) AND NOT (high_width)**. .. code-block:: python from jlnn.symbolic.compiler import LNNFormula from flax import nnx # Rule with priority 1.0 formula = "1.0::(~high_length & ~high_width)" model = LNNFormula(formula, nnx.Rngs(42)) Training and Safety Verification ----------------------------------- We train the model using **Optax** while monitoring the **Safety Risk (U-max)** in a predefined "dangerous box" (where Setosa should never be predicted). .. code-block:: python ``` try: import jlnn import jraph import numpyro import optax import trimesh from flax import nnx import jax.numpy as jnp import networkx as nx import numpy as np import xarray as xr import pandas as pd import grain.python as grain import optuna import matplotlib.pyplot as plt import sklearn print("✅ JLNN and JAX are ready.") except ImportError: print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...") # Instalace frameworku #!pip install jax-lnn --quiet !pip install git+https://github.com/RadimKozl/JLNN.git --quiet !pip install optuna optuna-dashboard pandas scikit-learn matplotlib --quiet !pip install arviz --quiet !pip install seaborn --quiet !pip install numpyro jraph --quiet !pip install grain --quiet !pip install networkx --quiet !pip install trimesh --quiet !pip install xarray --quiet !pip install kagglehub --quiet !pip install optax --quiet # Fix JAX/CUDA compatibility for 2026 in Colab !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --quiet !pip install scikit-learn pandas --quiet import os print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.") os.kill(os.getpid(), 9) os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts ``` # Imports import jax import jax.numpy as jnp import xarray as xr import numpy as np from flax import nnx import optax import matplotlib.pyplot as plt from matplotlib import animation from IPython.display import HTML from sklearn.datasets import load_iris from jlnn.symbolic.compiler import LNNFormula from jlnn.nn.constraints import apply_constraints, clip_predicates from jlnn.training.losses import total_lnn_loss # DATA PREPARATION (Xarray Semantic Mapping) iris = load_iris() da = xr.DataArray( iris.data[:, 2:], coords={"sample": range(150), "feature": ["petal_length", "petal_width"]}, dims=("sample", "feature") ) y = (iris.target == 0).astype(int) # Normalization da_min = da.min(dim="sample") da_max = da.max(dim="sample") da_norm = (da - da_min) / (da_max - da_min + 1e-6) # Preparing training inputs pl_raw = da_norm.sel(feature="petal_length").values.reshape(-1, 1).astype(np.float32) pw_raw = da_norm.sel(feature="petal_width").values.reshape(-1, 1).astype(np.float32) # Inputs for logical predicates inputs = { "high_length": jnp.array(pl_raw), # (150, 1) "high_width": jnp.array(pw_raw), # (150, 1) } # Target as a truth interval [Lower, Upper] target_interval = jnp.where( y[:, None] > 0.5, jnp.array([[0.9, 1.0]]), jnp.array([[0.0, 0.1]]) ) print(f"Input shape: {inputs['high_length'].shape}") # (150, 1) print(f"Target shape: {target_interval.shape}") # (150, 2) # MODEL DEFINITION (LNN Formula) formula = "1.5::(~high_length & ~high_width)" model = LNNFormula(formula, nnx.Rngs(42)) # Initialization of Predicates (Fine-tuning the "slopes" for better gradients) for name, pred_node in model.predicates.items(): pred = pred_node.predicate center = 0.5 pred.slope_l[...] = jnp.array([5.0]) pred.offset_l[...] = jnp.array([center + 0.02]) # L transition above pred.slope_u[...] = jnp.array([5.0]) pred.offset_u[...] = jnp.array([center - 0.02]) # U transition below → L<=U print(f" Predikát '{name}': slope=3.0, offset_l={center+0.1:.2f}, offset_u={center-0.1:.2f}") # Initialize weights to 3.5 to allow for contraction during learning for path, node in nnx.iter_graph(model): if isinstance(node, nnx.Param): param_name = str(path[-1]) if path else "" if 'weight' in param_name: # Initialize to 3.5 – gradient can reduce weights down to 1.0 node[...] = jnp.full_like(node[...], 3.5) print("\nWeights after re-initialization:") for path, node in nnx.iter_graph(model): if isinstance(node, nnx.Param): param_name = str(path[-1]) if path else "" if 'weight' in param_name or 'beta' in param_name: print(f" {param_name}: {node[...]}") # Pre-training diagnostics print("\n=== Diagnostic step 0 ===") pred0 = model(inputs) print(f" Output shape: {pred0.shape}") print(f" Output sample 0: {pred0[0]}") print(f" L mean: {float(pred0[:,0].mean()):.4f} U mean: {float(pred0[:,1].mean()):.4f}") print(f" Interval width: {float((pred0[:,1]-pred0[:,0]).mean()):.4f}") loss0 = total_lnn_loss(pred0, target_interval) print(f" Loss step 0: {float(loss0):.6f}") _, grads0 = nnx.value_and_grad( lambda m: total_lnn_loss(m(inputs), target_interval) )(model) # Check if gradients are non-zero nonzero_grads = [] for path, node in nnx.iter_graph(grads0): if isinstance(node, nnx.Param) and hasattr(node, '__jax_array__'): try: val = node[...] norm = float(jnp.linalg.norm(val)) if norm > 1e-9: nonzero_grads.append((str(path[-1]) if path else "?", norm)) except Exception: pass # FORMAL VERIFICATION SETUP (Safety Box) def norm_val(val, feat): v_min = da_min.sel(feature=feat).item() v_max = da_max.sel(feature=feat).item() return float((val - v_min) / (v_max - v_min + 1e-6)) # Most dangerous point: large flower (never Setosa) formal_input = { "high_length": jnp.array([[norm_val(5.5, "petal_length")]]), # (1, 1) "high_width": jnp.array([[norm_val(1.8, "petal_width")]]), } # TRAIN STEP schedule = optax.exponential_decay(0.02, transition_steps=300, decay_rate=0.5) optimizer = nnx.Optimizer(model, optax.adam(schedule), wrt=nnx.Param) @nnx.jit def train_step(model, optimizer, inputs, target): def loss_fn(m): pred = m(inputs) pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0) return total_lnn_loss(pred, target) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # ✅ Only clip_predicates (preserves L<=U) # clip_weights OMITTED – weights initialized to 3.5, can learn clip_predicates(model) return loss # TRAINING num_steps = 601 history = xr.Dataset( coords={"step": np.arange(num_steps)}, data_vars={ "loss": ("step", np.zeros(num_steps)), "safety_u": ("step", np.zeros(num_steps)), "unc": ("step", np.zeros(num_steps)), "acc": ("step", np.zeros(num_steps)), } ) print("🚀 Starting training...\n") for step in range(num_steps): loss_val = train_step(model, optimizer, inputs, target_interval) preds_f = model(formal_input) preds_f = jnp.nan_to_num(preds_f, nan=0.5) u_max = float(preds_f[0, 0, 1]) preds_tr = model(inputs) preds_tr = jnp.nan_to_num(preds_tr, nan=0.5) # Corrected indexing: Access the 0th output's bounds avg_unc = float((preds_tr[:, 0, 1] - preds_tr[:, 0, 0]).mean()) center = (preds_tr[:, 0, 0] + preds_tr[:, 0, 1]) / 2.0 acc = float(((center > 0.5).astype(float) == y).mean()) history.loss.values[step] = float(loss_val) history.safety_u.values[step] = u_max history.unc.values[step] = avg_unc history.acc.values[step] = acc if step % 50 == 0: print(f"Step {step:3d} | Loss: {loss_val:.5f} | Safety U: {u_max:.4f} " f"| Unc: {avg_unc:.4f} | Acc: {acc*100:.1f}%") print("\n✅ Training completed!") # Animation fig, ax1 = plt.subplots(figsize=(12, 5)) ax2 = ax1.twinx() line1, = ax1.plot([], [], color='#1f77b4', lw=2.5, label='Training Loss') line2, = ax2.plot([], [], color='#d62728', lw=2.5, label='Safety Risk (U-max)') ax2.axhline(0.1, color='green', ls='--', alpha=0.7, lw=1.5) ax2.text(num_steps * 0.02, 0.12, 'Safety Limit (0.1)', color='green', fontsize=9) loss_min = float(history.loss.min()) loss_max = float(history.loss.max()) loss_pad = (loss_max - loss_min) * 0.1 + 1e-6 ax1.set_xlim(0, num_steps) ax1.set_ylim(loss_min - loss_pad, loss_max + loss_pad) ax2.set_ylim(-0.05, 1.05) ax1.set_xlabel('Epochy', fontsize=11) ax1.set_ylabel('Loss', color='#1f77b4', fontsize=11) ax2.set_ylabel('Riziko (U-max)', color='#d62728', fontsize=11) ax1.tick_params(axis='y', labelcolor='#1f77b4') ax2.tick_params(axis='y', labelcolor='#d62728') ax1.legend([line1, line2], ['Training Loss', 'Safety Risk (U-max)'], loc='upper right', fontsize=9, framealpha=0.85) title = ax1.set_title('Step 0', fontsize=12) def animate(i): idx = min(i + 1, num_steps) xs = history.step.values[:idx] line1.set_data(xs, history.loss.values[:idx]) line2.set_data(xs, history.safety_u.values[:idx]) title.set_text(f'Step {xs[-1] if len(xs) else 0} / {num_steps - 1}') return line1, line2, title ani = animation.FuncAnimation( fig, animate, frames=range(0, num_steps, 5), blit=True, interval=50, repeat=False ) plt.tight_layout() plt.close(fig) display(HTML(ani.to_jshtml())) # Certification print("\n" + "=" * 60) final_u = float(history.safety_u.values[-1]) final_acc = float(history.acc.values[-1]) final_loss = float(history.loss.values[-1]) print(f"FINAL CERTIFICATION: {'✅ SAFE' if final_u <= 0.1 else '❌ RISK'}") print(f" Final Loss: {final_loss:.5f}") print(f" Classification accuracy: {final_acc*100:.1f}%") print(f" Max U (dangerous box): {final_u:.4f}") print("=" * 60) history.to_netcdf("training_log_fixed.nc") print("Log uložen: training_log_fixed.nc") Final Certification Report ------------------------------ After training, the model provides a final verification status. If the **Max U** in the safety box is below the threshold (0.1), the region is officially **Certified Safe**. .. code-block:: text ============================================================ JLNN CERTIFICATION REPORT ============================================================ Status: ✅ CERTIFIED SAFE Final Risk (U-max): 0.0842 Model Accuracy: 100.0% ------------------------------------------------------------ PROOF: Mathematically guaranteed that in the defined region, the model's belief in the target class never exceeds U-max. ============================================================ Conclusion ------------- JLNN provides a unique bridge between neural learning and symbolic logic. By using interval arithmetic, it moves beyond "probabilistic guesses" toward **verifiable AI guarantees**. Download ---------- You can also download the raw notebook file for local use: :download:`JLNN_safety_verification.ipynb ` .. tip:: To run the notebook locally, make sure you have installed the package using ``pip install -e .[test]``.