Source code for jlnn.reasoning.temporal
#!/usr/bin/env python3
# Imports
from typing import Dict
import jax.numpy as jnp
from jlnn.symbolic.compiler import Node
[docs]
class AlwaysNode(Node):
r"""
Implementation of the 'Always' (Globally) temporal operator, denoted as $\mathcal{G}$.
In Linear Temporal Logic (LTL), the formula $\mathcal{G}\phi$ is true if the
sub-formula $\phi$ holds at every time step within a given sequence.
Within the JLNN framework, this is realized as a generalized conjunction (AND)
over the temporal axis.
It utilizes the Gödel t-norm (minimum) to aggregate truth intervals, ensuring
that the resulting lower bound represents the "least true" moment in the
time series.
Attributes:
child (Node): The logical subtree or formula to be evaluated over time.
window_size (Optional[int]): The specific temporal look-ahead window.
If None, the operator applies to the entire input sequence.
"""
def __init__(self, child: Node, window_size: int = None):
"""
Initializes the Always (G) node.
Args:
child: The sub-formula to monitor.
window_size: The temporal range for the operator.
"""
self.child = child
self.window_size = window_size
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""
Calculates the minimum truth interval across the temporal dimension.
Computes the intersection of truth values over time, effectively
finding the invariant truth level of the sequence.
Args:
values: A dictionary of input tensors where the temporal dimension
is expected at axis 1.
Returns:
A JAX array of truth intervals $[L, U]$ with the temporal
dimension collapsed via the `min` operation.
"""
# Obtain child activations: shape (batch, time, 2)
a = self.child.forward(values)
# Aggregate across the time axis (axis=1)
return jnp.min(a, axis=1)
[docs]
class EventuallyNode(Node):
r"""
Implementation of the 'Eventually' (Finally) temporal operator, denoted as $\mathcal{F}$.
In LTL, the formula $\mathcal{F}\phi$ is true if the sub-formula $\phi$ holds at
least once at some point in the future or present. In JLNN, this is
implemented as a generalized disjunction (OR) over the temporal axis.
It utilizes a t-conorm (maximum) to aggregate truth intervals, meaning the
overall truth is determined by the "most true" moment in the sequence.
Attributes:
child (Node): The logical subtree to be evaluated.
window_size (Optional[int]): The specific temporal look-ahead window.
"""
def __init__(self, child: Node, window_size: int = None):
"""
Initializes the Eventually (F) node.
Args:
child: The sub-formula expected to occur.
window_size: The temporal range for the operator.
"""
self.child = child
self.window_size = window_size
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""
Calculates the maximum truth interval across the temporal dimension.
Determines the peak truth value within the sequence, identifying if the
condition is met at any point.
Args:
values: A dictionary of input tensors.
Returns:
A JAX array of truth intervals $[L, U]$ with the temporal
dimension collapsed via the `max` operation.
"""
# Obtain child activations: shape (batch, time, 2)
a = self.child.forward(values)
# Aggregate across the time axis (axis=1)
return jnp.max(a, axis=1)