Base Example: Basic inference and manual grounding

This tutorial presents the simplest way to use the JLNN framework for logical reasoning with uncertainty. It focuses on defining rules, manually setting truth intervals, and calculating the resulting inference without the need to train a model.

Note

The interactive notebook is hosted externally to ensure the best viewing experience and to allow immediate execution in the cloud.

Run in Google Colab

Execute the code directly in your browser without any local setup.

https://colab.research.google.com/github/RadimKozl/JLNN/blob/main/examples/JLNN_basic_inference.ipynb
View on GitHub

View source code and outputs in the GitHub notebook browser.

https://github.com/RadimKozl/JLNN/blob/main/examples/JLNN_basic_inference.ipynb

Key concepts

In JLNN, we work with logical rules that are represented as a differentiable graph (NNX). Instead of a single truth value, we work with an interval [L, U]:

  • L (Lower bound): Minimum confirmed truth.

  • U (Upper bound): Maximum possible truth.

  • Interval width (U - L): Expresses the degree of uncertainty or ignorance about the given predicate.

Rule definition

We create the model by compiling a symbolic rule using the LNNFormula class. In this example, we consider an implication with a weight of 0.8:

'''
try:
    import jlnn
    from flax import nnx
    import jax.numpy as jnp
    print("✅ JLNN and JAX are ready.")
except ImportError:
    print("🚀 Installing JLNN from GitHub and fixing JAX for Colab...")
    # Instalace frameworku
    !pip install jax-lnn --quiet
    #!pip install git+https://github.com/RadimKozl/JLNN.git --quiet
    # Fix JAX/CUDA compatibility for 2026 in Colab
    !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

    import os
    print("\n🔄 RESTARTING ENVIRONMENT... Please wait a second and then run the cell again.")
    os.kill(os.getpid(), 9)
    os.kill(os.getpid(), 9) # After this line, the cell stops and the environment restarts
'''

import os
os.environ["JAX_PLATFORMS"] = "cpu"

import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula

print("JLNN loaded.")

# Model creation – compiling rules into an NNX graph
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))

print("Model created. Predicates:", list(model.predicates.keys()))

# Manual grounding – setting intervals

inputs = {
    "A": jnp.array([[0.7, 0.9]]),   # And it is quite likely
    "B": jnp.array([[0.4, 0.8]]),   # B has more uncertainty
    "C": jnp.array([[0.0, 1.0]])    # C is completely unknown (ignorance)
}

output = model(inputs)

print("Output shape:", output.shape)

if len(output.shape) == 3:          # (batch, 1, 2) or similar
    L = output[0, 0, 0].item()
    U = output[0, 0, 1].item()
elif len(output.shape) == 2:        # (batch, 2)
    L = output[0, 0].item()
    U = output[0, 1].item()
elif len(output.shape) == 1:        # only (2,)
    L = output[0].item()
    U = output[1].item()
else:
    raise ValueError(f"Unknown output shape: {output.shape}")

print("Output interval for C:")
print(f"  L = {L:.4f}")
print(f"  U = {U:.4f}")
print(f"  Uncertainty (width): {U - L:.4f}")

# Experiments – different input intervals

inputs_exp1 = {
    "A": jnp.array([[0.95, 1.0]]),
    "B": jnp.array([[0.90, 0.98]]),
    "C": jnp.array([[0.0, 1.0]])
}
print("Exp 1 – strong A and B:")
print(model(inputs_exp1))

inputs_exp2 = {
    "A": jnp.array([[0.95, 1.0]]),
    "B": jnp.array([[0.1, 0.3]]),
    "C": jnp.array([[0.0, 1.0]])
}
print("\nExp 2 – weak B:")
print(model(inputs_exp2))

inputs_exp3 = {
    "A": jnp.array([[0.4, 0.9]]),
    "B": jnp.array([[0.8, 0.95]]),
    "C": jnp.array([[0.0, 1.0]])
}
print("\nExp 3 – high uncertainty in A:")
print(model(inputs_exp3))

Tutorial summary

  • Symbols to Graph: JLNN converts logic rules to NNX graphs.

  • Interval Logic: We work with intervals [L, U], not fixed points.

  • No Training: In this mode, all parameters are set manually and the model is used for direct logical inference.

Download

You can also download the raw notebook file for local use: JLNN_basic_inference.ipynb

Tip

To run the notebook locally, make sure you have installed the package using pip install -e .[test].