JLNN + Xarray: Real-time Rule Emergence Monitoring in Nano-MoE

This tutorial demonstrates the use of JAX Logical Neural Networks (JLNN) for real-time auditing of Mixture-of-Experts models, where we use fuzzy logic rules to monitor and visualize the router specialization process directly during training on the TPU.

Note

The interactive notebook is hosted externally to ensure the best viewing experience and to allow immediate execution in the cloud.

Run in Google Colab

Execute the code directly in your browser without any local setup.

https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_xarray_rule_monitoring_moe.ipynb
View on GitHub

View source code and outputs in the GitHub notebook browser.

https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_xarray_rule_monitoring_moe.ipynb

The Vision: Neuro-symbolic Self-Reflection

In modern Mixture-of-Experts (MoE) architectures, the router is often a black box, making it difficult to understand why a specific expert was chosen. This tutorial demonstrates a Neuro-symbolic Monitoring System using JLNN (JAX Logical Neural Networks).

By evaluating fuzzy logic rules over neural activations in real-time, we can audit the interaction between routing efficiency, expert load, and training stability.

The Audit Logic: Multi-Rule Monitoring

Instead of a single metric, we define a set of formal logic rules to monitor the Neuro-symbolic Emergence within the MoE layer. This audit tracks the interaction across 6 distinct logical dimensions using JLNN syntax.

Monitored Hypotheses

  • Expert Specialization: High routing confidence combined with low activation entropy signals that experts are becoming niche-specific.

  • Systemic Imbalance: Detection of high expert imbalance paired with increasing weight magnitudes (a potential warning sign of expert collapse).

  • Routing Maturity: Stable or decreasing loss combined with decisive routing indicates a maturing modular system.

  • Underutilization: Identifying states where the model is neither confident nor decisive, failing to leverage its modular capacity.

Core Symbolic Rules (JLNN Syntax)

  1. 0.85 :: high_routing_confidence & low_activation_entropy (Specialization)

  2. 0.75 :: high_expert_imbalance & increasing_weight_magnitude (Imbalance Risk)

  3. 0.80 :: high_routing_confidence & stable_loss (Routing Stability)

  4. 0.70 :: decreasing_loss & low_activation_entropy (Learning Progress)

  5. 0.65 :: ~low_activation_entropy & ~high_routing_confidence (Underutilization)

  6. 0.75 :: low_activation_entropy & ~high_expert_imbalance (Healthy Specialization)

Implementation Details

The system is optimized for TPU acceleration using JAX and Flax NNX. Key components include:

Fuzzy Ramp Sensing

To bridge the gap between neural scalars and logical predicates, we use a fuzzy_ramp function. This converts continuous metrics like entropy or confidence into truth intervals \([L, U]\).

def fuzzy_ramp(x, slope=12, offset=0.5):
    L = jax.nn.sigmoid(slope * (x - (offset + 0.05)))
    U = jax.nn.sigmoid(slope * (x - (offset - 0.05)))
    return jnp.stack([L, U], axis=-1).reshape(1, 1, 2)

Robust Extraction

During TPU execution, JAX transforms objects into raw arrays. The safe_extract utility ensures that we can always retrieve the truth intervals regardless of the underlying data representation.

def safe_extract(data_or_node):
    target = getattr(data_or_node, "output", getattr(data_or_node, "value", data_or_node))
    arr = np.array(target).reshape(-1)
    return arr[-2:]

Xarray Telemetry

We use xarray to store the high-dimensional audit history, allowing for sophisticated multi-rule visualization with uncertainty bounds.

history = xr.Dataset(
    data_vars={
        "loss": (["step"], np.zeros(steps)),
        **{f"{rule}_U": (["step"], np.zeros(steps)) for rule in rules_list},
        **{f"{rule}_L": (["step"], np.zeros(steps)) for rule in rules_list}
    },
    coords={"step": np.arange(steps)}
)

All example code

#######################
# Instalation & Imports
#######################

'''
def setup_jlnn_tpu_environment():
    print("🧹 Cleaning environment from potential conflicts...")
    # Odstraníme starý JAX, abychom vynutili čistou TPU instalaci
    !pip uninstall -y jax jaxlib --quiet

    print("🚀 Installing JAX with TPU support...")
    # Instalace specifické verze pro TPU
    !pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --quiet

    print("📦 Installing jax-lnn and visual tools...")
    # Instalace tvého projektu přímo z GitHubu
    !pip install git+https://github.com/RadimKozl/JLNN.git --quiet
    !pip install seaborn xarray --quiet

    print("\n🔄 RESTARTING KERNEL to apply TPU changes...")
    os.kill(os.getpid(), 9)

    try:
        import jlnn
        import os
        import jax
        # Checking if we see TPU
        if 'TPU_NAME' in os.environ:
            import jax.tools.colab_tpu
            jax.tools.colab_tpu.setup_tpu()

        print(f"✅ jax-lnn is ready.")
        print(f"✅ Devices: {jax.devices()}")
        # Here you should see: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]

    except (ImportError, RuntimeError):
        setup_jlnn_tpu_environment()
'''

import jax
import os

if 'TPU_NAME' in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()

print(f"Confirmed Devices: {jax.devices()}")

# Download real data (TinyShakespeare)

'''
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O input.txt
'''

# Imports

import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

# JLNN imports
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints, clip_predicates
from jlnn.training.losses import total_lnn_loss

#################################
# DATASET SETUP (TinyShakespeare)
#################################

with open('input.txt', 'r') as f: text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
data = jnp.array(encode(text), dtype=jnp.int32)

print(f"Dataset loaded: {len(data)} characters, dictionary: {vocab_size} unique characters.")

##################################
# NEURO-SYMBOLIC AUDIT: JLNN Rules
##################################

class MoELogicMonitor(nnx.Module):
    """Modular monitor inspired by QuantumLogicModel, holding independent LNN formulas."""
    def __init__(self, rules, rngs):
        self.rules = nnx.List([LNNFormula(r, rngs) for r in rules])
    def __call__(self, x):
        return jnp.stack([r(x) for r in self.rules])

rule_strings = [
    "0.85 :: high_routing_confidence & low_activation_entropy",
    "0.75 :: high_expert_imbalance & increasing_weight_magnitude",
    "0.80 :: high_routing_confidence & stable_loss",
    "0.70 :: decreasing_loss & low_activation_entropy",
    "0.65 :: ~low_activation_entropy & ~high_routing_confidence",
    "0.75 :: low_activation_entropy & ~high_expert_imbalance"
]

monitor = MoELogicMonitor(rule_strings, nnx.Rngs(45))

def fuzzy_ramp(x, slope=12, offset=0.5):
    """Maps a scalar value to a [Lower, Upper] fuzzy interval."""
    L = jax.nn.sigmoid(slope * (x - (offset + 0.05)))
    U = jax.nn.sigmoid(slope * (x - (offset - 0.05)))
    return jnp.stack([L, U], axis=-1).reshape(1, 1, 2)

def safe_extract(data_or_node):
    """
    Robustly extracts [L, U] intervals.
    Handles both JLNN Nodes (via .output/.value) and raw JAX/NumPy arrays.
    """
    # If it's a node, get its data. If it's already an array, use it directly.
    target = getattr(data_or_node, "output", getattr(data_or_node, "value", data_or_node))

    # Convert to numpy and flatten to get the last two elements [L, U]
    arr = np.array(target).reshape(-1)
    return arr[-2:]

##############################################
# ARCHITECTURE: Nano-MoE with Causal Attention
##############################################

class CausalAttention(nnx.Module):
    """Multi-head Causal Self-Attention mechanism."""
    def __init__(self, n_embd, n_head, rngs):
        self.n_head = n_head
        self.qkv = nnx.Linear(n_embd, 3 * n_embd, rngs=rngs)
        self.proj = nnx.Linear(n_embd, n_embd, rngs=rngs)

    def __call__(self, x):
        B, T, C = x.shape
        q, k, v = jnp.split(self.qkv(x), 3, axis=-1)
        q = q.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)
        k = k.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)
        v = v.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3)

        mask = jnp.tril(jnp.ones((T, T)))
        attn = (q @ k.transpose(0, 1, 3, 2)) * (1.0 / jnp.sqrt(k.shape[-1]))
        attn = jnp.where(mask == 0, -jnp.inf, attn)
        attn = jax.nn.softmax(attn, axis=-1)

        y = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
        return self.proj(y)

class MoELayer(nnx.Module):
    """Mixture-of-Experts layer with Noisy Top-2 Routing."""
    def __init__(self, n_embd, n_experts, rngs):
        self.router = nnx.Linear(n_embd, n_experts, rngs=rngs)
        self.experts = nnx.List([
            nnx.Sequential(nnx.Linear(n_embd, 4*n_embd, rngs=rngs), nnx.gelu, nnx.Linear(4*n_embd, n_embd, rngs=rngs))
            for _ in range(n_experts)
        ])
        self.n_experts = n_experts

    def __call__(self, x, train=True):
        logits = self.router(x)
        if train:
            logits += jax.random.normal(jax.random.PRNGKey(0), logits.shape) * 1e-2

        probs = jax.nn.softmax(logits, axis=-1)
        ent = -jnp.sum(probs * jnp.log(probs + 1e-9), axis=-1).mean()

        val, idx = jax.lax.top_k(logits, 2)
        out = jnp.zeros_like(x)
        for i, exp in enumerate(self.experts):
            mask = jnp.any(idx == i, axis=-1, keepdims=True)
            out += exp(x) * mask.astype(x.dtype)
        return out, ent, probs

class NanoMoE(nnx.Module):
    """Language Model using Mixture-of-Experts blocks."""
    def __init__(self, vocab_size, n_embd, n_head, n_experts, rngs):
        self.embed = nnx.Embed(vocab_size, n_embd, rngs=rngs)
        self.pos_embed = nnx.Variable(jax.random.normal(jax.random.PRNGKey(0), (1024, n_embd)))
        self.attn = CausalAttention(n_embd, n_head, rngs=rngs)
        self.ln1 = nnx.LayerNorm(n_embd, rngs=rngs)
        self.moe = MoELayer(n_embd, n_experts, rngs)
        self.ln2 = nnx.LayerNorm(n_embd, rngs=rngs)
        self.head = nnx.Linear(n_embd, vocab_size, rngs=rngs)

    def __call__(self, x, train=True):
        B, T = x.shape
        x = self.embed(x) + self.pos_embed[:T]
        x = x + self.attn(self.ln1(x))
        moe_out, ent, probs = self.moe(self.ln2(x), train=train)
        x = x + moe_out
        return self.head(x), ent, probs

#############################
# INITIALIZING XARRAY DATASET
#############################

rules_list = ["expert_specialization", "load_imbalance", "routing_stability",
          "learning_progress", "underutilized_experts", "healthy_specialization"]

steps = 5000
rngs = nnx.Rngs(42)

history = xr.Dataset(
    data_vars={
        "loss": (["step"], np.zeros(steps)),
        **{f"{r}_U": (["step"], np.zeros(steps)) for r in rules_list},
        **{f"{r}_L": (["step"], np.zeros(steps)) for r in rules_list}
    },
    coords={"step": np.arange(steps)}
)

##########################
# TRAINING LOOP WITH AUDIT
##########################

model = NanoMoE(vocab_size, 256, 8, 4, rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)

# SENSITIVITY TUNING (Before training)

@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        logits, ent, probs = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss, (ent, probs)
    (loss, (ent, probs)), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    return loss, ent, probs

# TRAINING LOOP WITH AUDIT

losses_cache = []

print("🚀 Training Nano-MoE on TPU with Neuro-Symbolic Audit...")

for step in tqdm(range(steps)):
    x_b, y_b = get_batch()
    loss, ent, probs = train_step(model, optimizer, x_b, y_b)
    losses_cache.append(float(loss))

    # Audit Metrics (Calculated on CPU/Host for logging)
    conf = float(jnp.max(probs, axis=-1).mean())
    ent_norm = float(ent / jnp.log(4))
    usage = jnp.mean(jax.nn.one_hot(jnp.argmax(probs, -1), 4), axis=(0,1))
    imbalance = float(jnp.std(usage) / 0.5)
    loss_trend = np.mean(np.diff(losses_cache[-10:])) if len(losses_cache) > 10 else 0

    lnn_inputs = {
        "high_routing_confidence": fuzzy_ramp(conf, offset=0.35),
        "low_activation_entropy": fuzzy_ramp(1.0 - ent_norm, offset=0.4),
        "high_expert_imbalance": fuzzy_ramp(imbalance, offset=0.2),
        "increasing_weight_magnitude": fuzzy_ramp(0.5 + 0.0005*step, offset=0.5),
        "stable_loss": fuzzy_ramp(1.0 - abs(loss_trend)*10, offset=0.5),
        "decreasing_loss": fuzzy_ramp(1.0 if loss_trend < 0 else 0.0, offset=0.4),
        "low_expert_imbalance": fuzzy_ramp(1.0 - imbalance, offset=0.5)
    }

    # Audit Inference (TPU accelerated)
    results = monitor(lnn_inputs)

    # Logging
    history["loss"].values[step] = float(loss)
    for i, r_name in enumerate(rules_list):
        res = safe_extract(results[i])
        history[f"{r_name}_L"].values[step] = res[0]
        history[f"{r_name}_U"].values[step] = res[1]

####################################
# VISUALIZATION: Performance & Audit
####################################

sns.set_style("whitegrid")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), sharex=True)

# 1. PANEL: Convergence Model
history.loss.rolling(step=50, min_periods=1).mean().plot(ax=ax1, color='#2c3e50', lw=2.5)
ax1.set_title("Model Convergence: Cross-Entropy + Load Balancing", fontsize=15, fontweight='bold')
ax1.set_ylabel("Loss (Training Error)", fontsize=12)
ax1.grid(True, alpha=0.3)

# 2. PANEL: Neuro-Symbolic Audit
custom_colors = {
    "expert_specialization": "#e74c3c", # Red (target)
    "healthy_specialization": "#27ae60", # Green (ideal)
    "load_imbalance": "#f39c12",         # Orange (warning)
    "learning_progress": "#3498db",      # Blue
    "routing_stability": "#9b59b6",      # Purple
    "underutilized_experts": "#95a5a6"   # Gray
}

for r in rules_list:
    u_bound = history[f"{r}_U"].rolling(step=30, min_periods=1).mean()
    l_bound = history[f"{r}_L"].rolling(step=30, min_periods=1).mean()

    color = custom_colors.get(r, "black")
    line, = ax2.plot(history.step, u_bound, label=r, color=color, lw=2, alpha=0.9)
    ax2.fill_between(history.step, l_bound, u_bound, alpha=0.15, color=line.get_color())

ax2.set_title("Neuro-Symbolic Audit: Comprehensive Rule Monitoring", fontsize=15, fontweight='bold')
ax2.set_ylabel("Fuzzy Truth Value [0, 1]", fontsize=12)
ax2.set_ylim(-0.05, 1.05)
ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), title="Monitored Rules", fontsize=10)
ax2.set_xlabel("Training Steps", fontsize=12)

plt.tight_layout()
plt.show()

Interpreting the Audit

By evaluating these fuzzy logic rules in real-time, the monitor creates a “Logical Self-Reflection” layer. When the truth values (\([L, U]\) intervals) for these rules stabilize near 1.0, the neuro-symbolic audit confirms that the MoE architecture has successfully transitioned from stochastic routing to structured, specialized behavior.

Download

You can also download the raw notebook file for local use: JLNN_xarray_rule_monitoring_moe.ipynb

Tip

To run the notebook locally, make sure you have installed the package using pip install -e .[test].