Parameter Constraints

jlnn.nn.constraints.apply_constraints(model: Module)[source]

Top-level function to aggregate and apply all logical and structural constraints.

This should be called immediately after the optimizer’s update step but before the next forward pass. It keeps the model within the feasible space of logical formulas.

Parameters:

model (nnx.Module) – The Flax NNX model to be constrained.

jlnn.nn.constraints.clip_predicates(model: Module)[source]

Ensures logical consistency in grounding layers (predicates) by maintaining L <= U.

For a LearnedPredicate, the lower bound (L) must never exceed the upper bound (U). This is achieved by adjusting the offset parameters such that the transition for the upper bound does not lag behind the lower bound.

Parameters:

model (nnx.Module) – The Flax NNX model containing LearnedPredicate modules.

jlnn.nn.constraints.clip_weights(model: Module)[source]

Ensures that all trainable weights in logic gates satisfy the condition w >= 1.0.

In Logical Neural Networks (LNN) using Łukasiewicz semantics, maintaining weights >= 1.0 is crucial for the interpretability of t-norms and t-conorms. If weights fall below this threshold, gates lose their identity as logical operators and behave like standard neural nodes.

This function implements a ‘Projected Gradient Descent’ step by projecting violating weights back to the valid domain [1.0, inf).

Parameters:

model (nnx.Module) – The Flax NNX model/module to be constrained. The function traverses the entire graph and updates parameters in-place.

This module implements mechanisms to ensure the logical integrity of the model during training. It uses the Projected Gradient Descent method, which returns the parameters to the allowed space after each step of the optimizer.

Why are restrictions necessary?

Within Logical Neural Networks (LNN), axiomatic conditions must be met for the model to remain interpretable as a set of logical rules:

  1. Gate weights (:math:`w geq 1`): If a weight drops below 1.0, the gate would lose its identity (e.g., AND would stop behaving like a t-norm). The function clip_weights ensures this for all gate types.

  2. Consistency of predicates (:math:`L leq U`): In learned predicates (LearnedPredicates), the lower bound of truth must always cover the upper bound. Mathematically, this requires the condition offset_u <= offset_l. The function clip_predicates ensures that these bounds never cross.

Main Functions

This function should be called in each step of the training loop immediately after updating the weights:

optimizer.update(grads)
apply_constraints(model)  # Ensuring logical integrity