JLNN: JAX Logical Neural Networks

JLNN is a high-performance neuro-symbolic framework built on a modern JAX stack and Flax NNX. It allows you to define logical knowledge using human-readable formulas and then compile them into differentiable neural graphs.

JLNN Architecture Diagram

Why JLNN?

Unlike standard neural networks, JLNN works with interval logic (truth is not just a point, but a range $[L, U]$). Thanks to this, the framework can detect not only what it “knows”, but also where the data is contradictory (Contradiction) or where it lacks information (Uncertainty).

Key Components

  • Symbolic Compiler: Using Lark grammar, transforms string definitions (e.g. A & B -> C) directly into the NNX module hierarchy.

  • Graph-Based Architecture (NetworkX): Full support for bidirectional conversion between JLNN and NetworkX. Allows importing topology from graph databases and visualizing logical trees as hierarchical graphs using build_networkx_graph.

  • Flax NNX Integration: Uses the latest state management in Flax, ensuring lightning speed, clean parameter handling, and compatibility with XLA.

  • Constraint Enforcement: Built-in projected gradients ensure that the learned weights \(w \geq 1\) always conform to logical axioms.

  • Unified Export: Direct path from trained model to ONNX, StableHLO and PyTorch formats.

Example of use

import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax

# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))

# 2. Ground inputs (including initial state for C)
inputs = {
   "A": jnp.array([[0.9]]),
   "B": jnp.array([[0.7]]),
   "C": jnp.array([[0.5]])   # MANDATORY – consequent must have grounding!
}

target = jnp.array([[0.6, 0.85]])

# 3. Loss function
def loss_fn(model, inputs, target):
   pred = model(inputs)
   pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0)  # protection against NaN
   return total_lnn_loss(pred, target)

# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
   model,
   wrt=nnx.Param,
   tx=optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.adam(learning_rate=0.001)
   )
)

# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
   # Gradients to the model – closure is traceable (inputs/target are arrays)
   grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)

   # Loss before update (for debug)
   loss = loss_fn(model, inputs, target)

   optimizer.update(model, grads)
   apply_constraints(model)

   final_loss = loss_fn(model, inputs, target)
   final_pred = model(inputs)

   return loss, final_loss, final_pred

print("=== Starting training ===")
steps = 50
for step in range(steps):
   loss, final_loss, pred = train_step(model, optimizer, inputs, target)

   print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f}{final_loss:.6f}")
   print(f"Prediction: {pred}")
   print("─" * 60)

   if jnp.isnan(final_loss).any():
      print("❌ NaN detected! Stopping.")
      break

print("=== Training completed ===")

# 6. Result after training

final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)

print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")

Discord channel:

Discord channel

Indexes