JAX Execution Engine

class jlnn.reasoning.engine.JLNNEngine(*args: Any, **kwargs: Any)[source]

Bases: Module

Orchestration engine for JLNN models, managing high-performance execution.

This class serves as a high-level wrapper for compiled logical formulas, facilitating JIT-compiled inference and atomic training operations. By extending nnx.Module, it leverages the Flax NNX state management system, allowing JAX to efficiently trace and optimize the model’s computational graph across various hardware accelerators (CPU, GPU, TPU).

model

The compiled logical network or formula to be executed.

Type:

nnx.Module

infer(inputs: Dict[str, Array]) Array[source]

Executes a JIT-compiled forward pass through the logical graph.

Parameters:

inputs – A dictionary mapping predicate names (strings) to input tensors of shape (batch, [time], features). The engine handles multi-dimensional data for both static and temporal reasoning.

Returns:

A JAX array of truth intervals [L, U] with shape (batch, [time], 2).

train_step(inputs: Dict[str, Array], targets: Array, optimizer: Optimizer, loss_fn: Callable[[Array, Array], Array]) Array[source]

Performs an atomic training step: Forward, Backward, and State Update.

This method encapsulates the complete optimization cycle within a single JIT-compiled block. It computes the loss based on target intervals, calculates gradients via automatic differentiation, and updates the model parameters using the provided optimizer.

Parameters:
  • inputs – Input data dictionary for the forward pass.

  • targets – Ground truth intervals [L, U] representing the desired logical state for the output nodes.

  • optimizer – A Flax NNX optimizer instance (e.g., Adam, SGD, or a custom constrained optimizer).

  • loss_fn – A callable that computes a scalar loss value from model predictions and target intervals.

Returns:

The scalar loss value computed for the current step (pre-update).

The JLNNEngine class serves as the main orchestrator between JAX and the logical model. It encapsulates low-level operations so that end users do not have to manage states in NNX.

Key roles of the engine:

  • JIT Compilation: The infer method uses @nnx.jit to transform recursive logical calls into highly optimized code for GPU/TPU.

  • Atomic Training: The train_step method ensures that weight updates and subsequent logical projections (constraints) occur as a single indivisible operation.

  • Abstraction: A smooth interface for passing data in dictionary form (Dict) directly into symbolic predicates.