Meta-Learning & Self-Reflection

This tutorial demonstrates how to build a Self-Reflective System using JLNN. In most neural networks, the training process is a “black box,” but JLNN allows us to treat the training history as structured data that can be analyzed by other models.

Run in Google Colab

Execute the code directly in your browser without any local setup.

https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_xarray_meta_learning.ipynb
View on GitHub

Browse the source code and outputs in the GitHub notebook viewer.

https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_xarray_meta_learning.ipynb

Content Overview

The pipeline consists of three main stages:

  1. Logical Reasoning (JLNN): A model learns rules for “Wine Quality” using weighted logical conjunctions.

  2. Episodic Memory (Xarray): Every 25 training steps, the model’s internal state (weights, loss, accuracy) is snapshotted into a structured multi-dimensional array.

  3. Meta-Analysis (Decision Tree): A Scikit-learn model analyzes this history to discover which parameter settings (like a specific feature weight) were responsible for reaching the highest accuracy.

Core Components

Structured Logging with Xarray

We use xarray.Dataset to store the training dynamics. This allows us to query the training history using labels instead of raw indices:

history = xr.Dataset(
    data_vars={
        "member_weights": (["step", "antecedent"], np.zeros((len(steps_to_log), 4))),
        "rule_weight": (["step"], np.zeros(len(steps_to_log))),
        "accuracy": (["step"], np.zeros(len(steps_to_log)))
    },
    coords={"step": steps_to_log, "antecedent": ["alcohol", "acid", "magnesium", "ash"]}
)

The Meta-Analyst Loop

By training a DecisionTreeRegressor on the recorded history, we can extract symbolic insights. For example, the tree can tell us: “If the ‘alcohol’ weight was above 1.5, accuracy increased by 10%.”

from sklearn.tree import export_text

# Analyze how weights affected accuracy
dt_analyzer.fit(df_meta, y_meta)
print(export_text(dt_analyzer, feature_names=df_meta.columns.tolist()))

Agentic Reflection (Prompt Generation)

The ultimate goal of this tutorial is to generate a report that can be understood by an LLM (like Gemma 3 or Llama 3). This output can be fed into an AI agent to autonomously adjust the next training run:

I analyzed the training of the JLNN model.
Maximum achieved accuracy: 0.74

Determined logical dependencies:
|--- magnesium >  1.61
|   |--- acid <= 0.61 -> value: [0.74]

Recommendation: Focus on magnesium weight stability.

Key Benefits

  • Transparency: No more guessing why the model stopped improving.

  • Auto-Tuning: Foundation for building self-correcting AI agents.

  • Auditability: A complete trace of the model’s “logic evolution” over time.

Tutorial code

'''
try:
    import jlnn
    from flax import nnx
    import jax.numpy as jnp
    import xarray as xr
    import pandas as pd
    import sklearn
    print("✅ JLNN and JAX are ready.")
except ImportError:
    print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...")
    # Instalace frameworku
    #!pip install jax-lnn --quiet
    !pip install git+https://github.com/RadimKozl/JLNN.git --quiet
    # Fix JAX/CUDA compatibility for 2026 in Colab
    !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    !pip install  scikit-learn pandas

    import os
    print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.")
    os.kill(os.getpid(), 9)
    os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts
'''

import os
os.environ["JAX_PLATFORMS"] = "cpu"

import os
import jlnn
import jax
import jax.numpy as jnp
from flax import nnx
from tqdm import tqdm
import optax
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeRegressor, export_text, plot_tree
from sklearn.preprocessing import StandardScaler

data = load_wine()

X_raw = data.data[:, [10, 1, 4, 3]]
feature_names = ["alcohol", "acid", "magnesium", "ash"]

y_raw = (data.target == 0).astype(float)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_raw)

inputs = jnp.stack([X_scaled, X_scaled], axis=-1)
targets = y_raw[:, None]

class MetaLogicModel(nnx.Module):
    def __init__(self, n_features, rngs):
        # Weights of individual members (antecedent)
        self.weights = nnx.Param(jnp.ones((n_features,)))
        # Weight of the entire rule (credibility)
        self.rule_weight = nnx.Param(jnp.array([0.9]))

    def __call__(self, x):
        # Implementation of weighted conjunction (Weighted AND)
        w = jnp.abs(self.weights)
        # Simplified Łukasiewicz: norm of the sum of weights
        logic_out = jnp.clip(jnp.sum(w * x[:, :, 0], axis=1) / jnp.sum(w) * self.rule_weight, 0, 1)
        return logic_out

steps_to_log = np.arange(0, 1000, 25)

history = xr.Dataset(
    data_vars={
        "member_weights": (["step", "antecedent"], np.zeros((len(steps_to_log), len(feature_names)))),
        "rule_weight": (["step"], np.zeros(len(steps_to_log))),
        "loss": (["step"], np.zeros(len(steps_to_log))),
        "accuracy": (["step"], np.zeros(len(steps_to_log)))
    },
    coords={"step": steps_to_log, "antecedent": feature_names}
)

model = MetaLogicModel(len(feature_names), rngs=nnx.Rngs(0))

optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)

def loss_fn(model, x, y):
    pred = model(x)
    mse = jnp.mean((pred - y[:, 0])**2)
    # Contradiction penalty - penalizes too high weights leading to a contradiction
    return mse + 1.5 * jnp.mean(jnp.maximum(0, jnp.abs(model.weights) - 4.0))

@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

print("Starting training and collecting meta-data into Xarray...")

for i in tqdm(range(1001)):
    # Perform one step of the training
    loss = train_step(model, optimizer, inputs, targets)

    # Data logging at specified intervals
    if i in steps_to_log:
        # 1. Getting antecedent weights (without deprecation warning)
        # .weights[...] returns the current parameter array
        current_weights = np.abs(model.weights[...])

        # 2. Getting the rule weight (fix TypeError)
        # .item() safely converts a single-element array (even with ndim=1) to a Python float
        current_rule_w = model.rule_weight[...].item()

        # 3. Uložení do Xarray Datasetu
        history["member_weights"].loc[dict(step=i)] = current_weights
        history["rule_weight"].loc[dict(step=i)] = current_rule_w
        history["loss"].loc[dict(step=i)] = float(loss)

        # Calculate and store Accuracy
        preds = model(inputs) > 0.5
        acc_value = jnp.mean(preds == y_raw).item() # We'll also use .item() just in case
        history["accuracy"].loc[dict(step=i)] = acc_value

print("\nTraining complete. Data ready for Meta-Analysis.")

df_weights = history["member_weights"].to_series().unstack()
df_meta = df_weights.copy()
df_meta["rule_weight"] = history["rule_weight"].values

# Calculate the trend (derivative) of the loss to see learning velocity
loss_series = history["loss"].to_series()
df_meta["loss_trend"] = loss_series.diff().fillna(0).values

y_meta = history["accuracy"].values

dt_analyzer = DecisionTreeRegressor(max_depth=3)
dt_analyzer.fit(df_meta, y_meta)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))

# Weight development graph
history.member_weights.plot.line(x="step", ax=ax1)
ax1.set_title("Evolution of antecedent weights over time")

# Tree visualization
plot_tree(dt_analyzer, feature_names=df_meta.columns.tolist(), filled=True, ax=ax2, fontsize=10)
ax2.set_title("Parameter Analysis: What Determines Model Success?")
plt.show()

print("\n" + "="*50)
print("ANALYSIS FOR LLM AGENT (Gemma/Ollama)")
print("="*50)

rules_text = export_text(dt_analyzer, feature_names=df_meta.columns.tolist())
best_acc = float(y_meta.max())

prompt = f"""
I analyzed the training of the JLNN model on the Wine Quality data.
Maximum achieved accuracy: {best_acc:.2f}

Determined logical dependencies of parameters:
{rules_text}

Recommendations for the next iteration:
1. If the 'alcohol' weight is within the range shown in the top node, focus on stabilizing the 'rule_weight'.
2. Contradiction penalty 1.5 appears to be optimal for the balance between MSE and weight stability.
"""

print(prompt)

Download

This tutorial demonstrates how to build a Self-Reflective System using JLNN. In most neural networks, the training process is a “black box,” but JLNN allows us to treat the training history as structured data that can be analyzed by other models.

You can also download the raw notebook file for local use: JLNN_xarray_meta_learning.ipynb

Tip

To run the notebook locally, make sure you have installed the package using pip install -e .[test].