Source code for jlnn.symbolic.graph
#!/usr/bin/env python3
# Imports
import networkx as nx
from typing import Optional, Any, Dict, List
from jlnn.symbolic.compiler import Node, PredicateNode, BinaryGateNode, NAryGateNode, UnaryGateNode
from networkx.drawing.nx_pydot import to_pydot
[docs]
def build_networkx_graph(root_node: Node, graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""
Recursively builds a directed graph (NetworkX DiGraph) from a compiled JLNN tree.
This function transforms a hierarchy of Node objects into a graph structure,
where vertices represent logical operations (gates) or variables (predicates) and edges represent a flow of truth values. Each vertex contains metadata about the node type and a color for subsequent visualization.
Args:
root_node (Node): The root node of the model (typically model.root).
graph (Optional[nx.DiGraph]): An existing graph instance. If None, a new one will be created.
Returns:
nx.DiGraph: A NetworkX graph where vertex ids correspond to Python Node object ids.
"""
if graph is None:
graph = nx.DiGraph()
# We use the object id as a unique identifier in the graph
node_id = id(root_node)
# Metadata and visual styles by node type
if isinstance(root_node, PredicateNode):
label = f"P: {root_node.name}"
color = "#ADD8E6" # LightBlue
children = []
elif isinstance(root_node, BinaryGateNode):
label = root_node.gate.__class__.__name__
color = "#FFA500" # Orange
children = [root_node.left, root_node.right]
elif isinstance(root_node, NAryGateNode):
label = f"{root_node.gate.__class__.__name__}\n(n={len(root_node.children)})"
color = "#90EE90" # LightGreen
children = root_node.children
elif isinstance(root_node, UnaryGateNode):
label = root_node.gate.__class__.__name__
color = "#FF6347" # Tomato
children = [root_node.child]
else:
label = "Unknown"
color = "#D3D3D3" # LightGrey
children = []
# Adding a vertex with attributes
graph.add_node(
node_id,
label=label,
color=color,
node_type=type(root_node).__name__
)
# Recursive connection: edges lead from descendants to parents (information flow)
# Or from parent to descendants (structural dependency) - here we choose structural
for child in children:
child_id = id(child)
build_networkx_graph(child, graph)
graph.add_edge(node_id, child_id)
return graph
[docs]
def get_node_attributes(graph: nx.DiGraph) -> Dict[str, Dict[str, Any]]:
"""
Extracts visual attributes of vertices for libraries like matplotlib or pyvis.
Args:
graph (nx.DiGraph): Graph generated by the build_networkx_graph function.
Returns:
Dict: Dictionary mapping vertex ids to their properties (label, color).
"""
return {n: graph.nodes[n] for n in graph.nodes}
[docs]
def to_dot(graph: nx.DiGraph) -> str:
"""
Converts a NetworkX graph to DOT (Graphviz) format for advanced plotting.
Args:
graph (nx.DiGraph): Graph to convert.
Returns:
str: String in DOT format.
"""
dot = to_pydot(graph)
return dot.to_string()
[docs]
def from_networkx_to_jlnn(graph: nx.DiGraph, root_id: Any, rngs: Any) -> Node:
"""
Reconstructs the functional JLNN computational tree from the NetworkX graph.
This function allows "importing" logical structures defined externally
in the graph editor or generated by another algorithm.
It requires that the vertices have the 'node_type' and 'label' attributes.
Args:
graph (nx.DiGraph): Source graph.
root_id: ID of the root vertex in the NetworkX graph.
rngs (nnx.Rngs): Generators for initializing gate weights.
Returns:
Node: The reconstructed root node of the model.
"""
node_data = graph.nodes[root_id]
n_type = node_data.get('node_type')
# Obtaining successors (offspring)
successors = list(graph.successors(root_id))
if n_type == "PredicateNode":
# We assume a label in the format "P: name"
name = node_data['label'].replace("P: ", "")
return PredicateNode(name=name, rngs=rngs)
# Recursively assemble children
child_nodes = [from_networkx_to_jlnn(graph, sid, rngs) for sid in successors]
# Dynamic binding to gates (example for AND and OR)
from jlnn.nn import gates
if "And" in node_data['label']:
gate = gates.WeightedAnd(num_inputs=len(child_nodes), rngs=rngs)
return NAryGateNode(gate, child_nodes)
elif "Or" in node_data['label']:
gate = gates.WeightedOr(num_inputs=len(child_nodes), rngs=rngs)
return NAryGateNode(gate, child_nodes)
elif "Implication" in node_data['label']:
gate = gates.WeightedImplication(rngs=rngs)
return BinaryGateNode(gate, child_nodes[0], child_nodes[1])
raise ValueError(f"Failed to map type '{n_type}' with label '{node_data['label']}' to JLNN node.")