Source code for tensorcircuit.cons

"""
Constants and setups
"""

# pylint: disable=invalid-name

import logging
import sys
import time
from contextlib import contextmanager
from functools import partial, reduce, wraps, lru_cache
from operator import mul
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple

import numpy as np
import opt_einsum
import tensornetwork as tn
from tensornetwork.backend_contextmanager import get_default_backend
from networkx.utils import UnionFind

from .backends.numpy_backend import NumpyBackend
from .backends import get_backend
from .simplify import _multi_remove

logger = logging.getLogger(__name__)

## monkey patch
_NODE_CREATION_COUNTER = 0
_original_node_init = tn.Node.__init__
_original_copynode_init = tn.CopyNode.__init__


@wraps(_original_node_init)
def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None:
    """Patched Node.__init__ to add a stable creation ID."""
    global _NODE_CREATION_COUNTER
    _original_node_init(self, *args, **kwargs)
    self._stable_id_ = _NODE_CREATION_COUNTER
    _NODE_CREATION_COUNTER += 1


@wraps(_original_copynode_init)
def _patched_copynode_init(self: Any, *args: Any, **kwargs: Any) -> None:
    """Patched CopyNode.__init__ to add a stable creation ID."""
    global _NODE_CREATION_COUNTER
    _original_copynode_init(self, *args, **kwargs)
    self._stable_id_ = _NODE_CREATION_COUNTER
    _NODE_CREATION_COUNTER += 1


tn.Node.__init__ = _patched_node_init
tn.CopyNode.__init__ = _patched_copynode_init


def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]:
    n1, n2 = edge.node1, edge.node2
    id1 = getattr(n1, "_stable_id_", -1)
    id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2  # -2 for dangling

    if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2):
        id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1
    else:
        ax1, ax2 = edge.axis1, edge.axis2
    return (id1, ax1, id2, ax2)


[docs] def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]: return sorted(edges, key=_get_edge_stable_key)
package_name = "tensorcircuit" thismodule = sys.modules[__name__] dtypestr = "complex64" rdtypestr = "float32" idtypestr = "int32" npdtype = np.complex64 backend: NumpyBackend = get_backend("numpy") contractor = tn.contractors.auto _ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" # these above lines are just for mypy, it is not very good at evaluating runtime object def _set_global_contractor(contractor_fn: Callable[..., Any]) -> None: for module in sys.modules: if module.startswith(package_name): setattr(sys.modules[module], "contractor", contractor_fn)
[docs] def set_tensornetwork_backend( backend: Optional[str] = None, set_global: bool = True ) -> Any: r"""To set the runtime backend of tensorcircuit. Note: ``tc.set_backend`` and ``tc.cons.set_tensornetwork_backend`` are the same. :Example: >>> tc.set_backend("numpy") numpy_backend >>> tc.gates.num_to_tensor(0.1) array(0.1+0.j, dtype=complex64) >>> >>> tc.set_backend("tensorflow") tensorflow_backend >>> tc.gates.num_to_tensor(0.1) <tf.Tensor: shape=(), dtype=complex64, numpy=(0.1+0j)> >>> >>> tc.set_backend("pytorch") pytorch_backend >>> tc.gates.num_to_tensor(0.1) tensor(0.1000+0.j) >>> >>> tc.set_backend("jax") jax_backend >>> tc.gates.num_to_tensor(0.1) DeviceArray(0.1+0.j, dtype=complex64) :param backend: "numpy", "tensorflow", "jax", "pytorch". defaults to None, which gives the same behavior as ``tensornetwork.backend_contextmanager.get_default_backend()``. :type backend: Optional[str], optional :param set_global: Whether the object should be set as global. :type set_global: bool :return: The `tc.backend` object that with all registered universal functions. :rtype: backend object """ if not backend: backend = get_default_backend() backend_obj = get_backend(backend) if set_global: for module in sys.modules: if module.startswith(package_name): setattr(sys.modules[module], "backend", backend_obj) tn.set_default_backend(backend) return backend_obj
set_backend = set_tensornetwork_backend set_tensornetwork_backend()
[docs] def set_function_backend(backend: Optional[str] = None) -> Callable[..., Any]: """ Function decorator to set function-level runtime backend :param backend: "numpy", "tensorflow", "jax", "pytorch", defaults to None :type backend: Optional[str], optional :return: Decorated function :rtype: Callable[..., Any] """ def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def newf(*args: Any, **kws: Any) -> Any: old_backend = getattr(thismodule, "backend").name set_backend(backend) try: r = f(*args, **kws) finally: set_backend(old_backend) return r return newf return wrapper
[docs] @contextmanager def runtime_backend(backend: Optional[str] = None) -> Iterator[Any]: """ Context manager to set with-level runtime backend :param backend: "numpy", "tensorflow", "jax", "pytorch", defaults to None :type backend: Optional[str], optional :yield: the backend object :rtype: Iterator[Any] """ old_backend = getattr(thismodule, "backend").name K = set_backend(backend) yield K set_backend(old_backend)
[docs] def set_dtype(dtype: Optional[str] = None, set_global: bool = True) -> Tuple[str, str]: """ Set the global runtime numerical dtype of tensors. :param dtype: "complex64"/"float32" or "complex128"/"float64", defaults to None, which is equivalent to "complex64". :type dtype: Optional[str], optional :return: complex dtype str and the corresponding real dtype str :rtype: Tuple[str, str] """ if not dtype: dtype = "complex64" if dtype == "complex64": rdtype = "float32" elif dtype == "complex128": rdtype = "float64" elif dtype == "float32": dtype = "complex64" rdtype = "float32" elif dtype == "float64": dtype = "complex128" rdtype = "float64" else: raise ValueError(f"Unsupported data type: {dtype}") if dtype == "complex128": idtype = "int64" else: idtype = "int32" try: from jax import config except ImportError: config = None # type: ignore if config is not None: if dtype == "complex128": config.update("jax_enable_x64", True) elif dtype == "complex64": config.update("jax_enable_x64", False) if set_global: npdtype = getattr(np, dtype) for module in sys.modules: if module.startswith(package_name): setattr(sys.modules[module], "dtypestr", dtype) setattr(sys.modules[module], "rdtypestr", rdtype) setattr(sys.modules[module], "idtypestr", idtype) setattr(sys.modules[module], "npdtype", npdtype) from .gates import meta_gate meta_gate() return dtype, rdtype
get_dtype = partial(set_dtype, set_global=False) set_dtype()
[docs] def set_function_dtype(dtype: Optional[str] = None) -> Callable[..., Any]: """ Function decorator to set function-level numerical dtype :param dtype: "complex64" or "complex128", defaults to None :type dtype: Optional[str], optional :return: The decorated function :rtype: Callable[..., Any] """ def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def newf(*args: Any, **kws: Any) -> Any: old_dtype = getattr(thismodule, "dtypestr") set_dtype(dtype) r = f(*args, **kws) set_dtype(old_dtype) return r return newf return wrapper
[docs] @contextmanager def runtime_dtype(dtype: Optional[str] = None) -> Iterator[Tuple[str, str]]: """ Context manager to set with-level runtime dtype :param dtype: "complex64" or "complex128", defaults to None ("complex64") :type dtype: Optional[str], optional :yield: complex dtype str and real dtype str :rtype: Iterator[Tuple[str, str]] """ old_dtype = getattr(thismodule, "dtypestr") dtuple = set_dtype(dtype) yield dtuple set_dtype(old_dtype)
# here below comes other contractors (just works, # but correctness has not been extensively tested for some of them) def _sizen(node: tn.Node, is_log: bool = False) -> int: s = reduce(mul, node.tensor.shape + (1,)) if is_log: return int(np.log2(s)) return s # type: ignore def _merge_single_gates( nodes: List[Any], total_size: Optional[int] = None ) -> Tuple[List[Any], int]: # TODO(@refraction-ray): investigate whether too much copy here so that staging is slow for large circuit nodes = list(nodes) if total_size is None: total_size = sum([_sizen(t) for t in nodes]) queue = [n for n in nodes if len(n.tensor.shape) <= 2] while queue: n0 = queue[0] try: n0[0] except IndexError: queue = _multi_remove(queue, [0]) continue if n0[0].is_dangling(): try: e0 = n0[1] if e0.is_dangling(): queue = _multi_remove(queue, [0]) continue except IndexError: queue = _multi_remove(queue, [0]) continue else: e0 = n0[0] njs = [i for i, n in enumerate(nodes) if id(n) in [id(e0.node1), id(e0.node2)]] qjs = [i for i, n in enumerate(queue) if id(n) in [id(e0.node1), id(e0.node2)]] new_node = tn.contract_parallel(e0) total_size += _sizen(new_node) logger.debug( _sizen(new_node, is_log=True), ) queue = _multi_remove(queue, qjs) if len(njs) > 1: nodes[njs[1]] = new_node nodes = _multi_remove(nodes, [njs[0]]) else: # trace edge? nodes[njs[0]] = new_node if len(new_node.tensor.shape) <= 2: # queue.append(new_node) queue.insert(0, new_node) return nodes, total_size
[docs] def experimental_contractor( nodes: List[Any], output_edge_order: Optional[List[Any]] = None, ignore_edge_order: bool = False, local_steps: int = 2, ) -> Any: total_size = sum([_sizen(t) for t in nodes]) nodes = list(nodes) # merge single qubit gate if len(nodes) > 5: nodes, total_size = _merge_single_gates(nodes, total_size) # further dq fusion if len(nodes) > 15: for r in range(local_steps): if len(nodes) < 10: break i = 0 while len(nodes) > i + 1: new_node = tn.contract_between( nodes[i], nodes[i + 1], allow_outer_product=True ) total_size += _sizen(new_node) logger.debug( r, _sizen(new_node, is_log=True), ) nodes[i] = new_node nodes = _multi_remove(nodes, [i + 1]) i += 1 logger.info("length of remaining nodes after dq fusion: %s" % len(nodes)) nodes = list(reversed(nodes)) while len(nodes) > 1: new_node = tn.contract_between(nodes[-1], nodes[-2], allow_outer_product=True) nodes = _multi_remove(nodes, [len(nodes) - 2, len(nodes) - 1]) nodes.append(new_node) logger.debug(_sizen(new_node, is_log=True)) total_size += _sizen(new_node) logger.info("----- WRITE: %s --------\n" % np.log2(total_size)) # if the final node has more than one edge, # output_edge_order has to be specified final_node = nodes[0] if output_edge_order is not None: final_node.reorder_edges(output_edge_order) return final_node
[docs] def plain_contractor( nodes: List[Any], output_edge_order: Optional[List[Any]] = None, ignore_edge_order: bool = False, ) -> Any: """ The naive state-vector simulator contraction path. :param nodes: The list of ``tn.Node``. :type nodes: List[Any] :param output_edge_order: The list of dangling node edges, defaults to be None. :type output_edge_order: Optional[List[Any]], optional :return: The ``tn.Node`` after contraction :rtype: tn.Node """ total_size = sum([_sizen(t) for t in nodes]) # nodes = list(reversed(list(nodes))) nodes = list(nodes) nodes = list(reversed(nodes)) width = 0 while len(nodes) > 1: new_node = tn.contract_between(nodes[-1], nodes[-2], allow_outer_product=True) nodes = _multi_remove(nodes, [len(nodes) - 2, len(nodes) - 1]) nodes.append(new_node) im_size = _sizen(new_node, is_log=True) logger.debug(im_size) width = max(width, im_size) total_size += _sizen(new_node) logger.info("----- SIZE: %s --------\n" % width) logger.info("----- WRITE: %s --------\n" % np.log2(total_size)) final_node = nodes[0] if output_edge_order is not None: final_node.reorder_edges(output_edge_order) return final_node
# TODO(@refraction-ray): consistent logger system for different contractors.
[docs] def nodes_to_adj(ns: List[Any]) -> Any: ind = {id(n): i for i, n in enumerate(ns)} adj = np.zeros([len(ns), len(ns)]) for node in ns: for e in node: if not e.is_dangling(): if id(e.node1) == id(node): onode = e.node2 else: onode = e.node1 adj[ind[id(node)], ind[id(onode)]] += np.log10(e.dimension) return adj
try: _ps = tn.contractors.custom_path_solvers.pathsolvers has_ps = True except AttributeError: has_ps = False
[docs] def d2s(n: int, dl: List[Any]) -> List[Any]: # dynamic to static list nums = [i for i in range(n)] i = n sl = [] for a, b in dl: sl.append([nums[a], nums[b]]) nums = _multi_remove(nums, [a, b]) nums.insert(a, i) i += 1 return sl
# seems worse than plain contraction in most cases
[docs] def tn_greedy_contractor( nodes: List[Any], output_edge_order: Optional[List[Any]] = None, ignore_edge_order: bool = False, max_branch: int = 1, ) -> Any: nodes = list(nodes) adj = nodes_to_adj(nodes) path = _ps.full_solve_complete(adj, max_branch=max_branch)[0] dl = [] for i in range(path.shape[1]): a, b = path[:, i] dl.append([a, b]) sl = d2s(len(nodes), dl) for a, b in sl: new_node = tn.contract_between(nodes[a], nodes[b], allow_outer_product=True) nodes.append(new_node) # new_node = tn.contract_between(nodes[a], nodes[b], allow_outer_product=True) # nodes = _multi_remove(nodes, [a, b]) # nodes.insert(a, new_node) final_node = nodes[-1] if output_edge_order is not None: final_node.reorder_edges(output_edge_order) return final_node
# base = tn.contractors.opt_einsum_paths.path_contractors.base # utils = tn.contractors.opt_einsum_paths.utils _einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
[docs] @lru_cache(2**14) def get_symbol(i: int) -> str: """Get the symbol corresponding to int ``i`` - runs through the usual 52 letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates. From cotengra codebase """ if i < 52: # use a-z, A-Z first return _einsum_symbols_base[i] # then proceed from 'À' i += 140 if i >= 55296: # Skip chr(57343) - chr(55296) as surrogates i += 2048 return chr(i)
def _extract_topology( nodes: List[tn.Node], ) -> Tuple[List[Any], List[str], str, Dict[str, int]]: """ Convert a physical tensor network graph (with possible CopyNodes) into algebraic components for einsum/cotengra. """ # split nodes into regular nodes and CopyNodes nodes = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1)) regular_nodes = [n for n in nodes if not isinstance(n, tn.CopyNode)] copy_nodes = [n for n in nodes if isinstance(n, tn.CopyNode)] uf = UnionFind() all_edges = tn.get_all_edges(nodes) for edge in all_edges: uf[edge] # init for cn in copy_nodes: edges = cn.edges if edges: root_edge = edges[0] for i in range(1, len(edges)): uf.union(root_edge, edges[i]) mapping_dict = {} symbol_counter = 0 input_sets = [] raw_tensors = [] for node in regular_nodes: node_symbols = [] for edge in node.edges: root = uf[edge] if root not in mapping_dict: mapping_dict[root] = get_symbol(symbol_counter) symbol_counter += 1 node_symbols.append(mapping_dict[root]) input_sets.append("".join(node_symbols)) raw_tensors.append(node.tensor) dangling_edges = sorted_edges(tn.get_subgraph_dangling(nodes)) output_set = [] for edge in dangling_edges: root = uf[edge] if root not in mapping_dict: mapping_dict[root] = get_symbol(symbol_counter) symbol_counter += 1 output_set.append(mapping_dict[root]) size_dict = {symbol: root.dimension for root, symbol in mapping_dict.items()} return ( raw_tensors, input_sets, "".join(output_set), size_dict, ) def _normalize_omeco_topology( input_sets: Sequence[Sequence[Any]], output_set: Sequence[Any], size_dict: Dict[Any, int], ) -> Tuple[List[List[int]], List[int], Dict[int, int]]: mapping: Dict[Any, int] = {} def relabel(symbol: Any) -> int: if symbol not in mapping: mapping[symbol] = len(mapping) return mapping[symbol] new_inputs = [[relabel(symbol) for symbol in term] for term in input_sets] new_output = [relabel(symbol) for symbol in output_set] new_sizes = {relabel(symbol): int(size) for symbol, size in size_dict.items()} return new_inputs, new_output, new_sizes def _omeco_tree_to_path(tree: Dict[str, Any], ntensors: int) -> List[Tuple[int, int]]: leaf_indices: List[int] = [] def collect(node: Dict[str, Any]) -> None: if "tensor_index" in node: leaf_indices.append(int(node["tensor_index"])) return if "tensorindex" in node: leaf_indices.append(int(node["tensorindex"])) return for child in node["args"]: collect(child) collect(tree) offset = 1 if sorted(leaf_indices) == list(range(1, ntensors + 1)) else 0 queue = list(range(ntensors)) path: List[Tuple[int, int]] = [] next_index = ntensors def visit(node: Dict[str, Any]) -> int: nonlocal next_index if "tensor_index" in node: return int(node["tensor_index"]) - offset if "tensorindex" in node: return int(node["tensorindex"]) - offset if len(node["args"]) != 2: raise ValueError("omeco returned a non-binary contraction tree") left = visit(node["args"][0]) right = visit(node["args"][1]) left_index = queue.index(left) right_index = queue.index(right) if left_index > right_index: pair = (left_index, right_index) else: pair = (right_index, left_index) queue.pop(pair[0]) queue.pop(pair[1]) queue.append(next_index) path.append(pair) next_index += 1 return next_index - 1 visit(tree) return path
[docs] class OMEOptimizer: """ Adapt an ``omeco`` optimizer for ``tc.set_contractor("custom", ...)``. ``omeco`` stays an optional dependency and is imported lazily only when this adapter is actually used. Example: >>> import omeco >>> opt = omeco.TreeSA(ntrials=4, niters=8, betas=[0.1, 1.0, 10.0]) >>> tc.set_contractor("custom", optimizer=opt, preprocessing=True) """
[docs] def __init__(self, optimizer: Optional[Any] = None): self.optimizer = optimizer
def __call__( self, input_sets: Sequence[Sequence[Any]], output_set: Sequence[Any], size_dict: Dict[Any, int], _memory_limit: Optional[int] = None, **_kws: Any, ) -> List[Tuple[int, int]]: try: import omeco except ImportError as e: raise ImportError( "OMEOptimizer requires the optional dependency 'omeco'." ) from e inputs, output, sizes = _normalize_omeco_topology( input_sets, output_set, size_dict ) if self.optimizer is None: tree = omeco.optimize_code(inputs, output, sizes) else: tree = omeco.optimize_code(inputs, output, sizes, self.optimizer) return _omeco_tree_to_path(tree.to_dict(), len(inputs))
def _wrap_omeco_optimizer(optimizer: Any) -> Any: if optimizer is None or isinstance(optimizer, list) or callable(optimizer): return optimizer try: import omeco except ImportError: return optimizer if isinstance(optimizer, (omeco.GreedyMethod, omeco.TreeSA)): return OMEOptimizer(optimizer) return optimizer def _algebraic_base_contraction( nodes: List[tn.Node], algorithm: Any, output_edge_order: Optional[Sequence[tn.Edge]] = None, ignore_edge_order: bool = False, **kws: Any, ) -> Any: """ Execute contraction using cotengra and autoray for bare tensors. """ import cotengra as ctg raw_tensors, input_sets, output_set, size_dict = _extract_topology(nodes) # Use the backend of the first node be = nodes[0].backend if len(raw_tensors) == 1: # Avoid cotengra bug for empty contraction paths final_raw_tensor = be.einsum(input_sets[0] + "->" + output_set, *raw_tensors) exponent = 0.0 else: path = algorithm(input_sets, output_set, size_dict) logger.info("the contraction path is given as %s" % str(path)) tree = ctg.ContractionTree.from_path( input_sets, output_set, size_dict, path=path ) # Use autoray to keep AD and JIT support across backends # Note: cotengra's make_contractor handles the orchestration if not kws.get("strip_exponent", False): contractor = ctg.core.make_contractor(tree, implementation="autoray") final_raw_tensor = contractor(*raw_tensors) else: final_raw_tensor, exponent = tree.contract(raw_tensors, strip_exponent=True) final_node = tn.Node(final_raw_tensor, backend=be) # Resolve dangling edges in the same order as in _extract_topology dangling_edges = sorted_edges(tn.get_subgraph_dangling(nodes)) # Update the edges to point to the new final_node for i, edge in enumerate(dangling_edges): if edge.node1 in nodes: edge.node1 = final_node edge.axis1 = i else: edge.node2 = final_node edge.axis2 = i final_node.edges = list(dangling_edges) if not ignore_edge_order: if output_edge_order is None: output_edge_order = dangling_edges final_node.reorder_edges(list(output_edge_order)) if kws.get("strip_exponent", False): return final_node, exponent return final_node def _get_path( nodes: List[tn.Node], algorithm: Any ) -> Tuple[List[Tuple[int, int]], List[tn.Node]]: nodes = list(nodes) input_sets = [set([id(e) for e in node.edges]) for node in nodes] output_set = set([id(e) for e in tn.get_subgraph_dangling(nodes)]) size_dict = {id(edge): edge.dimension for edge in tn.get_all_edges(nodes)} return algorithm(input_sets, output_set, size_dict), nodes def _identity(*args: Any, **kws: Any) -> Any: return args def _get_path_cache_friendly( nodes: List[tn.Node], algorithm: Any ) -> Tuple[List[Tuple[int, int]], List[tn.Node]]: nodes = list(nodes) nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1)) # if isinstance(algorithm, list): # return algorithm, [nodes_new] all_edges = tn.get_all_edges(nodes_new) all_edges_sorted = sorted_edges(all_edges) mapping_dict = {} i = 0 for edge in all_edges_sorted: if id(edge) not in mapping_dict: mapping_dict[id(edge)] = get_symbol(i) i += 1 input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new] output_set = list( [mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))] ) size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted} logger.debug("input_sets: %s" % input_sets) logger.debug("output_set: %s" % output_set) logger.debug("size_dict: %s" % size_dict) logger.debug("path finder algorithm: %s" % algorithm) return algorithm(input_sets, output_set, size_dict), nodes_new # directly get input_sets, output_set and size_dict by using identity function as algorithm get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity) # some contractor setup usages """ import cotengra as ctg import opt_einsum as oem sys.setrecursionlimit(10000) # for successfullt ctg parallel opt = ctg.ReusableHyperOptimizer( methods=["greedy", "kahypar"], parallel=True, minimize="write", max_time=30, max_repeats=4096, progbar=True, ) tc.set_contractor("custom", optimizer=opt, preprocessing=True) tc.set_contractor("custom_stateful", optimizer=oem.RandomGreedy, max_time=60, max_repeats=128, minimize="size") tc.set_contractor("plain-experimental", local_steps=3) # hyper efficient contractor: though long computation time required, suitable for extra large circuit simulation opt = ctg.ReusableHyperOptimizer( minimize='combo', max_repeats=1024, max_time='equil:128', optlib='nevergrad', progbar=True, ) def opt_reconf(inputs, output, size, **kws): tree = opt.search(inputs, output, size) tree_r = tree.subtree_reconfigure_forest(progbar=True, num_trees=10, num_restarts=20, subtree_weight_what=("size", )) return tree_r.get_path() tc.set_contractor("custom", optimizer=opt_reconf) """ def _base( nodes: List[tn.Node], algorithm: Any, output_edge_order: Optional[Sequence[tn.Edge]] = None, ignore_edge_order: bool = False, total_size: Optional[int] = None, debug_level: int = 0, use_primitives: Optional[bool] = None, **kws: Any, ) -> Any: """ The base method for all `opt_einsum` contractors. :param nodes: A collection of connected nodes. :type nodes: List[tn.Node] :pram algorithm: `opt_einsum` contraction method to use. :type algorithm: Any :param output_edge_order: An optional list of edges. Edges of the final node in `nodes_set` are reordered into `output_edge_order`; if final node has more than one edge, `output_edge_order` must be provided. :type output_edge_order: Optional[Sequence[tn.Edge]], optional :param ignore_edge_order: An option to ignore the output edge order. :type ignore_edge_order: bool :param total_size: The total size of the tensor network. :type total_size: Optional[int], optional :raises ValueError:"The final node after contraction has more than one remaining edge. In this case `output_edge_order` has to be provided," or "Output edges are not equal to the remaining non-contracted edges of the final node." :return: The final node after full contraction. :rtype: tn.Node """ # rewrite tensornetwork default to add logging infras nodes_set = set(nodes) edges = tn.get_all_edges(nodes_set) # output edge order has to be determinded before any contraction # (edges are refreshed after contractions) if not ignore_edge_order: if output_edge_order is None: output_edge_order = list(tn.get_subgraph_dangling(nodes)) if len(output_edge_order) > 1: raise ValueError( "The final node after contraction has more than " "one remaining edge. In this case `output_edge_order` " "has to be provided." ) if set(output_edge_order) != tn.get_subgraph_dangling(nodes): raise ValueError( "output edges are not equal to the remaining " "non-contracted edges of the final node." ) # 1. Resolve topology and check for hyperedges has_hyperedges = any(isinstance(n, tn.CopyNode) for n in nodes) if use_primitives is True or (use_primitives is None and has_hyperedges): # ========================================== # NEW ALGEBRAIC EXECUTION PATH (Opt-in) # ========================================== return _algebraic_base_contraction( nodes, algorithm, output_edge_order, ignore_edge_order, **kws ) # ========================================== # ORIGINAL EXECUTION PATH (100% Backward Compatible) # ========================================== for edge in edges: if not edge.is_disabled: # if its disabled we already contracted it if edge.is_trace(): idx = [i for i, n in enumerate(nodes) if id(n) == id(edge.node1)] nodes = _multi_remove(nodes, idx) nodes.append(tn.contract_parallel(edge)) # nodes_set.remove(edge.node1) # nodes_set.add(tn.contract_parallel(edge)) if len(nodes) == 1: # There's nothing to contract. if ignore_edge_order: return list(nodes)[0] return list(nodes)[0].reorder_edges(output_edge_order) # nodes = list(nodes_set) # Then apply `opt_einsum`'s algorithm # if isinstance(algorithm, list): # path = algorithm # else: path, nodes = _get_path_cache_friendly(nodes, algorithm) if debug_level == 2: # do nothing if output_edge_order: shape = [e.dimension for e in output_edge_order] else: shape = [] return tn.Node(backend.zeros(shape)) logger.info("the contraction path is given as %s" % str(path)) if total_size is None: total_size = sum([_sizen(t) for t in nodes]) for ab in path: if len(ab) < 2: logger.warning("single element tuple in contraction path!") continue a, b = ab if debug_level == 1: from .simplify import pseudo_contract_between new_node = pseudo_contract_between(nodes[a], nodes[b]) else: new_node = tn.contract_between(nodes[a], nodes[b], allow_outer_product=True) nodes.append(new_node) # nodes[a] = backend.zeros([1]) # nodes[b] = backend.zeros([1]) nodes = _multi_remove(nodes, [a, b]) logger.debug(_sizen(new_node, is_log=True)) total_size += _sizen(new_node) logger.info("----- WRITE: %s --------\n" % np.log2(total_size)) # if the final node has more than one edge, # output_edge_order has to be specified final_node = nodes[0] # nodes were connected, we checked this if not ignore_edge_order: final_node.reorder_edges(output_edge_order) return final_node
[docs] class NodesReturn(Exception): """ Intentionally stop execution to return a value. """
[docs] def __init__(self, value_to_return: Any): self.value = value_to_return super().__init__( f"Intentionally stopping execution to return: {value_to_return}" )
def _get_sorted_nodes(nodes: List[Any], *args: Any, **kws: Any) -> Any: nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1)) raise NodesReturn(nodes_new)
[docs] def function_nodes_capture(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: with runtime_contractor(method="before"): try: result = func(*args, **kwargs) return result except NodesReturn as e: return e.value return wrapper
[docs] @contextmanager def runtime_nodes_capture(key: str = "nodes") -> Iterator[Any]: old_contractor = getattr(thismodule, "contractor") set_contractor(method="before") captured_value: Dict[str, List[tn.Node]] = {} try: yield captured_value except NodesReturn as e: captured_value[key] = e.value finally: _set_global_contractor(old_contractor)
[docs] def custom( nodes: List[Any], optimizer: Any, memory_limit: Optional[int] = None, output_edge_order: Optional[List[Any]] = None, ignore_edge_order: bool = False, debug_level: int = 0, use_primitives: Optional[bool] = None, **kws: Any, ) -> Any: local_kws = dict(kws) debug_level = local_kws.pop("debug_level", debug_level) if len(nodes) < 5: alg = opt_einsum.paths.optimal # not good at minimize WRITE actually... return _base( nodes, alg, output_edge_order, ignore_edge_order, debug_level=debug_level, use_primitives=use_primitives, **local_kws, ) total_size = None has_hyperedges = any(isinstance(n, tn.CopyNode) for n in nodes) if local_kws.get("preprocessing", None) and not has_hyperedges: # nodes = _full_light_cone_cancel(nodes) nodes, total_size = _merge_single_gates(nodes) if not isinstance(optimizer, list): alg = partial(optimizer, memory_limit=memory_limit) else: alg = optimizer return _base( nodes, alg, output_edge_order, ignore_edge_order, total_size, debug_level=debug_level, use_primitives=use_primitives, **local_kws, )
[docs] def custom_stateful( nodes: List[Any], optimizer: Any, memory_limit: Optional[int] = None, opt_conf: Optional[Dict[str, Any]] = None, output_edge_order: Optional[List[Any]] = None, ignore_edge_order: bool = False, use_primitives: Optional[bool] = None, **kws: Any, ) -> Any: if opt_conf is None: opt_conf = {} opt = optimizer(**opt_conf) # reinitiate the optimizer each time if kws.get("contraction_info", None): opt = contraction_info_decorator(opt) local_kws = dict(kws) debug_level = local_kws.pop("debug_level", 0) local_kws.pop("contraction_info", None) return custom( nodes, opt, memory_limit=memory_limit, output_edge_order=output_edge_order, ignore_edge_order=ignore_edge_order, debug_level=debug_level, use_primitives=use_primitives, **local_kws, )
# used by custom contractor variants
[docs] def contraction_info_decorator(algorithm: Callable[..., Any]) -> Callable[..., Any]: """Decorator to add contraction information logging to an optimizer. This decorator wraps an optimization algorithm and prints detailed information about the contraction cost (FLOPs, size, write) and path finding time. :param algorithm: The optimization algorithm to decorate. :type algorithm: Callable[..., Any] :return: The decorated optimization algorithm. :rtype: Callable[..., Any] """ from cotengra import ContractionTree def new_algorithm( input_sets: Dict[Any, int], output_set: Dict[Any, int], size_dict: Dict[Any, int], **kws: Any, ) -> Any: t0 = time.time() path = algorithm(input_sets, output_set, size_dict, **kws) path_finding_time = time.time() - t0 tree = ContractionTree.from_path(input_sets, output_set, size_dict, path=path) print("------ contraction cost summary ------") print( "log10[FLOPs]: ", "%.3f" % np.log10(float(tree.total_flops())), " log2[SIZE]: ", "%.0f" % tree.contraction_width(), " log2[WRITE]: ", "%.3f" % np.log2(float(tree.total_write())), " PathFindingTime: ", "%.3f" % path_finding_time, ) return path return new_algorithm
[docs] def set_contractor( method: Optional[str] = None, optimizer: Optional[Any] = None, memory_limit: Optional[int] = None, opt_conf: Optional[Dict[str, Any]] = None, set_global: bool = True, contraction_info: bool = False, debug_level: int = 0, use_primitives: Optional[bool] = None, **kws: Any, ) -> Callable[..., Any]: """ To set runtime contractor of the tensornetwork for a better contraction path. For more information on the usage of contractor, please refer to independent tutorial. :param method: "auto", "greedy", "branch", "plain", "tng", "custom", "custom_stateful". defaults to None ("auto") :type method: Optional[str], optional :param optimizer: Valid for "custom" or "custom_stateful" as method, defaults to None :type optimizer: Optional[Any], optional :param memory_limit: It is not very useful, as ``memory_limit`` leads to ``branch`` contraction instead of ``greedy`` which is rather slow, defaults to None :type memory_limit: Optional[int], optional :raises Exception: Tensornetwork version is too low to support some of the contractors. :raises ValueError: Unknown method options. :return: The new tensornetwork with its contractor set. :rtype: tn.Node """ if not method: method = "greedy" # auto for small size fallbacks to dp, which has bug for now # see: https://github.com/dgasmith/opt_einsum/issues/172 if kws.get("strip_exponent", False): if use_primitives is None: use_primitives = True if method.startswith("cotengra"): # cotengra shortcut import cotengra if method == "cotengra": method = "custom" optimizer = cotengra.ReusableHyperOptimizer( methods=["greedy", "kahypar"], parallel=True, minimize="combo", max_time=30, max_repeats=64, progbar=True, ) else: # "cotengra-30-64" _, mt, mr = method.split("-") method = "custom" optimizer = cotengra.ReusableHyperOptimizer( methods=["greedy", "kahypar"], parallel=True, minimize="combo", max_time=int(mt), max_repeats=int(mr), progbar=True, ) if method == "plain": cf = plain_contractor elif method == "plain-experimental": cf = partial(experimental_contractor, local_steps=kws.get("local_steps", 2)) elif method == "tng": # don't use, deprecated, no guarantee if has_ps: cf = tn_greedy_contractor else: raise Exception( "current version of tensornetwork doesn't support tng contraction" ) elif method == "custom_stateful": cf = custom_stateful # type: ignore cf = partial( # type: ignore cf, optimizer=optimizer, opt_conf=opt_conf, contraction_info=contraction_info, debug_level=debug_level, use_primitives=use_primitives, **kws, ) elif method == "before": # a hack way to get the nodes cf = _get_sorted_nodes else: # cf = getattr(tn.contractors, method, None) # if not cf: # raise ValueError("Unknown contractor type: %s" % method) if method != "custom": optimizer = getattr(opt_einsum.paths, method) else: optimizer = _wrap_omeco_optimizer(optimizer) if contraction_info is True: optimizer = contraction_info_decorator(optimizer) # type: ignore cf = partial( custom, optimizer=optimizer, memory_limit=memory_limit, debug_level=debug_level, use_primitives=use_primitives, **kws, ) if set_global: _set_global_contractor(cf) return cf
set_contractor("greedy", preprocessing=True) get_contractor = partial(set_contractor, set_global=False)
[docs] def set_function_contractor(*confargs: Any, **confkws: Any) -> Callable[..., Any]: """ Function decorate to change function-level contractor :return: _description_ :rtype: Callable[..., Any] """ def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def newf(*args: Any, **kws: Any) -> Any: old_contractor = getattr(thismodule, "contractor") set_contractor(*confargs, **confkws) try: return f(*args, **kws) finally: _set_global_contractor(old_contractor) return newf return wrapper
[docs] @contextmanager def runtime_contractor(*confargs: Any, **confkws: Any) -> Iterator[Any]: """ Context manager to change with-levek contractor :yield: _description_ :rtype: Iterator[Any] """ old_contractor = getattr(thismodule, "contractor") nc = set_contractor(*confargs, **confkws) try: yield nc finally: _set_global_contractor(old_contractor)
[docs] def split_rules( max_singular_values: Optional[int] = None, max_truncation_err: Optional[float] = None, relative: bool = False, ) -> Any: """ Obtain the direcionary of truncation rules :param max_singular_values: The maximum number of singular values to keep. :type max_singular_values: int, optional :param max_truncation_err: The maximum allowed truncation error. :type max_truncation_err: float, optional :param relative: Multiply `max_truncation_err` with the largest singular value. :type relative: bool, optional """ rules: Any = {} if max_singular_values is not None: rules["max_singular_values"] = max_singular_values if max_truncation_err is not None: rules["max_truncattion_err"] = max_truncation_err if relative is not None: rules["relative"] = relative return rules