ONNX Export

jlnn.export.onnx.export_to_onnx(model: Module, sample_input, path: str)[source]

Exports a JLNN (Logical Neural Network) model to ONNX format.

This function first exports the model to StableHLO, then converts it to ONNX using placeholder manual graph construction. The process is designed to handle both single tensors and dictionary-based predicate inputs.

Parameters:
  • model (nnx.Module) – The trained JLNN model instance.

  • sample_input (Any) – Sample input tensor or PyTree for tracing.

  • path (str) – Destination file path for the .onnx model.

jlnn.export.onnx.export_to_stablehlo(model: Module, sample_input)[source]

Compiles JLNN model into a StableHLO artifact.

This function bridges the gap between stateful Flax NNX modules and the stateless requirements of the JAX export pipeline. It lowers the model’s logical operations (e.g., Łukasiewicz t-norms) into StableHLO representation.

Parameters:
  • model (nnx.Module) – The trained JLNN model instance containing logic gates.

  • sample_input (Any) – A sample input tensor or PyTree (e.g., dict of arrays) representing truth intervals. Used for shape and dtype tracing.

Returns:

Exported StableHLO model artifact that can be serialized

or executed.

Return type:

jax.export.Exported

jlnn.export.onnx.export_workflow_example(model: Module, sample_input, base_name: str)[source]

Complete export workflow demonstrating both StableHLO and ONNX export.

This example demonstrates the end-to-end pipeline: splitting the model state, lowering to StableHLO, and generating a portable ONNX artifact. It supports both simple tensor inputs and complex PyTree (dictionary) structures.

Parameters:
  • model (nnx.Module) – The JLNN model to export.

  • sample_input (Any) – Sample input for tracing (tensor or dict of tensors).

  • base_name (str) – Base filename (without extension).

Example

>>> # Example with dictionary-based predicates
>>> sample = {"A": jnp.array([[0.5, 0.8]]), "B": jnp.array([[0.2, 0.6]])}
>>> export_workflow_example(model, sample, "logic_model")
jlnn.export.onnx.save_for_xla_runtime(exported: Exported, filename: str)[source]

Serializes the StableHLO model for XLA/StableHLO runtime executors.

This module ensures the transformation of JLNN models into the Open Neural Network Exchange (ONNX) format.

Implementation

Unlike standard wrappers, JLNN uses the native jax.export pipeline.

  1. Model Tracing: Using export_to_stablehlo, the stateful NNX model is converted into a static computational graph.

  2. Metadata Serialization: The function export_to_onnx prepares the model for external runtime environments.

# Example export
sample_input = jnp.zeros((1, n_inputs, 2))
export_to_onnx(model, sample_input, "logic_model.onnx")