Model Checkpoints

jlnn.storage.checkpoints.load_checkpoint(model: Module, filepath: str | Path)[source]

Retrieves the saved parameter state and updates the existing model instance. Includes structural integrity check via parameter keys and shapes.

Parameters:
  • model (nnx.Module) – The model instance to load parameters into.

  • filepath (Union[str, Path]) – Path to the checkpoint file.

Raises:
  • FileNotFoundError – If the checkpoint file doesn’t exist.

  • ValueError – If the checkpoint structure doesn’t match the model structure.

jlnn.storage.checkpoints.save_checkpoint(model: Module, filepath: str | Path)[source]

Serializes and saves the current state of the NNX model parameters to a binary file.

This function uses the nnx.split mechanism, which separates the graph definition from the data itself (weights and beta parameters). Only the state is stored, ensuring that the files are compact and contain all the learned logic rules defined in gates like WeightedAnd or WeightedXor.

Parameters:
  • model (nnx.Module) – An instance of the logical model whose parameters (including the weights w >= 1.0 enforced in constraints.py) are to be stored.

  • filepath (Union[str, Path]) – The path to the target file (usually with a .pkl extension). If the directory does not exist, it will be created automatically.

Example

>>> model = WeightedXor(num_inputs=4, rngs=nnx.Rngs(42))
>>> save_checkpoint(model, "checkpoints/xor_v1.pkl")

This module uses the nnx.split mechanism to separate the graph structure from the actual data (weights). This allows only essential parameters to be saved, saving space and increasing stability.

Key Features:

  • Compact storage: Only nnx.Param type objects are stored (weights $w ge 1$ and bias $beta$).

  • State integrity: After loading a checkpoint, it is recommended to run Parameter Constraints, to ensure logical consistency even after manual file modifications.

Note

When loading, the target model’s structure (number of gates and inputs) must exactly match the saved state.