#!/usr/bin/env python3
"""
JLNN Symbolic Compiler
This module provides the infrastructure to compile symbolic logical formulas
into neural computational graphs using Flax NNX and JAX. It leverages
the Lark parser to transform strings into a tree of NNX Modules.
"""
# Imports
from __future__ import annotations
from typing import Any, Dict, List, Union
from lark import Transformer, Tree, Token
from flax import nnx
import jax.numpy as jnp
from jlnn.nn import gates, predicates
from jlnn.symbolic.parser import FormulaParser
[docs]
class Node(nnx.Module):
"""
Abstract base class for all nodes in the JLNN computational graph.
Each node represents a logical operation or a predicate and must
implement a forward pass that operates on JAX arrays.
"""
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""
Performs the forward evaluation of the node.
Args:
values: A mapping of variable names to JAX arrays containing
the raw input features.
Returns:
A JAX array representing the truth value (typically a truth interval).
"""
raise NotImplementedError
[docs]
class PredicateNode(Node):
"""
Represents a leaf node (variable) mapping to a LearnedPredicate.
This node acts as the grounding layer where raw numeric data is
transformed into a fuzzy truth value.
"""
def __init__(self, name: str, rngs: nnx.Rngs):
"""
Initializes the predicate node.
Args:
name: The identifier of the variable in the formula.
rngs: Flax NNX random number generator stream for parameter initialization.
"""
self.name = name
# Every variable gets its own trainable grounding (LearnedPredicate)
self.predicate = predicates.LearnedPredicate(in_features=1, rngs=rngs)
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Retrieves input data by name and passes it through the predicate."""
val = values[self.name]
return self.predicate(val)
[docs]
class NAryGateNode(Node):
"""
Represents a logic gate with N inputs, such as weighted AND, OR or XOR.
"""
def __init__(self, gate: nnx.Module, children: List[Node]):
"""
Args:
gate: The neural logic gate module (e.g., WeightedAnd).
children: A list of child Nodes whose outputs are inputs to this gate.
"""
self.gate = gate
# nnx.List is required for Flax NNX to correctly track child modules
# and their parameters during JAX transformations.
self.children = nnx.List(children)
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Evaluates all children, stacks results, and applies the gate logic."""
child_outputs = [child.forward(values) for child in self.children]
# Stack results on the second-to-last axis to form a tensor of
# shape (batch, num_inputs, truth_interval_dims).
x = jnp.stack(child_outputs, axis=-2)
return self.gate(x)
[docs]
class BinaryGateNode(Node):
"""
Represents a gate with exactly 2 inputs, specifically designed for
asymmetric operations like Implication (A -> B).
"""
def __init__(self, gate: nnx.Module, left: Node, right: Node):
"""
Args:
gate: The binary logic gate module.
left: The antecedent (left-hand side) node.
right: The consequent (right-hand side) node.
"""
self.gate = gate
self.left = left
self.right = right
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Evaluates antecedent and consequent separately before applying the gate."""
left_output = self.left.forward(values)
right_output = self.right.forward(values)
# The gate receives two separate arguments (antecedent, consequent)
result = self.gate(left_output, right_output)
# Squeeze singleton dimensions to maintain consistent output shape (batch, 2)
if result.ndim == 3 and result.shape[1] == 1:
result = jnp.squeeze(result, axis=1)
return result
[docs]
class UnaryGateNode(Node):
"""
Represents a gate with a single input, such as NOT.
"""
def __init__(self, gate: nnx.Module, child: Node):
"""
Args:
gate: The unary logic gate module.
child: The child node to be negated.
"""
self.gate = gate
self.child = child
[docs]
def forward(self, values: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Passes the child's output through the unary gate."""
return self.gate(self.child.forward(values))
[docs]
class JLNNCompiler(Transformer):
"""
Lark Transformer that converts a CST (Concrete Syntax Tree) into an NNX model tree.
Note:
The grammar rules use ``and``, ``or``, and ``not``. Since these are reserved
keywords in Python, we define them as ``and_``, ``or_``, ``not_`` and alias them
dynamically after class definition.
"""
def __init__(self, rngs: nnx.Rngs):
"""
Args:
rngs: Random streams for initializing gate and predicate parameters.
"""
super().__init__()
self.rngs = rngs
self.predicates: Dict[str, predicates.LearnedPredicate] = {}
[docs]
def variable(self, tokens: List[Token]) -> PredicateNode:
"""Transforms a variable token into a PredicateNode, ensuring weight sharing if name repeats."""
name = str(tokens[0])
if name not in self.predicates:
node = PredicateNode(name, self.rngs)
self.predicates[name] = node
return self.predicates[name]
[docs]
def and_(self, children: List[Node]) -> NAryGateNode:
"""Constructs a WeightedAnd gate node."""
gate = gates.WeightedAnd(num_inputs=len(children), rngs=self.rngs)
return NAryGateNode(gate, children)
[docs]
def or_(self, children: List[Node]) -> NAryGateNode:
"""Constructs a WeightedOr gate node."""
gate = gates.WeightedOr(num_inputs=len(children), rngs=self.rngs)
return NAryGateNode(gate, children)
[docs]
def not_(self, children: List[Node]) -> UnaryGateNode:
"""Constructs a WeightedNot gate node."""
gate = gates.WeightedNot(rngs=self.rngs)
return UnaryGateNode(gate, children[0])
[docs]
def implication(self, children: List[Node]) -> BinaryGateNode:
"""Constructs a WeightedImplication gate node (A -> B)."""
gate = gates.WeightedImplication(rngs=self.rngs)
return BinaryGateNode(gate, children[0], children[1])
[docs]
def weighted_expr(self, children: List[Any]) -> Node:
"""Root rule processor; returns the final compiled node structure."""
return children[-1]
# CRITICAL: Use setattr to create aliases for Python keywords
# The grammar rules are: "-> and", "-> or", "-> not"
# Lark looks for methods with these exact names, but we can't use them directly
# because they're Python reserved words. So we use setattr() to assign them.
setattr(JLNNCompiler, 'and', JLNNCompiler.and_)
setattr(JLNNCompiler, 'or', JLNNCompiler.or_)
setattr(JLNNCompiler, 'not', JLNNCompiler.not_)