Source code for jlnn.utils.helpers
#!/usr/bin/env python3
"""
General utility functions for Logical Neural Networks.
"""
# Imports
import jax.numpy as jnp
[docs]
def scalar_to_interval(x: jnp.ndarray) -> jnp.ndarray:
"""
Converts standard [0, 1] probability scalars into precise [L, U] intervals.
This is used to ground the JLNN model with data from classical datasets
where truth values are known exactly (L = U = x).
Args:
x (jnp.ndarray): Tensor of scalar truth values.
Returns:
jnp.ndarray: Tensor of intervals with shape ``(*x.shape, 2)``.
"""
return jnp.stack([x, x], axis=-1)
[docs]
def is_precise(interval: jnp.ndarray, epsilon: float = 1e-5) -> bool:
"""
Checks if a truth interval has collapsed into a single point (L ≈ U).
Args:
interval: Truth interval tensor [L, U].
epsilon: Maximum allowed difference between bounds.
Returns:
True if the uncertainty is within epsilon, False otherwise.
"""
return float(jnp.abs(interval[..., 0] - interval[..., 1])) < epsilon