Quickstart¶
This guide follows the introductory example from our Introductory Example: JLNN Base. It demonstrates how to compile a logical formula, perform inference with intervals, and train the model using JAX.
1. Installation¶
To get started in a local environment or Colab, install directly from GitHub:
pip install git+https://github.com/RadimKozl/JLNN.git
2. Define and Compile Logic¶
The LNNFormula takes a string formula and creates a differentiable graph of Flax NNX modules.
from jlnn.symbolic import LNNFormula
from flax import nnx
import jax.numpy as jnp
rngs = nnx.Rngs(42)
# Using the rule from the introductory example
formula = "0.8::A & B -> C"
model = LNNFormula(formula, rngs)
3. Inference with Intervals¶
JLNN uses truth intervals $[L, U]$. Even the conclusion C requires a grounding input (e.g., set to an uncertain state [0, 1]).
# Define inputs for A, B, and the initial state of C
inputs = {
"A": jnp.array([[1.0]]), # Certainly True
"B": jnp.array([[1.0]]), # Certainly True
"C": jnp.array([[0.0]]) # Initial grounding
}
# Forward pass returns the [Lower, Upper] interval
prediction = model(inputs)
print(f"Prediction for C: {prediction}")
4. Training (NaN-free)¶
To train the model, we use optax and the specialized jlnn_learning_loss() which handles MSE, contradictions, and uncertainty.
import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax
# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))
# 2. Ground inputs (including initial state for C)
inputs = {
"A": jnp.array([[0.9]]),
"B": jnp.array([[0.7]]),
"C": jnp.array([[0.5]]) # MANDATORY – consequent must have grounding!
}
target = jnp.array([[0.6, 0.85]])
# 3. Loss function
def loss_fn(model, inputs, target):
pred = model(inputs)
pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0) # protection against NaN
return total_lnn_loss(pred, target)
# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
model,
wrt=nnx.Param,
tx=optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.001)
)
)
# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
# Gradients to the model – closure is traceable (inputs/target are arrays)
grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)
# Loss before update (for debug)
loss = loss_fn(model, inputs, target)
optimizer.update(model, grads)
apply_constraints(model)
final_loss = loss_fn(model, inputs, target)
final_pred = model(inputs)
return loss, final_loss, final_pred
print("=== Starting training ===")
steps = 50
for step in range(steps):
loss, final_loss, pred = train_step(model, optimizer, inputs, target)
print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f} → {final_loss:.6f}")
print(f"Prediction: {pred}")
print("─" * 60)
if jnp.isnan(final_loss).any():
print("❌ NaN detected! Stopping.")
break
print("=== Training completed ===")
# 6. Result after training
final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)
print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")