"""
Tensornetwork Simplification
"""
# part of the implementations and ideas are inspired from
# https://github.com/jcmgray/quimb/blob/a2968050eba5a8a04ced4bdaa5e43c4fb89edc33/quimb/tensor/tensor_core.py#L7309-L8293
# (Apache 2.0)
# We here more focus on tensornetwork derived from circuit simulation
# and consider less on general tensornetwork topology.
# Note we have no direct hyperedge support in tensornetwork package
from typing import Any, List, Optional, Tuple
import numpy as np
import tensornetwork as tn
[docs]
def infer_new_size(a: tn.Node, b: tn.Node, include_old: bool = True) -> Any:
shared_edges = tn.get_shared_edges(a, b)
a_dim = np.prod([e.dimension for e in a])
b_dim = np.prod([e.dimension for e in b])
new_dim = np.prod([e.dimension for e in a if e not in shared_edges]) * np.prod(
[e.dimension for e in b if e not in shared_edges]
)
if include_old is True:
return new_dim, a_dim, b_dim
return new_dim
[docs]
def infer_new_shape(a: tn.Node, b: tn.Node, include_old: bool = True) -> Any:
"""
Get the new shape of two nodes, also supporting to return original shapes of two nodes.
:Example:
>>> a = tn.Node(np.ones([2, 3, 5]))
>>> b = tn.Node(np.ones([3, 5, 7]))
>>> a[1] ^ b[0]
>>> a[2] ^ b[1]
>>> tc.simplify.infer_new_shape(a, b)
>>> ((2, 7), (2, 3, 5), (3, 5, 7))
>>> # (shape of a, shape of b, new shape)
:param a: node one
:type a: tn.Node
:param b: node two
:type b: tn.Node
:param include_old: Whether to include original shape of two nodes, default is True.
:type include_old: bool
:return: The new shape of the two nodes.
:rtype: Union[Tuple[int, ...], Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]]
"""
shared_edges = tn.get_shared_edges(a, b)
a_shape = tuple([e.dimension for e in a])
b_shape = tuple([e.dimension for e in b])
new_shape = tuple(
([e.dimension for e in a if e not in shared_edges])
+ ([e.dimension for e in b if e not in shared_edges])
)
if include_old is True:
return new_shape, a_shape, b_shape
return new_shape
[docs]
def pseudo_contract_between(a: tn.Node, b: tn.Node, **kws: Any) -> tn.Node:
"""
Contract between Node ``a`` and ``b``, with correct shape only and no calculation
:param a: [description]
:type a: tn.Node
:param b: [description]
:type b: tn.Node
:return: [description]
:rtype: tn.Node
"""
from .cons import backend
shared_edges = tn.get_shared_edges(a, b)
new_shape = tuple(
([e.dimension for e in a if e not in shared_edges])
+ ([e.dimension for e in b if e not in shared_edges])
)
new_node = tn.Node(backend.zeros(new_shape))
tn.network_components._remove_edges(shared_edges, a, b, new_node)
return new_node
def _multi_remove(elems: List[Any], indices: List[int]) -> List[Any]:
"""Remove multiple indices from a list for one time."""
return [i for j, i in enumerate(elems) if j not in indices]
def _split_two_qubit_gate(
a: tn.Node,
max_singular_values: Optional[int] = None,
max_truncation_err: Optional[float] = None,
fixed_choice: Optional[int] = None,
) -> Any:
if (max_singular_values is not None) and (fixed_choice is None):
fixed_choice = 1
ndict, _ = tn.copy([a])
n = ndict[a]
n1, n2, _ = tn.split_node(
n,
left_edges=[n[0], n[2]],
right_edges=[n[1], n[3]],
max_singular_values=max_singular_values,
max_truncation_err=max_truncation_err,
)
if fixed_choice == 1:
# still considering better API type for fixed_choice
return n1, n2, False
s1 = n1.tensor.shape[-1] # bond dimension
n3, n4, _ = tn.split_node(
n,
left_edges=[n[0], n[3]],
right_edges=[n[1], n[2]],
max_singular_values=max_singular_values,
max_truncation_err=max_truncation_err,
)
if fixed_choice == 2: # swap one
return n3, n4, True # swap
s2 = n3.tensor.shape[-1]
if (s1 >= n[0].dimension * n[2].dimension) and (
s2 >= n[1].dimension * n[3].dimension
):
# jax jit unspport split_node with trun_err anyway
# tf function doesn't work either, though I believe it may work on tf side
# CANNOT DONE(@refraction-ray): tf.function version with trun_err set
return None
if s1 <= s2: # equal is necessary for max values to pick on unswap one
return n1, n2, False # no swap
return n3, n4, True # swap
def _rank_simplify(nodes: List[Any]) -> Tuple[List[Any], bool]:
# if total_size is None:
# total_size = sum([_sizen(t) for t in nodes])
is_changed = False
l = len(nodes)
for i in range(l):
if i < len(nodes):
n = nodes[i]
for e in n:
if not e.is_dangling():
nd, ad, bd = infer_new_shape(e.node1, e.node2)
if nd <= ad or nd <= bd:
n1, n2 = e.node1, e.node2
njs = [i for i, n in enumerate(nodes) if n is n1 or n is n2]
new_node = tn.contract_between(e.node1, e.node2)
# contract(e) is not enough for multi edges between two tensors
nodes[njs[0]] = new_node
nodes = _multi_remove(nodes, [njs[1]])
is_changed = True
break # switch to the next node
else:
break
return nodes, is_changed
def _full_rank_simplify(nodes: List[Any]) -> List[Any]:
"""
Simplify the list of tc.Nodes without increasing the rank of any tensors.
:Example:
.. code-block:: python
a = tn.Node(np.ones([2, 2]), name="a")
b = tn.Node(np.ones([2, 2]), name="b")
c = tn.Node(np.ones([2, 2, 2, 2]), name="c")
d = tn.Node(np.ones([2, 2, 2, 2, 2, 2]), name="d")
e = tn.Node(np.ones([2, 2]), name="e")
a[1] ^ c[0]
b[1] ^ c[1]
c[2] ^ d[0]
c[3] ^ d[1]
d[4] ^ e[0]
f = tn.Node(np.ones([2, 2]), name="f")
g = tn.Node(np.ones([2, 2, 2, 2]), name="g")
h = tn.Node(np.ones([2, 2, 2, 2]), name="h")
f[1] ^ g[0]
g[2] ^ h[1]
>>> nodes = simplify._full_rank_simplify([a, b, c, d, e])
>>> nodes[0].shape
[2, 2, 2, 2, 2, 2]
>>> len(nodes)
1
>>> len(simplify._full_rank_simplify([f, g, h]))
2
:param nodes: List of Nodes
:type nodes: List[Any]
:return: List of Nodes
:rtype: List[Any]
"""
nodes, is_changed = _rank_simplify(nodes)
while is_changed:
nodes, is_changed = _rank_simplify(nodes)
return nodes
def _light_cone_cancel(nodes: List[Any]) -> Tuple[List[Any], bool]:
"""
Scan the nodes and cancel pairs of U and U^dagger that are directly connected.
Assumes that for a gate node, the first half of its edges are 'output' (future-facing)
and the second half are 'input' (past-facing).
"""
is_changed = False
nodes_to_remove = set()
# Identify the "future" side. Nodes in expectation are typically [ket_nodes, bra_nodes, ops]
# We scan backward through the nodes list.
for i in range(len(nodes) - 1, -1, -1):
n = nodes[i]
if n in nodes_to_remove:
continue
if getattr(n, "is_dagger", None) is True:
continue
noe = len(n.shape)
if noe % 2 != 0:
continue
# Check if all output legs (0 to noe//2 - 1) are connected to the same conjugate node
match_node = None
for leg_idx in range(noe // 2):
e = n[leg_idx]
if e.is_dangling():
break
n1, n2 = e.node1, e.node2
other = n2 if n1 is n else n1
if getattr(other, "is_dagger", None) is not True:
break
if getattr(other, "id", None) != getattr(n, "id", -1):
break
if e.axis1 != e.axis2:
break
if match_node is None:
match_node = other
elif match_node is not other:
break
else:
if match_node is not None and match_node not in nodes_to_remove:
# Perform cancellation by bypass
for leg_idx in range(noe // 2, noe):
e_n = n[leg_idx]
e_m = match_node[leg_idx]
m_n, i_n = (
(e_n.node2, e_n.axis2)
if e_n.node1 is n
else (e_n.node1, e_n.axis1)
)
m_m, i_m = (
(e_m.node2, e_m.axis2)
if e_m.node1 is match_node
else (e_m.node1, e_m.axis1)
)
e_n.disconnect()
e_m.disconnect()
m_n[i_n] ^ m_m[i_m]
nodes_to_remove.add(n)
nodes_to_remove.add(match_node)
is_changed = True
if is_changed:
new_nodes = [n for n in nodes if n not in nodes_to_remove]
return new_nodes, True
return nodes, False
# TODO(@refraction-ray): better light cone cancellation in terms of MPO gates (three legs one)
# MPO cancellation requires matching internal bond dimensions and identities of all nodes in the MPO.
def _full_light_cone_cancel(nodes: List[Any]) -> List[Any]:
"""
Simplify the list of tc.Nodes using casual lightcone structure.
:param nodes: List of nodes representing the tensor network.
:type nodes: List[Any]
:return: Simplified list of nodes.
:rtype: List[Any]
"""
if not nodes:
return nodes
# Check metadata availability
if any(getattr(n, "is_dagger", None) is None for n in nodes):
return nodes
# A single pass of _light_cone_cancel (optimized for O(N)) is often sufficient
# but we use a while loop to ensure all possible cancellations are resolved.
nodes, is_changed = _light_cone_cancel(nodes)
while is_changed:
nodes, is_changed = _light_cone_cancel(nodes)
return nodes