Source code for tensorcircuit.zx.utils

"""
Utility functions for ZX-calculus and stabilizer decomposition.
"""

from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Any, Sequence, List
import numpy as np
from pyzx_param.graph.graph_s import GraphS as Graph


[docs] def find_basis(vectors: Any) -> tuple[Any, Any]: """ Decompose a set of binary vectors into a basis subset and a transformation matrix over GF(2). :param vectors: Input binary vectors (2D array-like). :type vectors: Any :return: Tuple of (basis_vectors, transformation_matrix). :rtype: tuple[Any, Any] """ vecs = np.array(vectors, dtype=np.uint8) num_vectors, _ = vecs.shape basis_indices: List[int] = [] reduced_basis: List[Any] = [] pivots: List[int] = [] basis_expansion: List[Any] = [] t_rows: List[Any] = [] for i in range(num_vectors): v = vecs[i].copy() coeffs = [] for j, b in enumerate(reduced_basis): if v[pivots[j]]: v ^= b coeffs.append(j) is_independent = np.any(v) current_rank = len(basis_indices) new_size = current_rank + 1 if is_independent else current_rank dep_sum = np.zeros(new_size, dtype=np.uint8) for idx in coeffs: e = basis_expansion[idx] dep_sum[: len(e)] ^= e if is_independent: basis_indices.append(i) reduced_basis.append(v) pivots.append(int(np.argmax(v))) dep_sum[current_rank] = 1 basis_expansion.append(dep_sum) t_row = np.zeros(new_size, dtype=np.uint8) t_row[current_rank] = 1 t_rows.append(t_row) else: t_rows.append(dep_sum) rank = len(basis_indices) transform = np.zeros((num_vectors, rank), dtype=np.uint8) for i, row in enumerate(t_rows): transform[i, : len(row)] = row return vecs[basis_indices], transform
[docs] @dataclass class ConnectedComponent: """A connected subgraph with its associated output indices.""" graph: Any output_indices: list[int]
[docs] def connected_components(g: Any) -> list[ConnectedComponent]: """ Return each connected component of ``g`` as its own ZX subgraph. :param g: The input ZX graph. :type g: Any :return: List of ConnectedComponent objects. :rtype: list[ConnectedComponent] """ components: list[ConnectedComponent] = [] visited: set[Any] = set() outputs = tuple(g.outputs()) output_indices = {vertex: idx for idx, vertex in enumerate(outputs)} for vertex in list(g.vertices()): if vertex in visited: continue component_vertices = _collect_vertices(g, vertex, visited) subgraph, _ = _induced_subgraph(g, component_vertices) component_output_indices = [ output_indices[v] for v in component_vertices if v in output_indices ] component_output_indices.sort() components.append( ConnectedComponent( graph=subgraph, output_indices=component_output_indices, ) ) return components
def _collect_vertices( g: Any, start: Any, visited: set[Any], ) -> list[Any]: queue: deque[Any] = deque([start]) component: list[Any] = [] while queue: vertex = queue.pop() if vertex in visited: continue visited.add(vertex) component.append(vertex) for neighbor in g.neighbors(vertex): if neighbor not in visited: queue.appendleft(neighbor) return component def _induced_subgraph( g: Any, vertices: Sequence[Any], ) -> tuple[Any, dict[Any, Any]]: subgraph = Graph() subgraph.track_phases = g.track_phases subgraph.merge_vdata = g.merge_vdata vert_map: dict[Any, Any] = {} phases = g.phases() qubits = g.qubits() rows = g.rows() types = g.types() get_params_fn = getattr(g, "get_params", None) for vertex in vertices: params = None if get_params_fn is not None: params = set(get_params_fn(vertex)) new_vertex = subgraph.add_vertex( types[vertex], qubit=qubits.get(vertex, -1), row=rows.get(vertex, -1), phase=phases.get(vertex, 0), phaseVars=params, ) for key in g.vdata_keys(vertex): subgraph.set_vdata(new_vertex, key, g.vdata(vertex, key)) vert_map[vertex] = new_vertex added_edges: set[tuple[Any, Any]] = set() for vertex in vertices: for neighbor in g.neighbors(vertex): if neighbor not in vert_map: continue edge = g.edge(vertex, neighbor) if edge in added_edges: continue added_edges.add(edge) subgraph.add_edge((vert_map[vertex], vert_map[neighbor]), g.edge_type(edge)) component_inputs = tuple(vert_map[v] for v in g.inputs() if v in vert_map) component_outputs = tuple(vert_map[v] for v in g.outputs() if v in vert_map) subgraph.set_inputs(component_inputs) subgraph.set_outputs(component_outputs) return subgraph, vert_map
[docs] def get_params(g: Any) -> set[str]: """ Get all parameter variables that appear in the graph and its scalar. :param g: The ZX graph to inspect. :type g: Any :return: A set of parameter names. :rtype: set[str] """ active: set[str] = set() for v in g.vertices(): active |= g._phaseVars.get(v, set()) scalar = g.scalar active |= getattr(scalar, "phasevars_pi", set()) for pair in getattr(scalar, "phasevars_pi_pair", []): for var_set in pair: active |= var_set for coeff in getattr(scalar, "phasevars_halfpi", {}): for var_set in scalar.phasevars_halfpi[coeff]: active |= var_set for spider_pair in getattr(scalar, "phasepairs", []): active |= spider_pair.paramsA active |= spider_pair.paramsB for var_set in getattr(scalar, "phasenodevars", []): active |= var_set return active