PyTorch Mapping¶
PyTorch Export Module for JLNN Models.
This module provides utilities to bridge the JAX/Flax NNX ecosystem with PyTorch. It uses ONNX as an intermediate representation to translate logical neural network structures into PyTorch-compatible modules.
- jlnn.export.torch_map.export_to_pytorch(model: Module, sample_input: Array, tmp_path: str = 'tmp_model.onnx', cleanup: bool = True) torch.nn.Module[source]¶
Converts a JLNN model to a PyTorch Module via ONNX.
The conversion process involves tracing the JAX model to an ONNX graph and then re-mapping those operations to PyTorch layers using onnx2pytorch.
- Parameters:
model (nnx.Module) – The logical neural network (NNX) to export.
sample_input (jnp.ndarray) – Representative input for shape tracing (batch, 2).
tmp_path (str) – Temporary path for the intermediate ONNX file.
cleanup (bool) – If True, deletes the temporary ONNX file after conversion.
- Returns:
An equivalent PyTorch module for inference.
- Return type:
torch.nn.Module
- Raises:
ImportError – If torch, onnx, or onnx2pytorch are not installed.
- jlnn.export.torch_map.verify_pytorch_conversion(jax_model: Module, pytorch_model: torch.nn.Module, sample_input: Array, tolerance: float = 1e-05) Dict[str, Any][source]¶
Verifies numerical consistency between the original JAX model and the exported PyTorch model.
This utility performs a forward pass on both models using the same input and compares the resulting truth intervals using the specified absolute tolerance.
- Parameters:
jax_model (nnx.Module) – The original source model.
pytorch_model (torch.nn.Module) – The converted destination model.
sample_input (jnp.ndarray) – Input data for comparison.
tolerance (float) – Maximum allowed absolute difference between outputs.
- Returns:
- A report containing:
’passed’ (bool): Whether the difference is within tolerance.
’max_diff’ (float): The maximum observed absolute error.
’jax_output’ (np.ndarray): Output from the original model.
’pytorch_output’ (np.ndarray): Output from the converted model.
- Return type:
Dict[str, Any]
Enables mapping JLNN operations to equivalent structures in PyTorch. This is useful if JLNN is part of a larger system that primarily runs in the Torch ecosystem.
Key Features:¶
Weight Transfer: Conversion of weights from JAX tensors to
torch.nn.Parameter.Functional Mapping: Mapping Łukasiewicz operators to Torch operations (e.g.,
torch.clampinstead ofjnp.clip).