diff --git a/tensorcircuit/cons.py b/tensorcircuit/cons.py index 86b32c95..9471d57b 100644 --- a/tensorcircuit/cons.py +++ b/tensorcircuit/cons.py @@ -587,9 +587,16 @@ def _algebraic_base_contraction( raw_tensors, input_sets, output_set, size_dict = _extract_topology(nodes) # Use the backend of the first node - be = nodes[0].backend + if len(nodes) > 0: + be = nodes[0].backend + else: + be = get_backend(get_default_backend()) - if len(raw_tensors) == 1: + if len(raw_tensors) == 0: + # Avoid cotengra bug for empty contraction paths + final_raw_tensor = be.ones([]) + exponent = 0.0 + elif len(raw_tensors) == 1: # Avoid cotengra bug for empty contraction paths final_raw_tensor = be.einsum(input_sets[0] + "->" + output_set, *raw_tensors) exponent = 0.0 diff --git a/tests/test_hyperedge.py b/tests/test_hyperedge.py index ad54b543..8cdcc5d5 100644 --- a/tests/test_hyperedge.py +++ b/tests/test_hyperedge.py @@ -545,3 +545,24 @@ def test_qir_fallback(contractor_setup, backend): qir = c.to_qir() c2 = tc.Circuit.from_qir(qir, circuit_params={"nqubits": n}) np.testing.assert_allclose(c.state(), c2.state(), atol=1e-5) + +@pytest.mark.parametrize("contractor_setup", [("cotengra", {"use_primitives": True})], indirect=True) +def test_algebraic_contraction_edge_cases(contractor_setup, backend_setup): + from tensorcircuit.cons import _algebraic_base_contraction + import opt_einsum + + # 0 nodes case + res0 = _algebraic_base_contraction([], opt_einsum.paths.greedy) + np.testing.assert_allclose(tc.backend.numpy(res0.tensor), 1.0) + + # 1 node case + a = tn.Node(tc.backend.convert_to_tensor(np.array([1.0, 2.0]))) + res1 = _algebraic_base_contraction([a], opt_einsum.paths.greedy) + np.testing.assert_allclose(tc.backend.numpy(res1.tensor), np.array([1.0, 2.0])) + + # 1 node with self-loop (trace) + # _extract_topology handles traces by mapping them to symbols + b = tn.Node(tc.backend.convert_to_tensor(np.eye(2))) + b[0] ^ b[1] + res1_trace = _algebraic_base_contraction([b], opt_einsum.paths.greedy) + np.testing.assert_allclose(tc.backend.numpy(res1_trace.tensor), 2.0)