JLNN: JAX Logical Neural Networks¶
JLNN is a high-performance neuro-symbolic framework built on a modern JAX stack and Flax NNX. It allows you to define logical knowledge using human-readable formulas and then compile them into differentiable neural graphs.
Why JLNN?¶
Unlike standard neural networks, JLNN works with interval logic (truth is not just a point, but a range $[L, U]$). Thanks to this, the framework can detect not only what it “knows”, but also where the data is contradictory (Contradiction) or where it lacks information (Uncertainty).
Key Components¶
Symbolic Compiler: Using Lark grammar, transforms string definitions (e.g.
A & B -> C) directly into the NNX module hierarchy.Graph-Based Architecture (NetworkX): Full support for bidirectional conversion between JLNN and NetworkX. Allows importing topology from graph databases and visualizing logical trees as hierarchical graphs using
build_networkx_graph.Flax NNX Integration: Uses the latest state management in Flax, ensuring lightning speed, clean parameter handling, and compatibility with XLA.
Constraint Enforcement: Built-in projected gradients ensure that the learned weights \(w \geq 1\) always conform to logical axioms.
Unified Export: Direct path from trained model to ONNX, StableHLO and PyTorch formats.
User Guide:
- Installation
- Quickstart
- Examples & Tutorials
- Introductory Example: JLNN Base
- Base Example: Basic inference and manual grounding
- Basic Boolean Gates
- Weighted Rules & Multiple Antecedents
- Temporal Logic (G, F, X) on Time-Series
- Contradiction Detection & Model Repair
- Model Export & Deployment (StableHLO, ONNX, PyTorch)
- Real Example: Iris dataset Classification
- Meta-Learning & Self-Reflection
- The Grand Cycle: Autonomous Tuning
- Differentiable Reasoning on Graphs (JLNN vs. PyReason)
- JLNN Explainability – From scales to symbolic rules
- Bayesian JLNN: Logic in an Uncertain World
- Neuro-Symbolic Bayesian GraphSAGE + JLNN
- JLNN – Accelerated Interval Logic
- Quantum Logic and Bell Inequalities with JLNN
- LLM Rule Refinement (The Grand Cycle)
- JLNN: Temporal Symbolic GNN for Pneumatic Digital Twin
- JLNN + Knowledge Graphs: RAG-like Reasoning over FB15k-237
- Theoretical foundations of JLNN
- Testing
API Reference
- Core Logic Engine (jlnn.core)
- Model Export & Deployment (jlnn.export)
- Neural Network Components (jlnn.nn)
- Reasoning & Inference Engine (jlnn.reasoning)
- Model Storage & Persistence (jlnn.storage)
- Symbolic Front-end (jlnn.symbolic)
- Training & Optimization (jlnn.training)
- Utilities & Visualization (jlnn.utils)
About the Project:
Example of use¶
import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax
# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))
# 2. Ground inputs (including initial state for C)
inputs = {
"A": jnp.array([[0.9]]),
"B": jnp.array([[0.7]]),
"C": jnp.array([[0.5]]) # MANDATORY – consequent must have grounding!
}
target = jnp.array([[0.6, 0.85]])
# 3. Loss function
def loss_fn(model, inputs, target):
pred = model(inputs)
pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0) # protection against NaN
return total_lnn_loss(pred, target)
# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
model,
wrt=nnx.Param,
tx=optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.001)
)
)
# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
# Gradients to the model – closure is traceable (inputs/target are arrays)
grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)
# Loss before update (for debug)
loss = loss_fn(model, inputs, target)
optimizer.update(model, grads)
apply_constraints(model)
final_loss = loss_fn(model, inputs, target)
final_pred = model(inputs)
return loss, final_loss, final_pred
print("=== Starting training ===")
steps = 50
for step in range(steps):
loss, final_loss, pred = train_step(model, optimizer, inputs, target)
print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f} → {final_loss:.6f}")
print(f"Prediction: {pred}")
print("─" * 60)
if jnp.isnan(final_loss).any():
print("❌ NaN detected! Stopping.")
break
print("=== Training completed ===")
# 6. Result after training
final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)
print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")
Discord channel:¶