Contradiction Detection & Model Repair

This notebook demonstrates a unique feature of JLNN: the ability to detect and correct logical conflicts. In classical neural networks, conflicting data leads to a “blurred” average, in JLNN it leads to an L > U state, which can be explicitly identified and resolved.

Note

The interactive notebook is hosted externally to ensure the best viewing experience and to allow immediate execution in the cloud.

Run in Google Colab

Execute the code directly in your browser without any local setup.

https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_contradiction_detection.ipynb
View on GitHub

Browse the source code and outputs in the GitHub notebook viewer.

https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_contradiction_detection.ipynb

Content Overview

This tutorial demonstrates the “Self-Healing” capability of JLNN. When logical rules conflict with observed data, the network enters a state of contradiction (L > U).

The following example shows how to: 1. Initialize a model with a forced contradiction. 2. Use total_lnn_loss to penalize logically invalid states. 3. Apply apply_constraints to keep the model within the bounds of Łukasiewicz semantics.

Key Takeaways

  • Contradiction Detection: Unlike black-box models, JLNN explicitly signals when it is confused by conflicting information.

  • Differentiable Logic: The consistency constraint L ≤ U is part of the loss function, allowing the model to “unlearn” or weaken conflicting rules during training.

'''
try:
    import jlnn
    from flax import nnx
    import jax.numpy as jnp
    print("✅ JLNN and JAX are ready.")
except ImportError:
    print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...")
    # Instalace frameworku
    !pip install jax-lnn --quiet
    #!pip install git+https://github.com/RadimKozl/JLNN.git --quiet
    # Fix JAX/CUDA compatibility for 2026 in Colab
    !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

    import os
    print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.")
    os.kill(os.getpid(), 9)
    os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts
'''

import os
os.environ["JAX_PLATFORMS"] = "cpu"

import jax.numpy as jnp
from flax import nnx
import optax
import matplotlib.pyplot as plt

from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss

model = LNNFormula("0.95::A & B -> C", nnx.Rngs(42))

inputs = {
    "A": jnp.array([[0.8, 0.9]]),
    "B": jnp.array([[0.7, 0.8]]),
    "C": jnp.array([[0.0, 0.2]]) # Target will conflict with the rule
}

root_node = model.root.children[0] if hasattr(model.root, 'children') else model.root
initial_output = root_node.forward(inputs)

flat_out = initial_output.reshape(-1, 2)

L_init = flat_out[0, 0].item()
U_init = flat_out[0, 1].item()

print(f"Initial C: [{L_init:.4f}, {U_init:.4f}]")
print(f"Contradiction detected: {L_init > U_init}")

def plot_contradiction(L, U, title):
    fig, ax = plt.subplots(figsize=(8, 2))
    is_conflict = L > U
    color = 'salmon' if is_conflict else 'skyblue'

    # Plot the interval
    start = min(L, U)
    width = abs(U - L) if is_conflict else (U - L)

    ax.barh(['Truth Value'], [width], left=[start], color=color, height=0.5)
    ax.axvline(L, color='blue', linestyle='--', label=f'L={L:.2f}')
    ax.axvline(U, color='red', linestyle='--', label=f'U={U:.2f}')

    ax.set_xlim(-0.1, 1.1)
    ax.set_title(title)
    ax.legend(loc='lower right')
    if is_conflict:
        ax.text(0.5, 0.2, "CONTRADICTION (L > U)", color='red', fontweight='bold', ha='center')
    plt.show()


optimizer = nnx.Optimizer(
    model,
    optax.adam(0.02),
    wrt=nnx.Param  # <--- This is a key parameter
)

target = jnp.array([[0.0, 0.2]]) # Target interval for C (low true)

@nnx.jit
def train_step(model, optimizer, inputs, target):
    def loss_fn(m):
        # Forward pass skrze logický uzel
        node = m.root.children[0] if hasattr(m.root, 'children') else m.root
        pred = node.forward(inputs)
        # Using the total_lnn_loss function
        return total_lnn_loss(pred, target)

    # Calculate loss and gradients
    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # FLAX 0.11+: update now requires both model and grads
    optimizer.update(model, grads)

    # Applying logical constraints (weights >= 1, L <= U)
    apply_constraints(model)

    return loss

print("=== Repairing model consistency ===")
for step in range(101):
    current_loss = train_step(model, optimizer, inputs, target)
    if step % 20 == 0:
        print(f"Step {step:3d} | Loss: {current_loss:.6f}")

final_output = root_node.forward(inputs)

flat_final = final_output.reshape(-1, 2)

L_final = flat_final[0, 0].item()
U_final = flat_final[0, 1].item()

print(f"Fixed C: [{L_final:.4f}, {U_final:.4f}]")
print(f"Logical conflict removed? {L_final <= U_final}")

plot_contradiction(L_final, U_final, "State After Repair (Consistency Restored)")

Download

You can also download the raw notebook file for local use: JLNN_contradiction_detection.ipynb

Tip

To run the notebook locally, make sure you have installed the package using pip install -e .[test].