JLNN: Certifiable AI and Formal Verification with Xarray¶
Proof of safety and real-time semantic traceability
This tutorial demonstrates how JLNN (JAX Logical Neural Networks) addresses the “black box” problem in modern AI. Unlike conventional neural networks, JLNN enables formal verification and semantic grounding.
Run the autonomous optimization cycle in your browser.
Browse the full source code and results.
Key Features¶
Formal Verification: Mathematically proving that a model stays within safety bounds in a defined input space.
Semantic Grounding: Using Xarray to map raw data to human-understandable logical predicates.
Truth Intervals: Operating with [Lower, Upper] bounds instead of single-point scalars to capture uncertainty.
Theory: Certifiability and Intervals¶
In JLNN, we do not work with simple scalars, but with truth intervals [L, U].
L (Lower bound): Minimum necessary truth.
U (Upper bound): Maximum possible truth (everything that cannot be disproven).
Formal verification occurs by feeding an “input box” (e.g., all data where petal length is between 5.5 and 7.0) into the model. JLNN calculates a single output interval. If the upper bound $U$ is low (e.g., < 0.1), we have a mathematical proof that the model never misclassifies a sample in this region.
Environment Setup¶
First, ensure you have JLNN and its dependencies installed:
pip install git+https://github.com/RadimKozl/JLNN.git
pip install xarray optax scikit-learn matplotlib
Data Preparation with Xarray¶
We use Xarray to maintain semantic metadata, ensuring we track features by name rather than indices.
import xarray as xr
import numpy as np
from sklearn.datasets import load_iris
iris = load_iris()
da = xr.DataArray(
iris.data[:, 2:],
coords={"sample": range(150), "feature": ["petal_length", "petal_width"]},
dims=("sample", "feature")
)
y = (iris.target == 0).astype(int) # Target: Setosa
# Normalization
da_min, da_max = da.min(dim="sample"), da.max(dim="sample")
da_norm = (da - da_min) / (da_max - da_min + 1e-6)
Model Definition¶
The logical rule for detecting Iris Setosa is defined as: Setosa = NOT (high_length) AND NOT (high_width).
from jlnn.symbolic.compiler import LNNFormula
from flax import nnx
# Rule with priority 1.0
formula = "1.0::(~high_length & ~high_width)"
model = LNNFormula(formula, nnx.Rngs(42))
Training and Safety Verification¶
We train the model using Optax while monitoring the Safety Risk (U-max) in a predefined “dangerous box” (where Setosa should never be predicted).
```
try:
import jlnn
import jraph
import numpyro
import optax
import trimesh
from flax import nnx
import jax.numpy as jnp
import networkx as nx
import numpy as np
import xarray as xr
import pandas as pd
import grain.python as grain
import optuna
import matplotlib.pyplot as plt
import sklearn
print("✅ JLNN and JAX are ready.")
except ImportError:
print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...")
# Instalace frameworku
#!pip install jax-lnn --quiet
!pip install git+https://github.com/RadimKozl/JLNN.git --quiet
!pip install optuna optuna-dashboard pandas scikit-learn matplotlib --quiet
!pip install arviz --quiet
!pip install seaborn --quiet
!pip install numpyro jraph --quiet
!pip install grain --quiet
!pip install networkx --quiet
!pip install trimesh --quiet
!pip install xarray --quiet
!pip install kagglehub --quiet
!pip install optax --quiet
# Fix JAX/CUDA compatibility for 2026 in Colab
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --quiet
!pip install scikit-learn pandas --quiet
import os
print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.")
os.kill(os.getpid(), 9)
os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts
```
# Imports
import jax
import jax.numpy as jnp
import xarray as xr
import numpy as np
from flax import nnx
import optax
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from sklearn.datasets import load_iris
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints, clip_predicates
from jlnn.training.losses import total_lnn_loss
# DATA PREPARATION (Xarray Semantic Mapping)
iris = load_iris()
da = xr.DataArray(
iris.data[:, 2:],
coords={"sample": range(150), "feature": ["petal_length", "petal_width"]},
dims=("sample", "feature")
)
y = (iris.target == 0).astype(int)
# Normalization
da_min = da.min(dim="sample")
da_max = da.max(dim="sample")
da_norm = (da - da_min) / (da_max - da_min + 1e-6)
# Preparing training inputs
pl_raw = da_norm.sel(feature="petal_length").values.reshape(-1, 1).astype(np.float32)
pw_raw = da_norm.sel(feature="petal_width").values.reshape(-1, 1).astype(np.float32)
# Inputs for logical predicates
inputs = {
"high_length": jnp.array(pl_raw), # (150, 1)
"high_width": jnp.array(pw_raw), # (150, 1)
}
# Target as a truth interval [Lower, Upper]
target_interval = jnp.where(
y[:, None] > 0.5,
jnp.array([[0.9, 1.0]]),
jnp.array([[0.0, 0.1]])
)
print(f"Input shape: {inputs['high_length'].shape}") # (150, 1)
print(f"Target shape: {target_interval.shape}") # (150, 2)
# MODEL DEFINITION (LNN Formula)
formula = "1.5::(~high_length & ~high_width)"
model = LNNFormula(formula, nnx.Rngs(42))
# Initialization of Predicates (Fine-tuning the "slopes" for better gradients)
for name, pred_node in model.predicates.items():
pred = pred_node.predicate
center = 0.5
pred.slope_l[...] = jnp.array([5.0])
pred.offset_l[...] = jnp.array([center + 0.02]) # L transition above
pred.slope_u[...] = jnp.array([5.0])
pred.offset_u[...] = jnp.array([center - 0.02]) # U transition below → L<=U
print(f" Predikát '{name}': slope=3.0, offset_l={center+0.1:.2f}, offset_u={center-0.1:.2f}")
# Initialize weights to 3.5 to allow for contraction during learning
for path, node in nnx.iter_graph(model):
if isinstance(node, nnx.Param):
param_name = str(path[-1]) if path else ""
if 'weight' in param_name:
# Initialize to 3.5 – gradient can reduce weights down to 1.0
node[...] = jnp.full_like(node[...], 3.5)
print("\nWeights after re-initialization:")
for path, node in nnx.iter_graph(model):
if isinstance(node, nnx.Param):
param_name = str(path[-1]) if path else ""
if 'weight' in param_name or 'beta' in param_name:
print(f" {param_name}: {node[...]}")
# Pre-training diagnostics
print("\n=== Diagnostic step 0 ===")
pred0 = model(inputs)
print(f" Output shape: {pred0.shape}")
print(f" Output sample 0: {pred0[0]}")
print(f" L mean: {float(pred0[:,0].mean()):.4f} U mean: {float(pred0[:,1].mean()):.4f}")
print(f" Interval width: {float((pred0[:,1]-pred0[:,0]).mean()):.4f}")
loss0 = total_lnn_loss(pred0, target_interval)
print(f" Loss step 0: {float(loss0):.6f}")
_, grads0 = nnx.value_and_grad(
lambda m: total_lnn_loss(m(inputs), target_interval)
)(model)
# Check if gradients are non-zero
nonzero_grads = []
for path, node in nnx.iter_graph(grads0):
if isinstance(node, nnx.Param) and hasattr(node, '__jax_array__'):
try:
val = node[...]
norm = float(jnp.linalg.norm(val))
if norm > 1e-9:
nonzero_grads.append((str(path[-1]) if path else "?", norm))
except Exception:
pass
# FORMAL VERIFICATION SETUP (Safety Box)
def norm_val(val, feat):
v_min = da_min.sel(feature=feat).item()
v_max = da_max.sel(feature=feat).item()
return float((val - v_min) / (v_max - v_min + 1e-6))
# Most dangerous point: large flower (never Setosa)
formal_input = {
"high_length": jnp.array([[norm_val(5.5, "petal_length")]]), # (1, 1)
"high_width": jnp.array([[norm_val(1.8, "petal_width")]]),
}
# TRAIN STEP
schedule = optax.exponential_decay(0.02, transition_steps=300, decay_rate=0.5)
optimizer = nnx.Optimizer(model, optax.adam(schedule), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, inputs, target):
def loss_fn(m):
pred = m(inputs)
pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0)
return total_lnn_loss(pred, target)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
# ✅ Only clip_predicates (preserves L<=U)
# clip_weights OMITTED – weights initialized to 3.5, can learn
clip_predicates(model)
return loss
# TRAINING
num_steps = 601
history = xr.Dataset(
coords={"step": np.arange(num_steps)},
data_vars={
"loss": ("step", np.zeros(num_steps)),
"safety_u": ("step", np.zeros(num_steps)),
"unc": ("step", np.zeros(num_steps)),
"acc": ("step", np.zeros(num_steps)),
}
)
print("🚀 Starting training...\n")
for step in range(num_steps):
loss_val = train_step(model, optimizer, inputs, target_interval)
preds_f = model(formal_input)
preds_f = jnp.nan_to_num(preds_f, nan=0.5)
u_max = float(preds_f[0, 0, 1])
preds_tr = model(inputs)
preds_tr = jnp.nan_to_num(preds_tr, nan=0.5)
# Corrected indexing: Access the 0th output's bounds
avg_unc = float((preds_tr[:, 0, 1] - preds_tr[:, 0, 0]).mean())
center = (preds_tr[:, 0, 0] + preds_tr[:, 0, 1]) / 2.0
acc = float(((center > 0.5).astype(float) == y).mean())
history.loss.values[step] = float(loss_val)
history.safety_u.values[step] = u_max
history.unc.values[step] = avg_unc
history.acc.values[step] = acc
if step % 50 == 0:
print(f"Step {step:3d} | Loss: {loss_val:.5f} | Safety U: {u_max:.4f} "
f"| Unc: {avg_unc:.4f} | Acc: {acc*100:.1f}%")
print("\n✅ Training completed!")
# Animation
fig, ax1 = plt.subplots(figsize=(12, 5))
ax2 = ax1.twinx()
line1, = ax1.plot([], [], color='#1f77b4', lw=2.5, label='Training Loss')
line2, = ax2.plot([], [], color='#d62728', lw=2.5, label='Safety Risk (U-max)')
ax2.axhline(0.1, color='green', ls='--', alpha=0.7, lw=1.5)
ax2.text(num_steps * 0.02, 0.12, 'Safety Limit (0.1)', color='green', fontsize=9)
loss_min = float(history.loss.min())
loss_max = float(history.loss.max())
loss_pad = (loss_max - loss_min) * 0.1 + 1e-6
ax1.set_xlim(0, num_steps)
ax1.set_ylim(loss_min - loss_pad, loss_max + loss_pad)
ax2.set_ylim(-0.05, 1.05)
ax1.set_xlabel('Epochy', fontsize=11)
ax1.set_ylabel('Loss', color='#1f77b4', fontsize=11)
ax2.set_ylabel('Riziko (U-max)', color='#d62728', fontsize=11)
ax1.tick_params(axis='y', labelcolor='#1f77b4')
ax2.tick_params(axis='y', labelcolor='#d62728')
ax1.legend([line1, line2], ['Training Loss', 'Safety Risk (U-max)'],
loc='upper right', fontsize=9, framealpha=0.85)
title = ax1.set_title('Step 0', fontsize=12)
def animate(i):
idx = min(i + 1, num_steps)
xs = history.step.values[:idx]
line1.set_data(xs, history.loss.values[:idx])
line2.set_data(xs, history.safety_u.values[:idx])
title.set_text(f'Step {xs[-1] if len(xs) else 0} / {num_steps - 1}')
return line1, line2, title
ani = animation.FuncAnimation(
fig, animate, frames=range(0, num_steps, 5),
blit=True, interval=50, repeat=False
)
plt.tight_layout()
plt.close(fig)
display(HTML(ani.to_jshtml()))
# Certification
print("\n" + "=" * 60)
final_u = float(history.safety_u.values[-1])
final_acc = float(history.acc.values[-1])
final_loss = float(history.loss.values[-1])
print(f"FINAL CERTIFICATION: {'✅ SAFE' if final_u <= 0.1 else '❌ RISK'}")
print(f" Final Loss: {final_loss:.5f}")
print(f" Classification accuracy: {final_acc*100:.1f}%")
print(f" Max U (dangerous box): {final_u:.4f}")
print("=" * 60)
history.to_netcdf("training_log_fixed.nc")
print("Log uložen: training_log_fixed.nc")
Final Certification Report¶
After training, the model provides a final verification status. If the Max U in the safety box is below the threshold (0.1), the region is officially Certified Safe.
============================================================
JLNN CERTIFICATION REPORT
============================================================
Status: ✅ CERTIFIED SAFE
Final Risk (U-max): 0.0842
Model Accuracy: 100.0%
------------------------------------------------------------
PROOF: Mathematically guaranteed that in the defined region,
the model's belief in the target class never exceeds U-max.
============================================================
Conclusion¶
JLNN provides a unique bridge between neural learning and symbolic logic. By using interval arithmetic, it moves beyond “probabilistic guesses” toward verifiable AI guarantees.
Download¶
You can also download the raw notebook file for local use:
JLNN_safety_verification.ipynb
Tip
To run the notebook locally, make sure you have installed the package using pip install -e .[test].