Skip to content

Commit 9bdd0f7

Browse files
authored
Merge pull request #65 from TensorBFS/feature/tropical-einsum-omeco
feat: add tropical einsum module with OMEinsum-style design and tropical-gemm acceleration
2 parents 13cb07b + 0063a62 commit 9bdd0f7

20 files changed

+11196
-87
lines changed

tropical_in_new/src/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
"""Tropical tensor network tools for MPE (independent package)."""
22

3-
from .contraction import build_contraction_tree, choose_order, contract_tree
3+
from .contraction import (
4+
build_contraction_tree,
5+
choose_order,
6+
contract_omeco_tree,
7+
contract_tree,
8+
get_omeco_tree,
9+
)
410
from .mpe import mpe_tropical, recover_mpe_assignment
511
from .network import TensorNode, build_network
6-
from .primitives import argmax_trace, safe_log, tropical_einsum
12+
from .primitives import safe_log
13+
from .tropical_einsum import (
14+
Backpointer,
15+
argmax_trace,
16+
match_rule,
17+
tropical_einsum,
18+
tropical_reduce_max,
19+
)
720
from .utils import (
821
Factor,
922
UAIModel,
@@ -14,6 +27,7 @@
1427
)
1528

1629
__all__ = [
30+
"Backpointer",
1731
"Factor",
1832
"TensorNode",
1933
"UAIModel",
@@ -22,12 +36,16 @@
2236
"build_network",
2337
"build_tropical_factors",
2438
"choose_order",
39+
"contract_omeco_tree",
2540
"contract_tree",
41+
"get_omeco_tree",
42+
"match_rule",
2643
"mpe_tropical",
2744
"read_evidence_file",
2845
"read_model_file",
2946
"read_model_from_string",
3047
"recover_mpe_assignment",
3148
"safe_log",
3249
"tropical_einsum",
50+
"tropical_reduce_max",
3351
]

tropical_in_new/src/contraction.py

Lines changed: 148 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import omeco
1111

1212
from .network import TensorNode
13-
from .primitives import Backpointer, tropical_reduce_max
14-
from .utils import build_index_map
13+
from .tropical_einsum import tropical_einsum, tropical_reduce_max, Backpointer
1514

1615

1716
@dataclass
@@ -36,12 +35,6 @@ class ReduceNode:
3635
TreeNode = TensorNode | ContractNode | ReduceNode
3736

3837

39-
@dataclass(frozen=True)
40-
class ContractionTree:
41-
order: Tuple[int, ...]
42-
nodes: Tuple[TensorNode, ...]
43-
44-
4538
def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
4639
sizes: dict[int, int] = {}
4740
for node in nodes:
@@ -54,16 +47,128 @@ def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
5447
return sizes
5548

5649

57-
def _extract_leaf_index(node_dict: dict) -> int | None:
58-
for key in ("leaf", "leaf_index", "index", "tensor"):
59-
if key in node_dict:
60-
value = node_dict[key]
61-
if isinstance(value, int):
62-
return value
63-
return None
50+
def get_omeco_tree(nodes: list[TensorNode]) -> dict:
51+
"""Get the optimized contraction tree from omeco.
52+
53+
Args:
54+
nodes: List of tensor nodes to contract.
55+
56+
Returns:
57+
The omeco tree as a dictionary with structure:
58+
- Leaf: {"tensor_index": int}
59+
- Node: {"args": [...], "eins": {"ixs": [[...], ...], "iy": [...]}}
60+
"""
61+
ixs = [list(node.vars) for node in nodes]
62+
sizes = _infer_var_sizes(nodes)
63+
method = omeco.GreedyMethod()
64+
tree = omeco.optimize_code(ixs, [], sizes, method)
65+
return tree.to_dict()
66+
67+
68+
def contract_omeco_tree(
69+
tree_dict: dict,
70+
nodes: list[TensorNode],
71+
track_argmax: bool = True,
72+
) -> TreeNode:
73+
"""Contract tensors following omeco's optimized tree structure.
74+
75+
Uses tropical-gemm for accelerated binary contractions when available.
76+
77+
Args:
78+
tree_dict: The omeco tree dictionary from get_omeco_tree().
79+
nodes: List of input tensor nodes.
80+
track_argmax: Whether to track argmax for MPE backtracing.
81+
82+
Returns:
83+
Root TreeNode with contracted result and backpointers.
84+
"""
85+
86+
def recurse(node: dict) -> TreeNode:
87+
# Leaf node - return the input tensor
88+
if "tensor_index" in node:
89+
return nodes[node["tensor_index"]]
90+
91+
# Internal node - contract children
92+
args = node["args"]
93+
eins = node["eins"]
94+
out_vars = tuple(eins["iy"])
95+
96+
# Recursively contract children
97+
children = [recurse(arg) for arg in args]
98+
99+
# Use tropical_einsum for the contraction
100+
tensors = [c.values for c in children]
101+
child_ixs = [c.vars for c in children]
102+
103+
values, backpointer = tropical_einsum(
104+
tensors, list(child_ixs), out_vars, track_argmax=track_argmax
105+
)
106+
107+
# Build result node (for binary, use ContractNode)
108+
if len(children) == 2:
109+
all_input = set(children[0].vars) | set(children[1].vars)
110+
elim_vars = tuple(v for v in all_input if v not in out_vars)
111+
112+
return ContractNode(
113+
vars=out_vars,
114+
values=values,
115+
left=children[0],
116+
right=children[1],
117+
elim_vars=elim_vars,
118+
backpointer=backpointer,
119+
)
120+
else:
121+
# For n-ary, chain as binary
122+
result = children[0]
123+
for i, child in enumerate(children[1:], 1):
124+
is_final = (i == len(children) - 1)
125+
target_out = out_vars if is_final else tuple(dict.fromkeys(result.vars + child.vars))
126+
127+
step_tensors = [result.values, child.values]
128+
step_ixs = [result.vars, child.vars]
129+
130+
step_values, step_bp = tropical_einsum(
131+
step_tensors, list(step_ixs), target_out, track_argmax=track_argmax
132+
)
133+
134+
all_input = set(result.vars) | set(child.vars)
135+
elim_vars = tuple(v for v in all_input if v not in target_out)
136+
137+
result = ContractNode(
138+
vars=target_out,
139+
values=step_values,
140+
left=result,
141+
right=child,
142+
elim_vars=elim_vars,
143+
backpointer=step_bp,
144+
)
145+
return result
146+
147+
return recurse(tree_dict)
148+
149+
150+
# =============================================================================
151+
# Legacy API for backward compatibility
152+
# =============================================================================
153+
154+
@dataclass(frozen=True)
155+
class ContractionTree:
156+
"""Legacy contraction tree structure."""
157+
order: Tuple[int, ...]
158+
nodes: Tuple[TensorNode, ...]
159+
160+
161+
def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]:
162+
"""Legacy: Select elimination order. Use get_omeco_tree() instead."""
163+
if heuristic != "omeco":
164+
raise ValueError("Only the 'omeco' heuristic is supported.")
165+
tree_dict = get_omeco_tree(nodes)
166+
ixs = [list(node.vars) for node in nodes]
167+
return _elim_order_from_tree_dict(tree_dict, ixs)
64168

65169

66170
def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[int]:
171+
"""Extract elimination order from omeco tree (legacy support)."""
67172
total_counts: dict[int, int] = {}
68173
for vars in ixs:
69174
for var in vars:
@@ -72,14 +177,13 @@ def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[in
72177
eliminated: set[int] = set()
73178

74179
def visit(node: dict) -> tuple[dict[int, int], list[int]]:
75-
leaf_index = _extract_leaf_index(node)
76-
if leaf_index is not None:
180+
if "tensor_index" in node:
77181
counts: dict[int, int] = {}
78-
for var in ixs[leaf_index]:
182+
for var in ixs[node["tensor_index"]]:
79183
counts[var] = counts.get(var, 0) + 1
80184
return counts, []
81185

82-
children = node.get("children", [])
186+
children = node.get("args") or node.get("children", [])
83187
if not isinstance(children, list) or not children:
84188
return {}, []
85189

@@ -106,59 +210,49 @@ def visit(node: dict) -> tuple[dict[int, int], list[int]]:
106210
return order + remaining
107211

108212

109-
def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]:
110-
"""Select elimination order over variable indices using omeco."""
111-
if heuristic != "omeco":
112-
raise ValueError("Only the 'omeco' heuristic is supported.")
113-
ixs = [list(node.vars) for node in nodes]
114-
sizes = _infer_var_sizes(nodes)
115-
method = omeco.GreedyMethod() if hasattr(omeco, "GreedyMethod") else None
116-
tree = (
117-
omeco.optimize_code(ixs, [], sizes, method)
118-
if method is not None
119-
else omeco.optimize_code(ixs, [], sizes)
120-
)
121-
tree_dict = tree.to_dict() if hasattr(tree, "to_dict") else tree
122-
if not isinstance(tree_dict, dict):
123-
raise ValueError("omeco.optimize_code did not return a usable tree.")
124-
return _elim_order_from_tree_dict(tree_dict, ixs)
125-
126-
127213
def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree:
128-
"""Prepare a contraction plan from order and leaf nodes."""
214+
"""Legacy: Prepare a contraction plan from order and leaf nodes."""
129215
return ContractionTree(order=tuple(order), nodes=tuple(nodes))
130216

131217

132218
def contract_tree(
133219
tree: ContractionTree,
134-
einsum_fn,
220+
einsum_fn=None,
135221
track_argmax: bool = True,
136222
) -> TreeNode:
137-
"""Contract along the tree using the tropical einsum."""
223+
"""Legacy: Contract using elimination order. Use contract_omeco_tree() instead."""
138224
active_nodes: list[TreeNode] = list(tree.nodes)
225+
139226
for var in tree.order:
140227
bucket = [node for node in active_nodes if var in node.vars]
141228
if not bucket:
142229
continue
143230
bucket_ids = {id(node) for node in bucket}
144231
active_nodes = [node for node in active_nodes if id(node) not in bucket_ids]
232+
145233
combined: TreeNode = bucket[0]
146234
for i, other in enumerate(bucket[1:]):
147235
is_last = i == len(bucket) - 2
148236
elim_vars = (var,) if is_last else ()
149-
index_map = build_index_map(combined.vars, other.vars, elim_vars=elim_vars)
150-
values, backpointer = einsum_fn(
151-
combined.values, other.values, index_map,
237+
238+
# Use tropical_einsum
239+
target_out = tuple(v for v in dict.fromkeys(combined.vars + other.vars) if v not in elim_vars)
240+
values, backpointer = tropical_einsum(
241+
[combined.values, other.values],
242+
[combined.vars, other.vars],
243+
target_out,
152244
track_argmax=track_argmax if is_last else False,
153245
)
246+
154247
combined = ContractNode(
155-
vars=index_map.out_vars,
248+
vars=target_out,
156249
values=values,
157250
left=combined,
158251
right=other,
159252
elim_vars=elim_vars,
160253
backpointer=backpointer,
161254
)
255+
162256
if var in combined.vars:
163257
# Single-node bucket: eliminate via reduce
164258
values, backpointer = tropical_reduce_max(
@@ -172,20 +266,27 @@ def contract_tree(
172266
backpointer=backpointer,
173267
)
174268
active_nodes.append(combined)
269+
175270
while len(active_nodes) > 1:
176271
left = active_nodes.pop(0)
177272
right = active_nodes.pop(0)
178-
index_map = build_index_map(left.vars, right.vars, elim_vars=())
179-
values, _ = einsum_fn(left.values, right.values, index_map, track_argmax=False)
273+
target_out = tuple(dict.fromkeys(left.vars + right.vars))
274+
values, _ = tropical_einsum(
275+
[left.values, right.values],
276+
[left.vars, right.vars],
277+
target_out,
278+
track_argmax=False,
279+
)
180280
combined = ContractNode(
181-
vars=index_map.out_vars,
281+
vars=target_out,
182282
values=values,
183283
left=left,
184284
right=right,
185285
elim_vars=(),
186286
backpointer=None,
187287
)
188288
active_nodes.append(combined)
289+
189290
if not active_nodes:
190291
raise ValueError("Contraction produced no nodes.")
191292
return active_nodes[0]

tropical_in_new/src/mpe.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44

55
from typing import Dict, Iterable
66

7-
from .contraction import ContractNode, ReduceNode, build_contraction_tree, choose_order
8-
from .contraction import contract_tree as _contract_tree
7+
from .contraction import (
8+
ContractNode,
9+
ReduceNode,
10+
contract_omeco_tree,
11+
get_omeco_tree,
12+
)
913
from .network import TensorNode, build_network
10-
from .primitives import argmax_trace, tropical_einsum, tropical_reduce_max
14+
from .tropical_einsum import argmax_trace, tropical_reduce_max
1115
from .utils import UAIModel, build_tropical_factors
1216

1317

@@ -70,16 +74,22 @@ def traverse(node, out_assignment: Dict[int, int]) -> None:
7074
def mpe_tropical(
7175
model: UAIModel,
7276
evidence: Dict[int, int] | None = None,
73-
order: Iterable[int] | None = None,
7477
) -> tuple[Dict[int, int], float, Dict[str, int | tuple[int, ...]]]:
75-
"""Return MPE assignment, score, and contraction metadata."""
78+
"""Return MPE assignment, score, and contraction metadata.
79+
80+
Uses omeco for optimized contraction order and tropical-gemm for acceleration.
81+
"""
7682
evidence = evidence or {}
7783
factors = build_tropical_factors(model, evidence)
7884
nodes = build_network(factors)
79-
if order is None:
80-
order = choose_order(nodes, heuristic="omeco")
81-
tree = build_contraction_tree(order, nodes)
82-
root = _contract_tree(tree, einsum_fn=tropical_einsum)
85+
86+
# Get optimized contraction tree from omeco
87+
tree_dict = get_omeco_tree(nodes)
88+
89+
# Contract using the optimized tree
90+
root = contract_omeco_tree(tree_dict, nodes, track_argmax=True)
91+
92+
# Final reduction if there are remaining variables
8393
if root.vars:
8494
values, backpointer = tropical_reduce_max(
8595
root.values, root.vars, tuple(root.vars), track_argmax=True
@@ -91,12 +101,11 @@ def mpe_tropical(
91101
elim_vars=tuple(root.vars),
92102
backpointer=backpointer,
93103
)
104+
94105
assignment = recover_mpe_assignment(root)
95106
assignment.update({int(k): int(v) for k, v in evidence.items()})
96107
score = float(root.values.item())
97108
info = {
98-
"order": tuple(order),
99109
"num_nodes": len(nodes),
100-
"num_elims": len(tuple(order)),
101110
}
102111
return assignment, score, info

0 commit comments

Comments
 (0)