JLNN: Certifiable AI and Formal Verification with Xarray

Proof of safety and real-time semantic traceability

This tutorial demonstrates how JLNN (JAX Logical Neural Networks) addresses the “black box” problem in modern AI. Unlike conventional neural networks, JLNN enables formal verification and semantic grounding.

Run in Google Colab

Run the autonomous optimization cycle in your browser.

https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_safety_verification.ipynb
View on GitHub

Browse the full source code and results.

https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_safety_verification.ipynb

Key Features

  1. Formal Verification: Mathematically proving that a model stays within safety bounds in a defined input space.

  2. Semantic Grounding: Using Xarray to map raw data to human-understandable logical predicates.

  3. Truth Intervals: Operating with [Lower, Upper] bounds instead of single-point scalars to capture uncertainty.

Theory: Certifiability and Intervals

In JLNN, we do not work with simple scalars, but with truth intervals [L, U].

  • L (Lower bound): Minimum necessary truth.

  • U (Upper bound): Maximum possible truth (everything that cannot be disproven).

Formal verification occurs by feeding an “input box” (e.g., all data where petal length is between 5.5 and 7.0) into the model. JLNN calculates a single output interval. If the upper bound $U$ is low (e.g., < 0.1), we have a mathematical proof that the model never misclassifies a sample in this region.

Environment Setup

First, ensure you have JLNN and its dependencies installed:

pip install git+https://github.com/RadimKozl/JLNN.git
pip install xarray optax scikit-learn matplotlib

Data Preparation with Xarray

We use Xarray to maintain semantic metadata, ensuring we track features by name rather than indices.

import xarray as xr
import numpy as np
from sklearn.datasets import load_iris

iris = load_iris()
da = xr.DataArray(
    iris.data[:, 2:],
    coords={"sample": range(150), "feature": ["petal_length", "petal_width"]},
    dims=("sample", "feature")
)
y = (iris.target == 0).astype(int) # Target: Setosa

# Normalization
da_min, da_max = da.min(dim="sample"), da.max(dim="sample")
da_norm = (da - da_min) / (da_max - da_min + 1e-6)

Model Definition

The logical rule for detecting Iris Setosa is defined as: Setosa = NOT (high_length) AND NOT (high_width).

from jlnn.symbolic.compiler import LNNFormula
from flax import nnx

# Rule with priority 1.0
formula = "1.0::(~high_length & ~high_width)"
model = LNNFormula(formula, nnx.Rngs(42))

Training and Safety Verification

We train the model using Optax while monitoring the Safety Risk (U-max) in a predefined “dangerous box” (where Setosa should never be predicted).

```
try:
    import jlnn
    import jraph
    import numpyro
    import optax
    import trimesh
    from flax import nnx
    import jax.numpy as jnp
    import networkx as nx
    import numpy as np
    import xarray as xr
    import pandas as pd
    import grain.python as grain
    import optuna
    import matplotlib.pyplot as plt
    import sklearn
    print("✅ JLNN and JAX are ready.")
except ImportError:
    print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...")
    # Instalace frameworku
    #!pip install jax-lnn --quiet
    !pip install git+https://github.com/RadimKozl/JLNN.git --quiet
    !pip install optuna optuna-dashboard pandas scikit-learn matplotlib --quiet
    !pip install arviz --quiet
    !pip install seaborn --quiet
    !pip install numpyro jraph --quiet
    !pip install grain --quiet
    !pip install networkx --quiet
    !pip install trimesh --quiet
    !pip install xarray --quiet
    !pip install kagglehub --quiet
    !pip install optax --quiet
    # Fix JAX/CUDA compatibility for 2026 in Colab
    !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --quiet
    !pip install  scikit-learn pandas --quiet

    import os
    print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.")
    os.kill(os.getpid(), 9)
    os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts
```

# Imports

import jax
import jax.numpy as jnp
import xarray as xr
import numpy as np
from flax import nnx
import optax
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from sklearn.datasets import load_iris

from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints, clip_predicates
from jlnn.training.losses import total_lnn_loss

# DATA PREPARATION (Xarray Semantic Mapping)

iris = load_iris()
da = xr.DataArray(
    iris.data[:, 2:],
    coords={"sample": range(150), "feature": ["petal_length", "petal_width"]},
    dims=("sample", "feature")
)
y = (iris.target == 0).astype(int)

# Normalization

da_min = da.min(dim="sample")
da_max = da.max(dim="sample")
da_norm = (da - da_min) / (da_max - da_min + 1e-6)

# Preparing training inputs

pl_raw = da_norm.sel(feature="petal_length").values.reshape(-1, 1).astype(np.float32)
pw_raw = da_norm.sel(feature="petal_width").values.reshape(-1, 1).astype(np.float32)

# Inputs for logical predicates

inputs = {
    "high_length": jnp.array(pl_raw),  # (150, 1)
    "high_width":  jnp.array(pw_raw),  # (150, 1)
}

# Target as a truth interval [Lower, Upper]

target_interval = jnp.where(
    y[:, None] > 0.5,
    jnp.array([[0.9, 1.0]]),
    jnp.array([[0.0, 0.1]])
)

print(f"Input shape:  {inputs['high_length'].shape}")   # (150, 1)
print(f"Target shape: {target_interval.shape}")          # (150, 2)

# MODEL DEFINITION (LNN Formula)

formula = "1.5::(~high_length & ~high_width)"
model = LNNFormula(formula, nnx.Rngs(42))

# Initialization of Predicates (Fine-tuning the "slopes" for better gradients)

for name, pred_node in model.predicates.items():
    pred = pred_node.predicate
    center = 0.5
    pred.slope_l[...] = jnp.array([5.0])
    pred.offset_l[...] = jnp.array([center + 0.02])   # L transition above
    pred.slope_u[...] = jnp.array([5.0])
    pred.offset_u[...] = jnp.array([center - 0.02])   # U transition below → L<=U
    print(f"  Predikát '{name}': slope=3.0, offset_l={center+0.1:.2f}, offset_u={center-0.1:.2f}")

# Initialize weights to 3.5 to allow for contraction during learning

for path, node in nnx.iter_graph(model):
    if isinstance(node, nnx.Param):
        param_name = str(path[-1]) if path else ""
        if 'weight' in param_name:
            # Initialize to 3.5 – gradient can reduce weights down to 1.0
            node[...] = jnp.full_like(node[...], 3.5)

print("\nWeights after re-initialization:")
for path, node in nnx.iter_graph(model):
    if isinstance(node, nnx.Param):
        param_name = str(path[-1]) if path else ""
        if 'weight' in param_name or 'beta' in param_name:
            print(f"  {param_name}: {node[...]}")

# Pre-training diagnostics

print("\n=== Diagnostic step 0 ===")
pred0 = model(inputs)
print(f"  Output shape:    {pred0.shape}")
print(f"  Output sample 0: {pred0[0]}")
print(f"  L mean: {float(pred0[:,0].mean()):.4f}  U mean: {float(pred0[:,1].mean()):.4f}")
print(f"  Interval width:  {float((pred0[:,1]-pred0[:,0]).mean()):.4f}")

loss0 = total_lnn_loss(pred0, target_interval)
print(f"  Loss step 0:     {float(loss0):.6f}")

_, grads0 = nnx.value_and_grad(
    lambda m: total_lnn_loss(m(inputs), target_interval)
)(model)

# Check if gradients are non-zero

nonzero_grads = []
for path, node in nnx.iter_graph(grads0):
    if isinstance(node, nnx.Param) and hasattr(node, '__jax_array__'):
        try:
            val = node[...]
            norm = float(jnp.linalg.norm(val))
            if norm > 1e-9:
                nonzero_grads.append((str(path[-1]) if path else "?", norm))
        except Exception:
            pass

# FORMAL VERIFICATION SETUP (Safety Box)

def norm_val(val, feat):
    v_min = da_min.sel(feature=feat).item()
    v_max = da_max.sel(feature=feat).item()
    return float((val - v_min) / (v_max - v_min + 1e-6))

# Most dangerous point: large flower (never Setosa)

formal_input = {
    "high_length": jnp.array([[norm_val(5.5, "petal_length")]]),  # (1, 1)
    "high_width":  jnp.array([[norm_val(1.8, "petal_width")]]),
}

# TRAIN STEP

schedule = optax.exponential_decay(0.02, transition_steps=300, decay_rate=0.5)
optimizer = nnx.Optimizer(model, optax.adam(schedule), wrt=nnx.Param)

@nnx.jit
def train_step(model, optimizer, inputs, target):
    def loss_fn(m):
        pred = m(inputs)
        pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0)
        return total_lnn_loss(pred, target)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)

    # ✅ Only clip_predicates (preserves L<=U)
    # clip_weights OMITTED – weights initialized to 3.5, can learn
    clip_predicates(model)
    return loss

# TRAINING

num_steps = 601

history = xr.Dataset(
    coords={"step": np.arange(num_steps)},
    data_vars={
        "loss":     ("step", np.zeros(num_steps)),
        "safety_u": ("step", np.zeros(num_steps)),
        "unc":      ("step", np.zeros(num_steps)),
        "acc":      ("step", np.zeros(num_steps)),
    }
)

print("🚀 Starting training...\n")

for step in range(num_steps):
    loss_val = train_step(model, optimizer, inputs, target_interval)

    preds_f = model(formal_input)
    preds_f = jnp.nan_to_num(preds_f, nan=0.5)
    u_max = float(preds_f[0, 0, 1])

    preds_tr = model(inputs)
    preds_tr = jnp.nan_to_num(preds_tr, nan=0.5)
    # Corrected indexing: Access the 0th output's bounds
    avg_unc  = float((preds_tr[:, 0, 1] - preds_tr[:, 0, 0]).mean())
    center   = (preds_tr[:, 0, 0] + preds_tr[:, 0, 1]) / 2.0
    acc      = float(((center > 0.5).astype(float) == y).mean())

    history.loss.values[step]     = float(loss_val)
    history.safety_u.values[step] = u_max
    history.unc.values[step]      = avg_unc
    history.acc.values[step]      = acc

    if step % 50 == 0:
        print(f"Step {step:3d} | Loss: {loss_val:.5f} | Safety U: {u_max:.4f} "
            f"| Unc: {avg_unc:.4f} | Acc: {acc*100:.1f}%")

print("\n✅ Training completed!")

# Animation

fig, ax1 = plt.subplots(figsize=(12, 5))
ax2 = ax1.twinx()

line1, = ax1.plot([], [], color='#1f77b4', lw=2.5, label='Training Loss')
line2, = ax2.plot([], [], color='#d62728', lw=2.5, label='Safety Risk (U-max)')
ax2.axhline(0.1, color='green', ls='--', alpha=0.7, lw=1.5)
ax2.text(num_steps * 0.02, 0.12, 'Safety Limit (0.1)', color='green', fontsize=9)

loss_min = float(history.loss.min())
loss_max = float(history.loss.max())
loss_pad = (loss_max - loss_min) * 0.1 + 1e-6

ax1.set_xlim(0, num_steps)
ax1.set_ylim(loss_min - loss_pad, loss_max + loss_pad)
ax2.set_ylim(-0.05, 1.05)
ax1.set_xlabel('Epochy', fontsize=11)
ax1.set_ylabel('Loss', color='#1f77b4', fontsize=11)
ax2.set_ylabel('Riziko (U-max)', color='#d62728', fontsize=11)
ax1.tick_params(axis='y', labelcolor='#1f77b4')
ax2.tick_params(axis='y', labelcolor='#d62728')
ax1.legend([line1, line2], ['Training Loss', 'Safety Risk (U-max)'],
        loc='upper right', fontsize=9, framealpha=0.85)
title = ax1.set_title('Step 0', fontsize=12)

def animate(i):
    idx = min(i + 1, num_steps)
    xs = history.step.values[:idx]
    line1.set_data(xs, history.loss.values[:idx])
    line2.set_data(xs, history.safety_u.values[:idx])
    title.set_text(f'Step {xs[-1] if len(xs) else 0} / {num_steps - 1}')
    return line1, line2, title

ani = animation.FuncAnimation(
    fig, animate, frames=range(0, num_steps, 5),
    blit=True, interval=50, repeat=False
)
plt.tight_layout()
plt.close(fig)
display(HTML(ani.to_jshtml()))

# Certification

print("\n" + "=" * 60)
final_u    = float(history.safety_u.values[-1])
final_acc  = float(history.acc.values[-1])
final_loss = float(history.loss.values[-1])
print(f"FINAL CERTIFICATION: {'✅ SAFE' if final_u <= 0.1 else '❌ RISK'}")
print(f"  Final Loss:            {final_loss:.5f}")
print(f"  Classification accuracy:    {final_acc*100:.1f}%")
print(f"  Max U (dangerous box):  {final_u:.4f}")
print("=" * 60)

history.to_netcdf("training_log_fixed.nc")
print("Log uložen: training_log_fixed.nc")

Final Certification Report

After training, the model provides a final verification status. If the Max U in the safety box is below the threshold (0.1), the region is officially Certified Safe.

============================================================
                JLNN CERTIFICATION REPORT
============================================================
Status:             ✅ CERTIFIED SAFE
Final Risk (U-max): 0.0842
Model Accuracy:     100.0%
------------------------------------------------------------
PROOF: Mathematically guaranteed that in the defined region,
the model's belief in the target class never exceeds U-max.
============================================================

Conclusion

JLNN provides a unique bridge between neural learning and symbolic logic. By using interval arithmetic, it moves beyond “probabilistic guesses” toward verifiable AI guarantees.

Download

You can also download the raw notebook file for local use: JLNN_safety_verification.ipynb

Tip

To run the notebook locally, make sure you have installed the package using pip install -e .[test].