Source code for jlnn.utils.xarray_utils
#!/usr/bin/env python3
"""
Integration utilities for Xarray data structures.
This module provides tools to bridge JAX tensors with xarray, allowing
symbolic labeling of neural outputs and trained logical weights.
"""
# Imports
import xarray as xr
import jax.numpy as jnp
from typing import Any, Dict, List
[docs]
def model_to_xarray(gate_outputs: Dict[str, jnp.ndarray], sample_labels: List[str]) -> xr.Dataset:
"""
Converts logical model outputs into a labeled xarray Dataset.
By mapping raw (batch, 2) tensors to a Dataset with 'sample' and 'bound'
dimensions, we enable powerful scientific indexing and visualization
of truth intervals.
Args:
gate_outputs: Mapping of gate names to their [L, U] output tensors.
sample_labels: Names for the samples in the batch (e.g., individual IDs).
Returns:
xr.Dataset: Multi-dimensional dataset containing truth intervals.
"""
ds_dict = {}
for name, data in gate_outputs.items():
ds_dict[name] = xr.DataArray(
data,
dims=["sample", "bound"],
coords={
"sample": sample_labels,
"bound": ["Lower", "Upper"]
},
name=name
)
return xr.Dataset(ds_dict)