JLNN + Xarray: Real-time Rule Emergence Monitoring in Nano-MoE¶
This tutorial demonstrates the use of JAX Logical Neural Networks (JLNN) for real-time auditing of Mixture-of-Experts models, where we use fuzzy logic rules to monitor and visualize the router specialization process directly during training on the TPU.
Note
The interactive notebook is hosted externally to ensure the best viewing experience and to allow immediate execution in the cloud.
Execute the code directly in your browser without any local setup.
View source code and outputs in the GitHub notebook browser.
The Vision: Neuro-symbolic Self-Reflection¶
In modern Mixture-of-Experts (MoE) architectures, the router is often a black box, making it difficult to understand why a specific expert was chosen. This tutorial demonstrates a Neuro-symbolic Monitoring System using JLNN (JAX Logical Neural Networks).
By evaluating fuzzy logic rules over neural activations in real-time, we can audit the interaction between routing efficiency, expert load, and training stability.
The Audit Logic: Multi-Rule Monitoring¶
Instead of a single metric, we define a set of formal logic rules to monitor the Neuro-symbolic Emergence within the MoE layer. This audit tracks the interaction across 6 distinct logical dimensions using JLNN syntax.
Monitored Hypotheses¶
Expert Specialization: High routing confidence combined with low activation entropy signals that experts are becoming niche-specific.
Systemic Imbalance: Detection of high expert imbalance paired with increasing weight magnitudes (a potential warning sign of expert collapse).
Routing Maturity: Stable or decreasing loss combined with decisive routing indicates a maturing modular system.
Underutilization: Identifying states where the model is neither confident nor decisive, failing to leverage its modular capacity.
Core Symbolic Rules (JLNN Syntax)¶
0.85 :: high_routing_confidence & low_activation_entropy(Specialization)0.75 :: high_expert_imbalance & increasing_weight_magnitude(Imbalance Risk)0.80 :: high_routing_confidence & stable_loss(Routing Stability)0.70 :: decreasing_loss & low_activation_entropy(Learning Progress)0.65 :: ~low_activation_entropy & ~high_routing_confidence(Underutilization)0.75 :: low_activation_entropy & ~high_expert_imbalance(Healthy Specialization)
Implementation Details¶
The system is optimized for TPU acceleration using JAX and Flax NNX. Key components include:
Fuzzy Ramp Sensing¶
To bridge the gap between neural scalars and logical predicates, we use a fuzzy_ramp function. This converts continuous metrics like entropy or confidence into truth intervals \([L, U]\).
def fuzzy_ramp(x, slope=12, offset=0.5):
L = jax.nn.sigmoid(slope * (x - (offset + 0.05)))
U = jax.nn.sigmoid(slope * (x - (offset - 0.05)))
return jnp.stack([L, U], axis=-1).reshape(1, 1, 2)
Robust Extraction¶
During TPU execution, JAX transforms objects into raw arrays. The safe_extract utility ensures that we can always retrieve the truth intervals regardless of the underlying data representation.
def safe_extract(data_or_node):
target = getattr(data_or_node, "output", getattr(data_or_node, "value", data_or_node))
arr = np.array(target).reshape(-1)
return arr[-2:]
Xarray Telemetry¶
We use xarray to store the high-dimensional audit history, allowing for sophisticated multi-rule visualization with uncertainty bounds.
history = xr.Dataset(
data_vars={
"loss": (["step"], np.zeros(steps)),
**{f"{rule}_U": (["step"], np.zeros(steps)) for rule in rules_list},
**{f"{rule}_L": (["step"], np.zeros(steps)) for rule in rules_list}
},
coords={"step": np.arange(steps)}
)
All example code¶
#######################
# Instalation & Imports
#######################
'''
def setup_jlnn_tpu_environment():
print("🧹 Cleaning environment from potential conflicts...")
# Odstraníme starý JAX, abychom vynutili čistou TPU instalaci
!pip uninstall -y jax jaxlib --quiet
print("🚀 Installing JAX with TPU support...")
# Instalace specifické verze pro TPU
!pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --quiet
print("📦 Installing jax-lnn and visual tools...")
# Instalace tvého projektu přímo z GitHubu
!pip install git+https://github.com/RadimKozl/JLNN.git --quiet
!pip install seaborn xarray --quiet
print("\n🔄 RESTARTING KERNEL to apply TPU changes...")
os.kill(os.getpid(), 9)
try:
import jlnn
import os
import jax
# Checking if we see TPU
if 'TPU_NAME' in os.environ:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(f"✅ jax-lnn is ready.")
print(f"✅ Devices: {jax.devices()}")
# Here you should see: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]
except (ImportError, RuntimeError):
setup_jlnn_tpu_environment()
'''
import jax
import os
if 'TPU_NAME' in os.environ:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(f"Confirmed Devices: {jax.devices()}")
# Download real data (TinyShakespeare)
'''
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O input.txt
'''
# Imports
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
# JLNN imports
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints, clip_predicates
from jlnn.training.losses import total_lnn_loss
#################################
# DATASET SETUP (TinyShakespeare)
#################################
with open('input.txt', 'r') as f: text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
data = jnp.array(encode(text), dtype=jnp.int32)
print(f"Dataset loaded: {len(data)} characters, dictionary: {vocab_size} unique characters.")
##################################
# NEURO-SYMBOLIC AUDIT: JLNN Rules
##################################
class MoELogicMonitor(nnx.Module):
"""Modular monitor inspired by QuantumLogicModel, holding independent LNN formulas."""
def __init__(self, rules, rngs):
self.rules = nnx.List([LNNFormula(r, rngs) for r in rules])
def __call__(self, x):
return jnp.stack([r(x) for r in self.rules])
rule_strings = [
"0.85 :: high_routing_confidence & low_activation_entropy",
"0.75 :: high_expert_imbalance & increasing_weight_magnitude",
"0.80 :: high_routing_confidence & stable_loss",
"0.70 :: decreasing_loss & low_activation_entropy",
"0.65 :: ~low_activation_entropy & ~high_routing_confidence",
"0.75 :: low_activation_entropy & ~high_expert_imbalance"
]
monitor = MoELogicMonitor(rule_strings, nnx.Rngs(45))
def fuzzy_ramp(x, slope=12, offset=0.5):
"""Maps a scalar value to a [Lower, Upper] fuzzy interval."""
L = jax.nn.sigmoid(slope * (x - (offset + 0.05)))
U = jax.nn.sigmoid(slope * (x - (offset - 0.05)))
return jnp.stack([L, U], axis=-1).reshape(1, 1, 2)
def safe_extract(data_or_node):
"""
Robustly extracts [L, U] intervals.
Handles both JLNN Nodes (via .output/.value) and raw JAX/NumPy arrays.
"""
# If it's a node, get its data. If it's already an array, use it directly.
target = getattr(data_or_node, "output", getattr(data_or_node, "value", data_or_node))
# Convert to numpy and flatten to get the last two elements [L, U]
arr = np.array(target).reshape(-1)
return arr[-2:]
##############################################
# ARCHITECTURE: Nano-MoE with Causal Attention
##############################################
class CausalAttention(nnx.Module):
"""Multi-head Causal Self-Attention mechanism."""
def __init__(self, n_embd, n_head, rngs):
self.n_head = n_head
self.qkv = nnx.Linear(n_embd, 3 * n_embd, rngs=rngs)
self.proj = nnx.Linear(n_embd, n_embd, rngs=rngs)
def __call__(self, x):
B, T, C = x.shape
q, k, v = jnp.split(self.qkv(x), 3, axis=-1)
q = q.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)
k = k.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)
mask = jnp.tril(jnp.ones((T, T)))
attn = (q @ k.transpose(0, 1, 3, 2)) * (1.0 / jnp.sqrt(k.shape[-1]))
attn = jnp.where(mask == 0, -jnp.inf, attn)
attn = jax.nn.softmax(attn, axis=-1)
y = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
return self.proj(y)
class MoELayer(nnx.Module):
"""Mixture-of-Experts layer with Noisy Top-2 Routing."""
def __init__(self, n_embd, n_experts, rngs):
self.router = nnx.Linear(n_embd, n_experts, rngs=rngs)
self.experts = nnx.List([
nnx.Sequential(nnx.Linear(n_embd, 4*n_embd, rngs=rngs), nnx.gelu, nnx.Linear(4*n_embd, n_embd, rngs=rngs))
for _ in range(n_experts)
])
self.n_experts = n_experts
def __call__(self, x, train=True):
logits = self.router(x)
if train:
logits += jax.random.normal(jax.random.PRNGKey(0), logits.shape) * 1e-2
probs = jax.nn.softmax(logits, axis=-1)
ent = -jnp.sum(probs * jnp.log(probs + 1e-9), axis=-1).mean()
val, idx = jax.lax.top_k(logits, 2)
out = jnp.zeros_like(x)
for i, exp in enumerate(self.experts):
mask = jnp.any(idx == i, axis=-1, keepdims=True)
out += exp(x) * mask.astype(x.dtype)
return out, ent, probs
class NanoMoE(nnx.Module):
"""Language Model using Mixture-of-Experts blocks."""
def __init__(self, vocab_size, n_embd, n_head, n_experts, rngs):
self.embed = nnx.Embed(vocab_size, n_embd, rngs=rngs)
self.pos_embed = nnx.Variable(jax.random.normal(jax.random.PRNGKey(0), (1024, n_embd)))
self.attn = CausalAttention(n_embd, n_head, rngs=rngs)
self.ln1 = nnx.LayerNorm(n_embd, rngs=rngs)
self.moe = MoELayer(n_embd, n_experts, rngs)
self.ln2 = nnx.LayerNorm(n_embd, rngs=rngs)
self.head = nnx.Linear(n_embd, vocab_size, rngs=rngs)
def __call__(self, x, train=True):
B, T = x.shape
x = self.embed(x) + self.pos_embed[:T]
x = x + self.attn(self.ln1(x))
moe_out, ent, probs = self.moe(self.ln2(x), train=train)
x = x + moe_out
return self.head(x), ent, probs
#############################
# INITIALIZING XARRAY DATASET
#############################
rules_list = ["expert_specialization", "load_imbalance", "routing_stability",
"learning_progress", "underutilized_experts", "healthy_specialization"]
steps = 5000
rngs = nnx.Rngs(42)
history = xr.Dataset(
data_vars={
"loss": (["step"], np.zeros(steps)),
**{f"{r}_U": (["step"], np.zeros(steps)) for r in rules_list},
**{f"{r}_L": (["step"], np.zeros(steps)) for r in rules_list}
},
coords={"step": np.arange(steps)}
)
##########################
# TRAINING LOOP WITH AUDIT
##########################
model = NanoMoE(vocab_size, 256, 8, 4, rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)
# SENSITIVITY TUNING (Before training)
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
logits, ent, probs = model(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss, (ent, probs)
(loss, (ent, probs)), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(model, grads)
return loss, ent, probs
# TRAINING LOOP WITH AUDIT
losses_cache = []
print("🚀 Training Nano-MoE on TPU with Neuro-Symbolic Audit...")
for step in tqdm(range(steps)):
x_b, y_b = get_batch()
loss, ent, probs = train_step(model, optimizer, x_b, y_b)
losses_cache.append(float(loss))
# Audit Metrics (Calculated on CPU/Host for logging)
conf = float(jnp.max(probs, axis=-1).mean())
ent_norm = float(ent / jnp.log(4))
usage = jnp.mean(jax.nn.one_hot(jnp.argmax(probs, -1), 4), axis=(0,1))
imbalance = float(jnp.std(usage) / 0.5)
loss_trend = np.mean(np.diff(losses_cache[-10:])) if len(losses_cache) > 10 else 0
lnn_inputs = {
"high_routing_confidence": fuzzy_ramp(conf, offset=0.35),
"low_activation_entropy": fuzzy_ramp(1.0 - ent_norm, offset=0.4),
"high_expert_imbalance": fuzzy_ramp(imbalance, offset=0.2),
"increasing_weight_magnitude": fuzzy_ramp(0.5 + 0.0005*step, offset=0.5),
"stable_loss": fuzzy_ramp(1.0 - abs(loss_trend)*10, offset=0.5),
"decreasing_loss": fuzzy_ramp(1.0 if loss_trend < 0 else 0.0, offset=0.4),
"low_expert_imbalance": fuzzy_ramp(1.0 - imbalance, offset=0.5)
}
# Audit Inference (TPU accelerated)
results = monitor(lnn_inputs)
# Logging
history["loss"].values[step] = float(loss)
for i, r_name in enumerate(rules_list):
res = safe_extract(results[i])
history[f"{r_name}_L"].values[step] = res[0]
history[f"{r_name}_U"].values[step] = res[1]
####################################
# VISUALIZATION: Performance & Audit
####################################
sns.set_style("whitegrid")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), sharex=True)
# 1. PANEL: Convergence Model
history.loss.rolling(step=50, min_periods=1).mean().plot(ax=ax1, color='#2c3e50', lw=2.5)
ax1.set_title("Model Convergence: Cross-Entropy + Load Balancing", fontsize=15, fontweight='bold')
ax1.set_ylabel("Loss (Training Error)", fontsize=12)
ax1.grid(True, alpha=0.3)
# 2. PANEL: Neuro-Symbolic Audit
custom_colors = {
"expert_specialization": "#e74c3c", # Red (target)
"healthy_specialization": "#27ae60", # Green (ideal)
"load_imbalance": "#f39c12", # Orange (warning)
"learning_progress": "#3498db", # Blue
"routing_stability": "#9b59b6", # Purple
"underutilized_experts": "#95a5a6" # Gray
}
for r in rules_list:
u_bound = history[f"{r}_U"].rolling(step=30, min_periods=1).mean()
l_bound = history[f"{r}_L"].rolling(step=30, min_periods=1).mean()
color = custom_colors.get(r, "black")
line, = ax2.plot(history.step, u_bound, label=r, color=color, lw=2, alpha=0.9)
ax2.fill_between(history.step, l_bound, u_bound, alpha=0.15, color=line.get_color())
ax2.set_title("Neuro-Symbolic Audit: Comprehensive Rule Monitoring", fontsize=15, fontweight='bold')
ax2.set_ylabel("Fuzzy Truth Value [0, 1]", fontsize=12)
ax2.set_ylim(-0.05, 1.05)
ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), title="Monitored Rules", fontsize=10)
ax2.set_xlabel("Training Steps", fontsize=12)
plt.tight_layout()
plt.show()
Interpreting the Audit¶
By evaluating these fuzzy logic rules in real-time, the monitor creates a “Logical Self-Reflection” layer. When the truth values (\([L, U]\) intervals) for these rules stabilize near 1.0, the neuro-symbolic audit confirms that the MoE architecture has successfully transitioned from stochastic routing to structured, specialized behavior.
Download¶
You can also download the raw notebook file for local use:
JLNN_xarray_rule_monitoring_moe.ipynb
Tip
To run the notebook locally, make sure you have installed the package using pip install -e .[test].