Logical Loss Functions

jlnn.training.losses.contradiction_loss(interval: Array) Array[source]

Calculates the penalty for a logical contradiction in the truth interval.

In Logical Neural Networks (LNN), the axiom of logical consistency must be observed, where the lower bound (L) never exceeds the upper bound (U). If L > U, this means that the system is contradictory (e.g., it simultaneously claims that a statement is ‘certainly true’ and ‘certainly false’).

This function calculates the squared loss from the difference between the boundaries if this condition is violated. This motivates the optimizer to adjust the weights and biases to bring the network back into a valid logical space.

Parameters:

interval (jnp.ndarray) – A tensor of intervals of the form (…, 2). The last dimension contains the pair [Lower Bound, Upper Bound].

Returns:

A scalar value representing the average contradiction loss.

Return type:

jnp.ndarray

jlnn.training.losses.jlnn_learning_loss(prediction: Array, target: Array, contradiction_weight: float = 1.0, uncertainty_weight: float = 0.05) Array[source]

Calculates a combined loss function designed for stable learning of neuro-symbolic weights.

This loss function unifies three optimization objectives:
  1. Accuracy: Standard Mean Squared Error (MSE) to align prediction bounds with targets.

  2. Consistency: Quadratic penalty for logical contradictions where the lower bound exceeds the upper bound (L > U).

  3. Decisiveness: A penalty on the interval width (U - L) to prevent the model from remaining in a neutral “unknown” state [0, 1] and forcing it to converge towards more certain truth values.

Parameters:
  • prediction (jnp.ndarray) – Predicted truth interval tensor of the form (…, 2).

  • target (jnp.ndarray) – Target (ground truth) interval tensor of the form (…, 2).

  • contradiction_weight (float) – Scalar multiplier for the consistency penalty. Defaults to 1.0.

  • uncertainty_weight (float) – Scalar multiplier for the uncertainty (width) penalty. Small values (e.g., 0.05) are recommended to maintain focus on accuracy.

Returns:

A scalar value representing the total learning loss.

Return type:

jnp.ndarray

jlnn.training.losses.logical_consistency_loss(model_output: Array, uncertainty_weight: float = 0.1) Array[source]

A complex loss function to enforce logical consistency and model certainty.

This function combines two aspects: 1. Validity (Hinge Loss): Penalizes situations where the lower bound (L) exceeds the upper bound (U). In correct LNN logic, L <= U must always hold. 2. Certainty (Uncertainty): Minimizes the width of the interval (U - L). Encourages the model to move away from a neutral state of “don’t know” (0, 1) towards a definitive “true” (1, 1) or “false” (0, 0). so that it does not remain in the neutral state “I don’t know” (0, 1), but tends towards “true” (1, 1) or “false” (0, 0).

Parameters:
  • model_output (jnp.ndarray) – The model’s output tensor with intervals of the form (…, 2).

  • uncertainty_weight (float) – Coefficient determining the strength of the pressure to reduce uncertainty. Default value 0.1 ensures that validity and accuracy remain the primary goals.

Returns:

Scalar value representing the total inconsistency.

Return type:

jnp.ndarray

jlnn.training.losses.logical_mse_loss(prediction: Array, target: Array) Array[source]

Calculates the mean square error (MSE) between the predicted and target intervals.

This function measures the accuracy of the model by comparing predicted truth intervals with reference values ​​(labels). The calculation is performed simultaneously for both the lower (L) and upper (U) bounds, forcing the model to converge to the target in both truth dimensions.

In the context of JLNN, this loss penalizes deviation from the known truth, while additional functions (such as contradiction_loss) ensure that the resulting gradient move does not violate logical axioms.

Parameters:
  • prediction (jnp.ndarray) – Predicted interval tensor of the form (…, 2).

  • target (jnp.ndarray) – The target (ground truth) tensor of intervals of the form (…, 2).

Returns:

A scalar value representing the mean square error.

Return type:

jnp.ndarray

jlnn.training.losses.rule_violation_loss(antecedent: Array, consequent: Array) Array[source]

Penalizes violation of the semantics of logical implication (A -> B).

In neuro-symbolic learning, this is a key function for knowledge embedding. If a model claims that premise (A) is true (high lower bound) but at the same time claims that conclusion (B) is false (low upper bound), a logical conflict arises that penalizes this function.

This loss is defined as max(0, L(A) - U(B)).

Parameters:
  • antecedent (jnp.ndarray) – Truth interval of the premise (A).

  • consequent (jnp.ndarray) – Truth interval of the conclusion (B).

Returns:

Average rule violation rate across the batch.

Return type:

jnp.ndarray

jlnn.training.losses.total_lnn_loss(prediction: Array, target: Array, contradiction_weight: float = 1.0) Array[source]

Calculates the combined loss function (Total Loss) for JLNN.

This function unifies two main optimization goals in Logical Neural Networks:
  1. Accuracy: Minimizing the difference between the predicted interval and the target using MSE.

  2. Consistency: Penalizing internal inconsistencies (L > U) that would impair the interpretability of the model.

The resulting gradient value leads the model to find parameters that not only describe the data well, but also form a logically closed and consistent system in accordance with the axioms of Łukasiewicz’s logic.

Parameters:
  • prediction (jnp.ndarray) – The model’s output tensor (intervals) of the form (…, 2).

  • target (jnp.ndarray) – Reference truth values ​​(labels) of the form (…, 2).

  • contradiction_weight (float) – Hyperparameter determining the strength of the penalty for logical contradiction. A higher value (e.g. > 1.0) places more emphasis on logical purity of the model, at the cost of slower MSE error reduction. The default value is 1.0.

Returns:

Total scalar loss prepared for gradient calculation in JAX.

Return type:

jnp.ndarray

Loss functions in JLNN are designed to motivate the model to seek consistent interpretations of data.

Key Functions:

  • Contradiction Loss: The most important function for stability. Penalizes states where the lower bound (L) exceeds the upper bound (U). A logical contradiction \(L > U\) is not allowed in JLNN.

  • Uncertainty Penalization: Motivates the model to shrink the intervals of truth (approaching L and U to each other), thereby reducing the system’s “ignorance”.

  • Rule Violation Loss: Specific loss for knowledge engineering. Penalizes situations where the premise (A) is true, but the conclusion (B) is false, enforcing the validity of logical rules.