PyTorch Mapping

PyTorch Export Module for JLNN Models.

This module provides utilities to bridge the JAX/Flax NNX ecosystem with PyTorch. It uses ONNX as an intermediate representation to translate logical neural network structures into PyTorch-compatible modules.

jlnn.export.torch_map.export_to_pytorch(model: Module, sample_input: Array, tmp_path: str = 'tmp_model.onnx', cleanup: bool = True) torch.nn.Module[source]

Converts a JLNN model to a PyTorch Module via ONNX.

The conversion process involves tracing the JAX model to an ONNX graph and then re-mapping those operations to PyTorch layers using onnx2pytorch.

Parameters:
  • model (nnx.Module) – The logical neural network (NNX) to export.

  • sample_input (jnp.ndarray) – Representative input for shape tracing (batch, 2).

  • tmp_path (str) – Temporary path for the intermediate ONNX file.

  • cleanup (bool) – If True, deletes the temporary ONNX file after conversion.

Returns:

An equivalent PyTorch module for inference.

Return type:

torch.nn.Module

Raises:

ImportError – If torch, onnx, or onnx2pytorch are not installed.

jlnn.export.torch_map.verify_pytorch_conversion(jax_model: Module, pytorch_model: torch.nn.Module, sample_input: Array, tolerance: float = 1e-05) Dict[str, Any][source]

Verifies numerical consistency between the original JAX model and the exported PyTorch model.

This utility performs a forward pass on both models using the same input and compares the resulting truth intervals using the specified absolute tolerance.

Parameters:
  • jax_model (nnx.Module) – The original source model.

  • pytorch_model (torch.nn.Module) – The converted destination model.

  • sample_input (jnp.ndarray) – Input data for comparison.

  • tolerance (float) – Maximum allowed absolute difference between outputs.

Returns:

A report containing:
  • ’passed’ (bool): Whether the difference is within tolerance.

  • ’max_diff’ (float): The maximum observed absolute error.

  • ’jax_output’ (np.ndarray): Output from the original model.

  • ’pytorch_output’ (np.ndarray): Output from the converted model.

Return type:

Dict[str, Any]

Enables mapping JLNN operations to equivalent structures in PyTorch. This is useful if JLNN is part of a larger system that primarily runs in the Torch ecosystem.

Key Features:

  • Weight Transfer: Conversion of weights from JAX tensors to torch.nn.Parameter.

  • Functional Mapping: Mapping Łukasiewicz operators to Torch operations (e.g., torch.clamp instead of jnp.clip).