JLNN: JAX Logical Neural Networks

JLNN is a high-performance neuro-symbolic framework built on a modern JAX stack and Flax NNX. It allows you to define logical knowledge using human-readable formulas and then compile them into differentiable neural graphs.

JLNN Architecture Diagram

Why JLNN?

Unlike standard neural networks, JLNN works with interval logic (truth is not just a point, but a range $[L, U]$). Thanks to this, the framework can detect not only what it “knows”, but also where the data is contradictory (Contradiction) or where it lacks information (Uncertainty).

Key Components

  • Symbolic Compiler: Using Lark grammar, transforms string definitions (e.g. A & B -> C) directly into the NNX module hierarchy.

  • Graph-Based Architecture (NetworkX): Full support for bidirectional conversion between JLNN and NetworkX. Allows importing topology from graph databases and visualizing logical trees as hierarchical graphs using build_networkx_graph.

  • Flax NNX Integration: Uses the latest state management in Flax, ensuring lightning speed, clean parameter handling, and compatibility with XLA.

  • Constraint Enforcement: Built-in projected gradients ensure that the learned weights \(w \geq 1\) always conform to logical axioms.

  • Unified Export: Direct path from trained model to ONNX, StableHLO and PyTorch formats.

User Guide:

Example of use

import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax

# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))

# 2. Ground inputs (including initial state for C)
inputs = {
   "A": jnp.array([[0.9]]),
   "B": jnp.array([[0.7]]),
   "C": jnp.array([[0.5]])   # MANDATORY – consequent must have grounding!
}

target = jnp.array([[0.6, 0.85]])

# 3. Loss function
def loss_fn(model, inputs, target):
   pred = model(inputs)
   pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0)  # protection against NaN
   return total_lnn_loss(pred, target)

# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
   model,
   wrt=nnx.Param,
   tx=optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.adam(learning_rate=0.001)
   )
)

# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
   # Gradients to the model – closure is traceable (inputs/target are arrays)
   grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)

   # Loss before update (for debug)
   loss = loss_fn(model, inputs, target)

   optimizer.update(model, grads)
   apply_constraints(model)

   final_loss = loss_fn(model, inputs, target)
   final_pred = model(inputs)

   return loss, final_loss, final_pred

print("=== Starting training ===")
steps = 50
for step in range(steps):
   loss, final_loss, pred = train_step(model, optimizer, inputs, target)

   print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f}{final_loss:.6f}")
   print(f"Prediction: {pred}")
   print("─" * 60)

   if jnp.isnan(final_loss).any():
      print("❌ NaN detected! Stopping.")
      break

print("=== Training completed ===")

# 6. Result after training

final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)

print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")

Discord channel:

Discord channel

Indexes