Source code for jlnn.export.data
#!/usr/bin/env python3
"""
Logic parameter extraction utilities for model portability and symbolic reasoning.
This module provides functions to traverse the Flax NNX module hierarchy and
extract learned logical parameters (weights, biases, thresholds) into standard
Python dictionaries. This allows for saving model metadata as JSON or
exporting the neural logic back into symbolic rules.
"""
# Imports
import jax.numpy as jnp
from flax import nnx
from typing import Dict, Any, Union
from jlnn.nn.gates import (
WeightedAnd, WeightedOr, WeightedXor,
WeightedNand, WeightedNor, WeightedNot, WeightedImplication
)
[docs]
def export_module_recursive(module: Any, name: str = "root") -> Dict[str, Any]:
"""
Recursively extracts logical parameters while preserving the module hierarchy.
This function traverses the Flax NNX module tree, identifying logical gates
and capturing their specific parameters. It maintains the tree structure,
allowing for the reconstruction of nested logical expressions (e.g.,
reduction trees in n-ary gates).
Args:
module (Any): An instance of an NNX Module or a sub-gate to be inspected.
name (str): The attribute name of the module within its parent.
Defaults to "root".
Returns:
Dict[str, Any]: A dictionary containing the gate's name, type,
extracted parameters (weights, beta, etc.), and optional nested sub-gates.
"""
data = {
"name": name,
"type": module.__class__.__name__,
"params": {}
}
# Use [...] for NNX Variable access to avoid DeprecationWarnings
# weights (w >= 1.0) represent the relative importance of antecedents
if hasattr(module, "weights"):
data["params"]["weights"] = module.weights[...].tolist()
# beta (b) is the threshold parameter for the Lukasiewicz t-norm
if hasattr(module, "beta"):
# Ensure beta is a standard float for JSON compatibility
data["params"]["beta"] = float(module.beta[...])
# Specific threshold for XOR gates (separating True from Contradictory/False)
xor_param = "xor_threshold"
if hasattr(module, xor_param):
data["params"][xor_param] = module.xor_threshold[...].tolist()
# Recursive Tree Traversal
# We inspect the object's dictionary to find nested NNX modules.
sub_modules = {}
for attr_name, attr_value in vars(module).items():
if isinstance(attr_value, nnx.Module):
sub_modules[attr_name] = export_module_recursive(attr_value, attr_name)
if sub_modules:
data["sub_gates"] = sub_modules
return data