Source code for jlnn.reasoning.engine
#!/usr/bin/env python3
# Imports
import jax
from flax import nnx
from typing import Dict, Any, Callable
import jax.numpy as jnp
[docs]
class JLNNEngine(nnx.Module):
"""
Orchestration engine for JLNN models, managing high-performance execution.
This class serves as a high-level wrapper for compiled logical formulas,
facilitating JIT-compiled inference and atomic training operations.
By extending `nnx.Module`, it leverages the Flax NNX state management
system, allowing JAX to efficiently trace and optimize the model's
computational graph across various hardware accelerators (CPU, GPU, TPU).
Attributes:
model (nnx.Module): The compiled logical network or formula to be executed.
"""
def __init__(self, model: nnx.Module):
"""
Initializes the engine with a target logical model.
Args:
model: A compiled LNN formula or neural logic network instance
conforming to the NNX module interface.
"""
self.model = model
[docs]
@nnx.jit
def infer(self, inputs: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""
Executes a JIT-compiled forward pass through the logical graph.
Args:
inputs: A dictionary mapping predicate names (strings) to input
tensors of shape (batch, [time], features). The engine handles
multi-dimensional data for both static and temporal reasoning.
Returns:
A JAX array of truth intervals [L, U] with shape (batch, [time], 2).
"""
return self.model(inputs)
[docs]
@nnx.jit
def train_step(self,
inputs: Dict[str, jnp.ndarray],
targets: jnp.ndarray,
optimizer: nnx.Optimizer,
loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
"""
Performs an atomic training step: Forward, Backward, and State Update.
This method encapsulates the complete optimization cycle within a single
JIT-compiled block. It computes the loss based on target intervals,
calculates gradients via automatic differentiation, and updates the
model parameters using the provided optimizer.
Args:
inputs: Input data dictionary for the forward pass.
targets: Ground truth intervals [L, U] representing the desired
logical state for the output nodes.
optimizer: A Flax NNX optimizer instance (e.g., Adam, SGD,
or a custom constrained optimizer).
loss_fn: A callable that computes a scalar loss value from
model predictions and target intervals.
Returns:
The scalar loss value computed for the current step (pre-update).
"""
def compute_loss(model):
preds = model(inputs)
return loss_fn(preds, targets)
loss, grads = nnx.value_and_grad(compute_loss)(self.model)
optimizer.step(self.model, grads)
return loss