Source code for jlnn.export.torch_map

#!/usr/bin/env python3
"""
PyTorch Export Module for JLNN Models.

This module provides utilities to bridge the JAX/Flax NNX ecosystem with PyTorch.
It uses ONNX as an intermediate representation to translate logical neural 
network structures into PyTorch-compatible modules.
"""

# Imports
import os
import numpy as np
import jax.numpy as jnp
from flax import nnx
from typing import Optional, Dict, Any
import warnings

# Import our ONNX export function
from jlnn.export.onnx import export_to_onnx

# Optional PyTorch imports
try:
    import torch
    import onnx
    from onnx2pytorch import ConvertModel
    PYTORCH_AVAILABLE = True
except ImportError:
    torch = None
    onnx = None
    ConvertModel = None
    PYTORCH_AVAILABLE = False


[docs] def export_to_pytorch( model: nnx.Module, sample_input: jnp.ndarray, tmp_path: str = "tmp_model.onnx", cleanup: bool = True ) -> 'torch.nn.Module': """ Converts a JLNN model to a PyTorch Module via ONNX. The conversion process involves tracing the JAX model to an ONNX graph and then re-mapping those operations to PyTorch layers using onnx2pytorch. Args: model (nnx.Module): The logical neural network (NNX) to export. sample_input (jnp.ndarray): Representative input for shape tracing (batch, 2). tmp_path (str): Temporary path for the intermediate ONNX file. cleanup (bool): If True, deletes the temporary ONNX file after conversion. Returns: torch.nn.Module: An equivalent PyTorch module for inference. Raises: ImportError: If torch, onnx, or onnx2pytorch are not installed. """ if not PYTORCH_AVAILABLE: raise ImportError( "PyTorch or onnx2pytorch not found. Install with: " "pip install torch onnx onnx2pytorch" ) print("\n" + "=" * 80) print("JLNN MODEL EXPORT TO PYTORCH") print("=" * 80) try: # Stage 1: JAX -> ONNX print("\nStage 1/3: Exporting JAX model to ONNX...") export_to_onnx(model, sample_input, tmp_path) # Stage 2: ONNX -> PyTorch print("\nStage 2/3: Converting ONNX to PyTorch...") onnx_model = onnx.load(tmp_path) onnx.checker.check_model(onnx_model) print(" ✓ ONNX model loaded and validated") pytorch_model = ConvertModel(onnx_model) print(" ✓ PyTorch model created") return pytorch_model finally: # Stage 3: Cleanup if cleanup and os.path.exists(tmp_path): print("\nStage 3/3: Cleanup...") os.remove(tmp_path) print(f" ✓ Temporary ONNX file removed: {tmp_path}") print("\n" + "=" * 80) print("CONVERSION COMPLETE") print("=" * 80)
[docs] def verify_pytorch_conversion( jax_model: nnx.Module, pytorch_model: 'torch.nn.Module', sample_input: jnp.ndarray, tolerance: float = 1e-5 ) -> Dict[str, Any]: """ Verifies numerical consistency between the original JAX model and the exported PyTorch model. This utility performs a forward pass on both models using the same input and compares the resulting truth intervals using the specified absolute tolerance. Args: jax_model (nnx.Module): The original source model. pytorch_model (torch.nn.Module): The converted destination model. sample_input (jnp.ndarray): Input data for comparison. tolerance (float): Maximum allowed absolute difference between outputs. Returns: Dict[str, Any]: A report containing: - 'passed' (bool): Whether the difference is within tolerance. - 'max_diff' (float): The maximum observed absolute error. - 'jax_output' (np.ndarray): Output from the original model. - 'pytorch_output' (np.ndarray): Output from the converted model. """ if not PYTORCH_AVAILABLE: raise ImportError("PyTorch not available for verification") # Run JAX model jax_output = jax_model(sample_input) # Run PyTorch model pytorch_model.eval() with torch.no_grad(): # JAX -> NumPy -> Torch bridge numpy_input = np.array(sample_input).astype(np.float32) torch_input = torch.from_numpy(numpy_input) pytorch_output_raw = pytorch_model(torch_input) pytorch_output = pytorch_output_raw.cpu().numpy() # Compare jax_out_np = np.array(jax_output) diff = np.abs(jax_out_np - pytorch_output) max_diff = float(np.max(diff)) passed = max_diff <= tolerance return { 'passed': passed, 'max_diff': max_diff, 'jax_output': jax_out_np, 'pytorch_output': pytorch_output }