Projected Optimizers

class jlnn.training.optimizers.ProjectedOptimizer(optimizer: GradientTransformation, model: Module)[source]

Bases: object

Optimizer with support for logical constraints (Constraint Projection).

This class wraps the standard optax optimizer and ensures that after each learning step the model parameters satisfy logical axioms (e.g. w >= 1 for LNN gates).

step(model: Module, grads: Any)[source]

It performs one optimization step followed by projection onto a logical set.

Parameters:
  • model – The model whose parameters we are updating.

  • grads – Gradients calculated using jax.grad or jax.value_and_grad.

Standard optimizers (like Adam or SGD) can during training push parameters into an uninterpretable state. ProjectedOptimizer solves this problem using the technique of Projected Gradient Descent.

Mechanism of operation:

  1. The standard scale update step is performed (e.g. using Optax).

  2. Immediately applies projection using apply_constraints.

  3. Gate weights are returned to the space \(w \geq 1\) and predicate bounds are fixed.

Wrapper over any optax optimizer. Ensures that the model satisfies logical axioms after each step.