StableHLO Integration

jlnn.export.stablehlo.export_to_stablehlo(model: Module, sample_input: Array) Exported[source]

Compiles the JLNN model into a StableHLO artifact for high-performance deployment.

StableHLO is an operation set for deep learning compilers and is part of the OpenXLA ecosystem (https://openxla.org/stablehlo). This export ensures that the specific Łukasiewicz logic kernels defined in ‘functional.py’ are lowered into highly optimized, hardware-agnostic HLO operations.

The export process involves: 1. Splitting the stateful NNX model into GraphDef (structure) and State (parameters) 2. Creating a pure functional wrapper compatible with JAX tracing 3. Generating abstract value specifications (avals) for all inputs 4. Lowering to StableHLO via the JAX export API

Parameters:
  • model (nnx.Module) – The logic-based neural network (Flax NNX) containing logical gates such as WeightedAnd, WeightedOr, etc.

  • sample_input (jnp.ndarray) – Input data representing truth intervals (shape: […, 2]) used to trace shapes and dtypes during compilation.

Returns:

An object containing the StableHLO MLIR module

and serialized model state. Can be inspected via .mlir_module() or executed via .call().

Return type:

jax.export.Exported

References

Example

>>> model = MyJLNNModel(feature_dim=10)
>>> sample = jnp.array([[0.3, 0.7], [0.1, 0.9]])
>>> exported = export_to_stablehlo(model, sample)
>>> # Inspect the StableHLO intermediate representation
>>> print(exported.mlir_module())
>>> # Execute the exported model
>>> graphdef, state = nnx.split(model)
>>> result = exported.call(state, sample)
jlnn.export.stablehlo.export_workflow(model: Module, sample_input: Array, output_path: str, inspect: bool = False) Exported[source]

Complete workflow for exporting and saving a JLNN model to StableHLO.

This convenience function combines export, inspection, and serialization in a single call for common deployment scenarios.

Parameters:
  • model (nnx.Module) – The JLNN model to export.

  • sample_input (jnp.ndarray) – Sample input for tracing.

  • output_path (str) – Path where the StableHLO artifact will be saved.

  • inspect (bool) – If True, prints the MLIR module for debugging.

Returns:

The exported model artifact.

Return type:

jax.export.Exported

References

Example

>>> model = MyJLNNModel(feature_dim=10)
>>> sample = jnp.array([[0.3, 0.7], [0.1, 0.9]])
>>> exported = export_workflow(
...     model, sample, "model.stablehlo", inspect=True
... )
jlnn.export.stablehlo.inspect_stablehlo_module(exported: Exported, verbose: bool = False)[source]

Inspects and prints the StableHLO MLIR representation for debugging.

This function is useful for: - Verifying correct lowering of Łukasiewicz logic operations - Identifying optimization opportunities in the HLO graph - Debugging shape mismatches or type errors - Understanding the compiled computation structure

Parameters:
  • exported (jax.export.Exported) – The exported StableHLO model.

  • verbose (bool) – If True, prints the full MLIR module. If False, prints only a summary.

Example

>>> exported = export_to_stablehlo(model, sample_input)
>>> inspect_stablehlo_module(exported, verbose=True)
jlnn.export.stablehlo.save_stablehlo_artifact(exported: Exported, path: str)[source]

Serializes the StableHLO module to a binary file for deployment.

The serialized artifact is a portable representation of the compiled model that can be: - Loaded by XLA runtimes without Python/JAX dependencies - Converted to other formats (TFLite, ONNX via intermediate tools) - Deployed to cloud TPUs, GPUs, or custom accelerators - Used with OpenXLA toolchain for further optimization

The serialization format preserves: - Complete StableHLO computation graph - Model parameters and their shapes - Type information and constant values - Control flow and conditional operations

Parameters:
  • exported (jax.export.Exported) – The exported StableHLO model artifact returned by export_to_stablehlo().

  • path (str) – Destination file path for the serialized artifact. Convention: use .stablehlo or .mlir extension.

References

Example

>>> exported = export_to_stablehlo(model, sample_input)
>>> save_stablehlo_artifact(exported, "jlnn_model.stablehlo")
>>> # Later, load and execute:
>>> with open("jlnn_model.stablehlo", "rb") as f:
>>>     serialized = f.read()
>>> loaded = jax.export.deserialize(serialized)

This module serves to export the model to StableHLO (Stable High Level Operations). This is key for:

  • TFX / Google Cloud: Deployment of models into the TensorFlow ecosystem.

  • Hardware Accelerators: Compilation for specific chips that support the XLA/HLO dialect.