1010import omeco
1111
1212from .network import TensorNode
13- from .primitives import Backpointer , tropical_reduce_max
14- from .utils import build_index_map
13+ from .tropical_einsum import tropical_einsum , tropical_reduce_max , Backpointer
1514
1615
1716@dataclass
@@ -36,12 +35,6 @@ class ReduceNode:
3635TreeNode = TensorNode | ContractNode | ReduceNode
3736
3837
39- @dataclass (frozen = True )
40- class ContractionTree :
41- order : Tuple [int , ...]
42- nodes : Tuple [TensorNode , ...]
43-
44-
4538def _infer_var_sizes (nodes : Iterable [TensorNode ]) -> dict [int , int ]:
4639 sizes : dict [int , int ] = {}
4740 for node in nodes :
@@ -54,16 +47,128 @@ def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
5447 return sizes
5548
5649
57- def _extract_leaf_index (node_dict : dict ) -> int | None :
58- for key in ("leaf" , "leaf_index" , "index" , "tensor" ):
59- if key in node_dict :
60- value = node_dict [key ]
61- if isinstance (value , int ):
62- return value
63- return None
50+ def get_omeco_tree (nodes : list [TensorNode ]) -> dict :
51+ """Get the optimized contraction tree from omeco.
52+
53+ Args:
54+ nodes: List of tensor nodes to contract.
55+
56+ Returns:
57+ The omeco tree as a dictionary with structure:
58+ - Leaf: {"tensor_index": int}
59+ - Node: {"args": [...], "eins": {"ixs": [[...], ...], "iy": [...]}}
60+ """
61+ ixs = [list (node .vars ) for node in nodes ]
62+ sizes = _infer_var_sizes (nodes )
63+ method = omeco .GreedyMethod ()
64+ tree = omeco .optimize_code (ixs , [], sizes , method )
65+ return tree .to_dict ()
66+
67+
68+ def contract_omeco_tree (
69+ tree_dict : dict ,
70+ nodes : list [TensorNode ],
71+ track_argmax : bool = True ,
72+ ) -> TreeNode :
73+ """Contract tensors following omeco's optimized tree structure.
74+
75+ Uses tropical-gemm for accelerated binary contractions when available.
76+
77+ Args:
78+ tree_dict: The omeco tree dictionary from get_omeco_tree().
79+ nodes: List of input tensor nodes.
80+ track_argmax: Whether to track argmax for MPE backtracing.
81+
82+ Returns:
83+ Root TreeNode with contracted result and backpointers.
84+ """
85+
86+ def recurse (node : dict ) -> TreeNode :
87+ # Leaf node - return the input tensor
88+ if "tensor_index" in node :
89+ return nodes [node ["tensor_index" ]]
90+
91+ # Internal node - contract children
92+ args = node ["args" ]
93+ eins = node ["eins" ]
94+ out_vars = tuple (eins ["iy" ])
95+
96+ # Recursively contract children
97+ children = [recurse (arg ) for arg in args ]
98+
99+ # Use tropical_einsum for the contraction
100+ tensors = [c .values for c in children ]
101+ child_ixs = [c .vars for c in children ]
102+
103+ values , backpointer = tropical_einsum (
104+ tensors , list (child_ixs ), out_vars , track_argmax = track_argmax
105+ )
106+
107+ # Build result node (for binary, use ContractNode)
108+ if len (children ) == 2 :
109+ all_input = set (children [0 ].vars ) | set (children [1 ].vars )
110+ elim_vars = tuple (v for v in all_input if v not in out_vars )
111+
112+ return ContractNode (
113+ vars = out_vars ,
114+ values = values ,
115+ left = children [0 ],
116+ right = children [1 ],
117+ elim_vars = elim_vars ,
118+ backpointer = backpointer ,
119+ )
120+ else :
121+ # For n-ary, chain as binary
122+ result = children [0 ]
123+ for i , child in enumerate (children [1 :], 1 ):
124+ is_final = (i == len (children ) - 1 )
125+ target_out = out_vars if is_final else tuple (dict .fromkeys (result .vars + child .vars ))
126+
127+ step_tensors = [result .values , child .values ]
128+ step_ixs = [result .vars , child .vars ]
129+
130+ step_values , step_bp = tropical_einsum (
131+ step_tensors , list (step_ixs ), target_out , track_argmax = track_argmax
132+ )
133+
134+ all_input = set (result .vars ) | set (child .vars )
135+ elim_vars = tuple (v for v in all_input if v not in target_out )
136+
137+ result = ContractNode (
138+ vars = target_out ,
139+ values = step_values ,
140+ left = result ,
141+ right = child ,
142+ elim_vars = elim_vars ,
143+ backpointer = step_bp ,
144+ )
145+ return result
146+
147+ return recurse (tree_dict )
148+
149+
150+ # =============================================================================
151+ # Legacy API for backward compatibility
152+ # =============================================================================
153+
154+ @dataclass (frozen = True )
155+ class ContractionTree :
156+ """Legacy contraction tree structure."""
157+ order : Tuple [int , ...]
158+ nodes : Tuple [TensorNode , ...]
159+
160+
161+ def choose_order (nodes : list [TensorNode ], heuristic : str = "omeco" ) -> list [int ]:
162+ """Legacy: Select elimination order. Use get_omeco_tree() instead."""
163+ if heuristic != "omeco" :
164+ raise ValueError ("Only the 'omeco' heuristic is supported." )
165+ tree_dict = get_omeco_tree (nodes )
166+ ixs = [list (node .vars ) for node in nodes ]
167+ return _elim_order_from_tree_dict (tree_dict , ixs )
64168
65169
66170def _elim_order_from_tree_dict (tree_dict : dict , ixs : list [list [int ]]) -> list [int ]:
171+ """Extract elimination order from omeco tree (legacy support)."""
67172 total_counts : dict [int , int ] = {}
68173 for vars in ixs :
69174 for var in vars :
@@ -72,14 +177,13 @@ def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[in
72177 eliminated : set [int ] = set ()
73178
74179 def visit (node : dict ) -> tuple [dict [int , int ], list [int ]]:
75- leaf_index = _extract_leaf_index (node )
76- if leaf_index is not None :
180+ if "tensor_index" in node :
77181 counts : dict [int , int ] = {}
78- for var in ixs [leaf_index ]:
182+ for var in ixs [node [ "tensor_index" ] ]:
79183 counts [var ] = counts .get (var , 0 ) + 1
80184 return counts , []
81185
82- children = node .get ("children" , [])
186+ children = node .get ("args" ) or node . get ( " children" , [])
83187 if not isinstance (children , list ) or not children :
84188 return {}, []
85189
@@ -106,59 +210,49 @@ def visit(node: dict) -> tuple[dict[int, int], list[int]]:
106210 return order + remaining
107211
108212
109- def choose_order (nodes : list [TensorNode ], heuristic : str = "omeco" ) -> list [int ]:
110- """Select elimination order over variable indices using omeco."""
111- if heuristic != "omeco" :
112- raise ValueError ("Only the 'omeco' heuristic is supported." )
113- ixs = [list (node .vars ) for node in nodes ]
114- sizes = _infer_var_sizes (nodes )
115- method = omeco .GreedyMethod () if hasattr (omeco , "GreedyMethod" ) else None
116- tree = (
117- omeco .optimize_code (ixs , [], sizes , method )
118- if method is not None
119- else omeco .optimize_code (ixs , [], sizes )
120- )
121- tree_dict = tree .to_dict () if hasattr (tree , "to_dict" ) else tree
122- if not isinstance (tree_dict , dict ):
123- raise ValueError ("omeco.optimize_code did not return a usable tree." )
124- return _elim_order_from_tree_dict (tree_dict , ixs )
125-
126-
127213def build_contraction_tree (order : Iterable [int ], nodes : list [TensorNode ]) -> ContractionTree :
128- """Prepare a contraction plan from order and leaf nodes."""
214+ """Legacy: Prepare a contraction plan from order and leaf nodes."""
129215 return ContractionTree (order = tuple (order ), nodes = tuple (nodes ))
130216
131217
132218def contract_tree (
133219 tree : ContractionTree ,
134- einsum_fn ,
220+ einsum_fn = None ,
135221 track_argmax : bool = True ,
136222) -> TreeNode :
137- """Contract along the tree using the tropical einsum ."""
223+ """Legacy: Contract using elimination order. Use contract_omeco_tree() instead ."""
138224 active_nodes : list [TreeNode ] = list (tree .nodes )
225+
139226 for var in tree .order :
140227 bucket = [node for node in active_nodes if var in node .vars ]
141228 if not bucket :
142229 continue
143230 bucket_ids = {id (node ) for node in bucket }
144231 active_nodes = [node for node in active_nodes if id (node ) not in bucket_ids ]
232+
145233 combined : TreeNode = bucket [0 ]
146234 for i , other in enumerate (bucket [1 :]):
147235 is_last = i == len (bucket ) - 2
148236 elim_vars = (var ,) if is_last else ()
149- index_map = build_index_map (combined .vars , other .vars , elim_vars = elim_vars )
150- values , backpointer = einsum_fn (
151- combined .values , other .values , index_map ,
237+
238+ # Use tropical_einsum
239+ target_out = tuple (v for v in dict .fromkeys (combined .vars + other .vars ) if v not in elim_vars )
240+ values , backpointer = tropical_einsum (
241+ [combined .values , other .values ],
242+ [combined .vars , other .vars ],
243+ target_out ,
152244 track_argmax = track_argmax if is_last else False ,
153245 )
246+
154247 combined = ContractNode (
155- vars = index_map . out_vars ,
248+ vars = target_out ,
156249 values = values ,
157250 left = combined ,
158251 right = other ,
159252 elim_vars = elim_vars ,
160253 backpointer = backpointer ,
161254 )
255+
162256 if var in combined .vars :
163257 # Single-node bucket: eliminate via reduce
164258 values , backpointer = tropical_reduce_max (
@@ -172,20 +266,27 @@ def contract_tree(
172266 backpointer = backpointer ,
173267 )
174268 active_nodes .append (combined )
269+
175270 while len (active_nodes ) > 1 :
176271 left = active_nodes .pop (0 )
177272 right = active_nodes .pop (0 )
178- index_map = build_index_map (left .vars , right .vars , elim_vars = ())
179- values , _ = einsum_fn (left .values , right .values , index_map , track_argmax = False )
273+ target_out = tuple (dict .fromkeys (left .vars + right .vars ))
274+ values , _ = tropical_einsum (
275+ [left .values , right .values ],
276+ [left .vars , right .vars ],
277+ target_out ,
278+ track_argmax = False ,
279+ )
180280 combined = ContractNode (
181- vars = index_map . out_vars ,
281+ vars = target_out ,
182282 values = values ,
183283 left = left ,
184284 right = right ,
185285 elim_vars = (),
186286 backpointer = None ,
187287 )
188288 active_nodes .append (combined )
289+
189290 if not active_nodes :
190291 raise ValueError ("Contraction produced no nodes." )
191292 return active_nodes [0 ]
0 commit comments