StableHLO Integration¶
- jlnn.export.stablehlo.export_to_stablehlo(model: Module, sample_input: Array) Exported[source]¶
Compiles the JLNN model into a StableHLO artifact for high-performance deployment.
StableHLO is an operation set for deep learning compilers and is part of the OpenXLA ecosystem (https://openxla.org/stablehlo). This export ensures that the specific Łukasiewicz logic kernels defined in ‘functional.py’ are lowered into highly optimized, hardware-agnostic HLO operations.
The export process involves: 1. Splitting the stateful NNX model into GraphDef (structure) and State (parameters) 2. Creating a pure functional wrapper compatible with JAX tracing 3. Generating abstract value specifications (avals) for all inputs 4. Lowering to StableHLO via the JAX export API
- Parameters:
model (nnx.Module) – The logic-based neural network (Flax NNX) containing logical gates such as WeightedAnd, WeightedOr, etc.
sample_input (jnp.ndarray) – Input data representing truth intervals (shape: […, 2]) used to trace shapes and dtypes during compilation.
- Returns:
- An object containing the StableHLO MLIR module
and serialized model state. Can be inspected via .mlir_module() or executed via .call().
- Return type:
jax.export.Exported
References
StableHLO Specification: https://openxla.org/stablehlo
JAX Export Guide: https://docs.jax.dev/en/latest/jax.export.html
XLA Flags & Optimization: https://docs.jax.dev/en/latest/xla_flags.html
Example
>>> model = MyJLNNModel(feature_dim=10) >>> sample = jnp.array([[0.3, 0.7], [0.1, 0.9]]) >>> exported = export_to_stablehlo(model, sample) >>> # Inspect the StableHLO intermediate representation >>> print(exported.mlir_module()) >>> # Execute the exported model >>> graphdef, state = nnx.split(model) >>> result = exported.call(state, sample)
- jlnn.export.stablehlo.export_workflow(model: Module, sample_input: Array, output_path: str, inspect: bool = False) Exported[source]¶
Complete workflow for exporting and saving a JLNN model to StableHLO.
This convenience function combines export, inspection, and serialization in a single call for common deployment scenarios.
- Parameters:
model (nnx.Module) – The JLNN model to export.
sample_input (jnp.ndarray) – Sample input for tracing.
output_path (str) – Path where the StableHLO artifact will be saved.
inspect (bool) – If True, prints the MLIR module for debugging.
- Returns:
The exported model artifact.
- Return type:
jax.export.Exported
References
StableHLO: https://openxla.org/stablehlo
JAX Export: https://docs.jax.dev/en/latest/jax.export.html
Example
>>> model = MyJLNNModel(feature_dim=10) >>> sample = jnp.array([[0.3, 0.7], [0.1, 0.9]]) >>> exported = export_workflow( ... model, sample, "model.stablehlo", inspect=True ... )
- jlnn.export.stablehlo.inspect_stablehlo_module(exported: Exported, verbose: bool = False)[source]¶
Inspects and prints the StableHLO MLIR representation for debugging.
This function is useful for: - Verifying correct lowering of Łukasiewicz logic operations - Identifying optimization opportunities in the HLO graph - Debugging shape mismatches or type errors - Understanding the compiled computation structure
- Parameters:
exported (jax.export.Exported) – The exported StableHLO model.
verbose (bool) – If True, prints the full MLIR module. If False, prints only a summary.
Example
>>> exported = export_to_stablehlo(model, sample_input) >>> inspect_stablehlo_module(exported, verbose=True)
- jlnn.export.stablehlo.save_stablehlo_artifact(exported: Exported, path: str)[source]¶
Serializes the StableHLO module to a binary file for deployment.
The serialized artifact is a portable representation of the compiled model that can be: - Loaded by XLA runtimes without Python/JAX dependencies - Converted to other formats (TFLite, ONNX via intermediate tools) - Deployed to cloud TPUs, GPUs, or custom accelerators - Used with OpenXLA toolchain for further optimization
The serialization format preserves: - Complete StableHLO computation graph - Model parameters and their shapes - Type information and constant values - Control flow and conditional operations
- Parameters:
exported (jax.export.Exported) – The exported StableHLO model artifact returned by export_to_stablehlo().
path (str) – Destination file path for the serialized artifact. Convention: use .stablehlo or .mlir extension.
References
OpenXLA StableHLO: https://openxla.org/stablehlo
XLA Compilation: https://docs.jax.dev/en/latest/xla_flags.html
Example
>>> exported = export_to_stablehlo(model, sample_input) >>> save_stablehlo_artifact(exported, "jlnn_model.stablehlo") >>> # Later, load and execute: >>> with open("jlnn_model.stablehlo", "rb") as f: >>> serialized = f.read() >>> loaded = jax.export.deserialize(serialized)
This module serves to export the model to StableHLO (Stable High Level Operations). This is key for:
TFX / Google Cloud: Deployment of models into the TensorFlow ecosystem.
Hardware Accelerators: Compilation for specific chips that support the XLA/HLO dialect.