Source code for jlnn.training.optimizers
#!/usr/bin/env python3
# Imports
import optax
from flax import nnx
from typing import Any
from jlnn.nn.constraints import apply_constraints
[docs]
class ProjectedOptimizer:
"""
Optimizer with support for logical constraints (Constraint Projection).
This class wraps the standard optax optimizer
and ensures that after each learning step the model
parameters satisfy logical axioms (e.g. w >= 1 for LNN gates).
"""
def __init__(self, optimizer: optax.GradientTransformation, model: nnx.Module):
"""
Initializes the optimizer state for the given model.
Args:
optimizer: Optax transformation chain (e.g. optax.adam(1e-3)).
model: JLNNModel instance whose parameters we will optimize.
"""
self.optimizer = optimizer
# Initialize the optimizer state for trainable parameters only
self.opt_state = optimizer.init(nnx.state(model, nnx.Param))
[docs]
def step(self, model: nnx.Module, grads: Any):
"""
It performs one optimization step followed by projection onto a logical set.
Args:
model: The model whose parameters we are updating.
grads: Gradients calculated using jax.grad or jax.value_and_grad.
"""
# 1. Getting the current status of parameters
params = nnx.state(model, nnx.Param)
# 2. Calculating updates (Adam/SGD mechanics)
updates, self.opt_state = self.optimizer.update(grads, self.opt_state, params)
# 3. Applying updates to parameters
new_params = optax.apply_updates(params, updates)
nnx.update(model, new_params)
# 4. PROJECTION: Enforcing logical constraints (e.g. trimming weights)
# This transforms standard SGD into Projected Gradient Descent.
apply_constraints(model)