Source code for tensorcircuit.pauliprop

"""
Pauli Propagation Engine
========================

This module implements a Pauli Propagation Engine (PPE) that tracks the evolution
of Pauli observables through a quantum circuit using a global k-local representation.

Key Features:
    - Tracks all k-local Pauli strings globally, avoiding subset boundary issues.
    - Uses precomputed Neighbor Maps for O(1) fiber lookups.
    - JAX JIT and AD compatible.
    - Naturally handles truncation of terms exceeding locality k.
"""

from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence
import itertools
import logging
import numpy as np

from . import gates
from .cons import backend
from .circuit import Circuit

logger = logging.getLogger(__name__)

Tensor = Any


[docs] class PauliPropagationEngine: r""" A Pauli Propagation Engine that tracks observables in the global k-local space. The state is represented as a flat vector of size $|P_k| = \sum_{i=0}^k \binom{N}{i} 3^i$. .. note:: The engine internally uses single precision (``complex64``) for state coefficients. Since Pauli propagation is an approximate algorithm, the systematic errors introduced by truncation are typically much larger than the numerical errors from single-precision floating point. """
[docs] def __init__(self, N: int, k: int) -> None: """ Initialize the Pauli Propagation Engine. :param N: Total number of qubits. :type N: int :param k: Maximum locality of the Pauli strings to track. :type k: int """ self.N = N self.k = k self._build_basis() self._build_neighbor_map() self._cache_z_indices() # Pauli matrices for PTM construction self.pauli_mats = backend.cast( backend.convert_to_tensor( np.array( [ [[1, 0], [0, 1]], # I [[0, 1], [1, 0]], # X [[0, -1j], [1j, 0]], # Y [[1, 0], [0, -1]], # Z ], dtype=np.complex64, ) ), "complex64", )
def _build_basis(self) -> None: """ Build the canonical list of all k-local Pauli strings. A string is represented as a tuple of ( (q_idx, ...), (pauli_code, ...) ) where codes are 1:X, 2:Y, 3:Z. Identity is implicitly handled. """ # index 0 is always the identity string (empty tuple) self.basis: List[Tuple[Tuple[int, ...], Tuple[int, ...]]] = [((), ())] self.string_to_idx: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], int] = { ((), ()): 0 } for loc in range(1, self.k + 1): for qubits in itertools.combinations(range(self.N), loc): for codes in itertools.product([1, 2, 3], repeat=loc): s = (qubits, codes) self.string_to_idx[s] = len(self.basis) self.basis.append(s) self.dim = len(self.basis) logger.info( f"Initialized Pauli basis with size {self.dim} for N={self.N}, k={self.k}" ) def _build_neighbor_map(self) -> None: """ Precompute the neighbor map. neighbor_map[idx, q, code] = new_idx where code in {0:I, 1:X, 2:Y, 3:Z}. Indices that exceed locality k point to self.dim (a SINK index). """ # We use uint32 for indices to save memory, or int32 if backend preferred # SINK index is self.dim, we add one extra row for it in the map self.neighbor_map_np = np.full( (self.dim + 1, self.N, 4), self.dim, dtype=np.int32 ) for i, (qubits, codes) in enumerate(self.basis): q_to_code = dict(zip(qubits, codes)) for q in range(self.N): # We want to know what happens if we change the base at q # Existing code at q current_code = q_to_code.get(q, 0) for target_code in range(4): if target_code == current_code: self.neighbor_map_np[i, q, target_code] = i continue # Construct new string new_q_to_code = q_to_code.copy() if target_code == 0: del new_q_to_code[q] else: new_q_to_code[q] = target_code if len(new_q_to_code) > self.k: self.neighbor_map_np[i, q, target_code] = self.dim # SINK else: # Canonicalize new_qubits = tuple(sorted(new_q_to_code.keys())) new_codes = tuple(new_q_to_code[q_] for q_ in new_qubits) new_s = (new_qubits, new_codes) self.neighbor_map_np[i, q, target_code] = self.string_to_idx[ new_s ] # SINK row already initialized to point to SINK for all operations # Convert to backend tensor self.neighbor_map = backend.convert_to_tensor( self.neighbor_map_np, dtype="int32" ) def _cache_z_indices(self) -> None: """ Precompute indices of Pauli strings that consist only of I and Z. """ z_indices = [] for i, (_, codes) in enumerate(self.basis): if all(c == 3 for c in codes): z_indices.append(i) self.z_indices_np = np.array(z_indices, dtype=np.int32) self.z_indices = backend.convert_to_tensor(self.z_indices_np, dtype="int32")
[docs] def string_to_code(self, s: Tuple[Tuple[int, ...], Tuple[int, ...]]) -> int: """ Convert a Pauli string representation ((qidx, ...), (opcode, ...)) to its index in the basis. :param s: Pauli string as ((qidx, ...), (opcode, ...)). :type s: Tuple[Tuple[int, ...], Tuple[int, ...]] :return: Index in the basis. :rtype: int """ return self.string_to_idx.get(s, self.dim)
[docs] def get_ptm_1q(self, u: Any) -> Any: u = backend.convert_to_tensor(u) u_dag = backend.conj(backend.transpose(u)) sigmas = self.pauli_mats # (4, 2, 2) tmp = backend.matmul(u_dag, sigmas) rot_sigmas = backend.matmul(tmp, u) res = backend.matmul( backend.conj(backend.reshape(sigmas, [4, 4])), backend.transpose(backend.reshape(rot_sigmas, [4, 4])), ) m = 0.5 * res return backend.real(m)
[docs] def get_ptm_2q(self, u: Any) -> Any: u = backend.convert_to_tensor(u) u_dag = backend.conj(backend.transpose(u)) s1 = self.pauli_mats sigmas_2q = backend.reshape( backend.reshape(s1, [4, 1, 2, 1, 2, 1]) * backend.reshape(s1, [1, 4, 1, 2, 1, 2]), [16, 4, 4], ) tmp = backend.matmul(u_dag, sigmas_2q) rot_sigmas = backend.matmul(tmp, u) res = backend.matmul( backend.conj(backend.reshape(sigmas_2q, [16, 16])), backend.transpose(backend.reshape(rot_sigmas, [16, 16])), ) m = 0.25 * res return backend.real(m)
[docs] def get_initial_state(self, structures: Any, weights: Any) -> Any: """ Initialize the state vector in the k-local Hilbert space from Hamiltonian terms. :param structures: Hamiltonian structure array of shape [n_terms, n_qubits]. Each entry is 0:I, 1:X, 2:Y, 3:Z. :type structures: Any :param weights: Coefficients for each Hamiltonian term. :type weights: Any :return: Initial state vector of size dim+1 (including sink). :rtype: Any .. code-block:: python # Initialize Z0 + Z1 for a 2-qubit system engine = PauliPropagationEngine(N=2, k=2) structures = [[3, 0], [0, 3]] weights = [1.0, 1.0] state = engine.get_initial_state(structures, weights) """ structures = np.array(structures) indices = [] updates = [] for i in range(len(weights)): term = structures[i] w = weights[i] qubits = tuple(np.where(term != 0)[0]) if len(qubits) > self.k: continue codes = tuple(term[q] for q in qubits) s = (qubits, codes) if s in self.string_to_idx: indices.append([self.string_to_idx[s]]) updates.append(w) # Flat state vector + 1 for SINK state = backend.zeros((self.dim + 1,), dtype="complex64") if len(indices) > 0: indices_tensor = backend.convert_to_tensor( np.array(indices, dtype=np.int32), dtype="int32" ) updates_tensor = backend.cast(backend.stack(updates), "complex64") state = backend.scatter(state, indices_tensor, updates_tensor, mode="add") return state
[docs] def expectation(self, state: Any) -> Any: r""" Compute the expectation value $\langle 0| O(t) |0 \rangle$. Sum coefficients of purely Z observables in the final state. :param state: Propagated observable state vector. :type state: Any :return: Real-valued expectation value. :rtype: Any """ z_coeffs = backend.gather1d(state, self.z_indices) return backend.real(backend.sum(z_coeffs))
[docs] def apply_gate( self, state: Any, gate_name: str, wires: Any, params: Any = None ) -> Any: r""" Propagate the observable through a quantum gate in the Heisenberg picture. Applies $O \to U^\dagger O U$. :param state: Current observable state vector. :type state: Any :param gate_name: Name of the gate (e.g., 'rx', 'cnot'). :type gate_name: str :param wires: List of qubit indices the gate acts on. :type wires: Any :param params: Optional gate parameters. :type params: Any :return: Updated observable state vector. :rtype: Any """ gate_name = gate_name.lower() gate_func = getattr(gates, gate_name) if params is not None: if isinstance(params, dict): u = gate_func(**params).tensor else: try: u = gate_func(params).tensor except TypeError: u = gate_func(theta=params).tensor else: u = gate_func().tensor u = backend.convert_to_tensor(u) dim = 2 ** len(wires) u = backend.reshape(u, [dim, dim]) if len(wires) == 1: ptm = self.get_ptm_1q(u) # (4, 4) q = wires[0] # Fiber partitioning for qubit q repr_indices = np.where( self.neighbor_map_np[: self.dim, q, 0] == np.arange(self.dim) )[0] repr_idx_tensor = backend.convert_to_tensor(repr_indices, dtype="int32") # Neighbor indices for these representatives fiber_indices = backend.gather1d( self.neighbor_map[:, q, :], repr_idx_tensor ) def fiber_update(indices: Any) -> Any: # gather values including SINK (which is always 0) vals = backend.gather1d(state, indices) new_vals = backend.tensordot( backend.cast(ptm, vals.dtype), vals, ([1], [0]) ) return new_vals # Apply PTM to all fibers new_vals_all = backend.vmap(fiber_update)(fiber_indices) # Flatten and scatter back state = backend.scatter( state, backend.reshape(fiber_indices, [-1, 1]), backend.reshape(new_vals_all, [-1]), mode="update", ) # Ensure SINK stays 0 sink_idx = backend.convert_to_tensor([[self.dim]], dtype="int32") sink_val = backend.zeros([1], dtype=state.dtype) state = backend.scatter(state, sink_idx, sink_val, mode="update") elif len(wires) == 2: ptm = self.get_ptm_2q(u) # (16, 16) q1, q2 = wires # Representatives have I at both q1 AND q2 # neighbor_map is indexed by (global_idx, qubit, code) # string at glob_idx has I at q1 IF neighbor_map[glob_idx, q1, 0] == glob_idx mask = (self.neighbor_map_np[: self.dim, q1, 0] == np.arange(self.dim)) & ( self.neighbor_map_np[: self.dim, q2, 0] == np.arange(self.dim) ) repr_indices = np.where(mask)[0] repr_idx_tensor = backend.convert_to_tensor(repr_indices, dtype="int32") # To get 16 indices, we use neighbor_map twice # indices at q1 (4) -> for each, indices at q2 (4) -> 16 total fiber4_q1 = backend.gather1d( self.neighbor_map[:, q1, :], repr_idx_tensor ) # (N_repr, 4) # Vmap over the 4 columns for q2 lookup def lookup_q2(idx_q1: Any) -> Any: return backend.gather1d(self.neighbor_map[:, q2, :], idx_q1) fiber16 = backend.vmap(lookup_q2, vectorized_argnums=0)( backend.transpose(fiber4_q1) ) # (4, N_repr, 4) fiber16 = backend.transpose(fiber16, (1, 0, 2)) # (N_repr, 4, 4) fiber16 = backend.reshape(fiber16, [-1, 16]) def fiber_update_16(indices: Any) -> Any: vals = backend.gather1d(state, indices) new_vals = backend.tensordot( backend.cast(ptm, vals.dtype), vals, ([1], [0]) ) return new_vals new_vals_all = backend.vmap(fiber_update_16)(fiber16) state = backend.scatter( state, backend.reshape(fiber16, [-1, 1]), backend.reshape(new_vals_all, [-1]), mode="update", ) sink_idx = backend.convert_to_tensor([[self.dim]], dtype="int32") sink_val = backend.zeros([1], dtype=state.dtype) state = backend.scatter(state, sink_idx, sink_val, mode="update") return state
[docs] def compute_expectation_scan( self, ham_structures: Any, ham_weights: Any, layer_fn: Callable[..., None], params_batch: Any, extra_inputs: Optional[Sequence[Any]] = None, ) -> Any: """ Compute expectation value with JAX scan optimization for layers. :param ham_structures: Hamiltonian structure array. :type ham_structures: Any :param ham_weights: Hamiltonian weights. :type ham_weights: Any :param layer_fn: A function `f(circuit, params, *extra_inputs)` that defines a circuit layer. :type layer_fn: Callable :param params_batch: Batched parameters for the layers. :type params_batch: Any :param extra_inputs: Optional static inputs for the layer function. :type extra_inputs: Optional[Sequence[Any]] :return: Final expectation value. :rtype: Any """ state = self.get_initial_state(ham_structures, ham_weights) def scan_body(state: Any, scan_inputs: Any) -> Any: # scan_inputs might be a tuple or a single tensor depending on backend and extra_inputs if isinstance(scan_inputs, (tuple, list)): p_l = scan_inputs[0] args = scan_inputs[1:] else: p_l = scan_inputs args = () if isinstance(state, (tuple, list)): state = state[0] c_layer = Circuit(self.N) layer_fn(c_layer, p_l, *args) ops = c_layer.to_qir() for op in reversed(ops): gate_name = op["name"] wires = op["index"] p_dict = op.get("parameters", {}) param_val = p_dict if p_dict else None state = self.apply_gate(state, gate_name, wires, param_val) return state scan_inputs = ( params_batch if extra_inputs is None else (params_batch, *extra_inputs) ) # Workaround for general reverse if isinstance(scan_inputs, (tuple, list)): scan_inputs_rev = tuple([backend.reverse(x) for x in scan_inputs]) else: scan_inputs_rev = backend.reverse(scan_inputs) # Prepare inputs for scan consistently as a tuple scan_inputs = ( (params_batch,) if extra_inputs is None else (params_batch, *extra_inputs) ) scan_inputs_rev = tuple([backend.reverse(x) for x in scan_inputs]) if backend.name in ["jax", "tensorflow"]: final_state = backend.scan(scan_body, scan_inputs_rev, state) else: final_state = state # Manual scan for backends like NumPy to avoid abstract_backend inconsistencies for i in range(backend.shape_tuple(scan_inputs_rev[0])[0]): scan_inputs_i = tuple([x[i] for x in scan_inputs_rev]) final_state = scan_body(final_state, scan_inputs_i) return self.expectation(final_state)
[docs] def pauli_propagation( c: Circuit, observable: Any, weights: Optional[Any] = None, k: int = 3, ) -> Any: r""" High-level API for Heisenberg-picture Pauli propagation. :param c: The quantum circuit to propagate through. :type c: Circuit :param observable: The initial observable. Can be: 1. A list of (coeff, pauli_string) pairs, e.g., ``[(1.0, "ZZ"), (-0.5, "XI")]``. 2. A structure array if ``weights`` is also provided. :type observable: Any :param weights: Optional weights if ``observable`` is a structure array. :type weights: Optional[Any] :param k: Maximum locality to track. Defaults to 3. :type k: int :return: Real-valued expectation value $\langle 0 | U^\dagger O U | 0 \rangle$. :rtype: Any .. code-block:: python import tensorcircuit as tc from tensorcircuit.pauliprop import pauli_propagation c = tc.Circuit(2) c.h(0) c.cnot(0, 1) # Compute expectation of Z0 + Z1 obs = [(1.0, "ZI"), (1.0, "IZ")] energy = pauli_propagation(c, obs, k=2) """ N = c._nqubits pp = PauliPropagationEngine(N, k) if weights is not None: state = pp.get_initial_state(observable, weights) elif ( isinstance(observable, list) and len(observable) > 0 and isinstance(observable[0], tuple) ): map_char = {"I": 0, "X": 1, "Y": 2, "Z": 3} num_terms = len(observable) structures = np.zeros((num_terms, N), dtype=int) weights_arr = np.zeros(num_terms, dtype=np.complex64) for i, (coeff, p_str) in enumerate(observable): weights_arr[i] = coeff for j, char in enumerate(p_str): if j < N: structures[i, j] = map_char[char] state = pp.get_initial_state(structures, weights_arr) elif isinstance(observable, tuple) and len(observable) == 2: state = pp.get_initial_state(observable[0], observable[1]) else: raise ValueError("observable format not supported") ops = c.to_qir() for op in reversed(ops): gate_name = op["name"] wires = op["index"] params_dict = op.get("parameters", {}) param_val = params_dict if params_dict else None state = pp.apply_gate(state, gate_name, wires, param_val) return pp.expectation(state)
[docs] class SparsePauliPropagationEngine: """ A Truly Sparse Pauli Propagation Engine that tracks Pauli strings using bitpacked integers. No combinatorial basis is precomputed, making it suitable for hundreds of qubits. .. note:: The engine internally uses single precision (``complex64``) for coefficients, but utilizes ``int64`` bit-packing for encoding Pauli strings. It is highly recommended to call ``tc.set_dtype("complex128")`` when using this engine. While the coefficients remain single-precision, this setting ensures the backend (especially JAX) enables native 64-bit support, which is essential for the ``int64`` encoding and bitwise operations used in the large-scale simulation. .. note:: The numerical precision error of ``complex64`` is significantly smaller than the systematic approximation error from buffer truncation, making double-precision coefficients practically unnecessary for this engine. """
[docs] def __init__(self, N: int, k: int, buffer_size: int = 2000) -> None: self.N = N self.k = k self.buffer_size = buffer_size # Number of int64 words to track 2 bits per qubit self.W = (N + 31) // 32 self.pauli_mats = backend.convert_to_tensor( np.array( [ [[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]], ] ) )
[docs] def get_initial_state(self, structures: Any, weights: Any) -> Any: """ Initialize the sparse state (codes, coefficients). :param structures: Hamiltonian structures [n_terms, n_qubits]. :type structures: Any :param weights: Hamiltonian coefficients. :type weights: Any :return: A tuple of (codes, weights) representing the sparse state. :rtype: Any .. code-block:: python # Initialize Z0 + Z1 for a 100-qubit system engine = SparsePauliPropagationEngine(N=100, k=2, buffer_size=1000) structures = np.zeros((2, 100), dtype=int) structures[0, 0] = 3 # Z structures[1, 1] = 3 # Z weights = [1.0, 1.0] state = engine.get_initial_state(structures, weights) """ K = backend structures = K.convert_to_tensor(structures, dtype="int32") weights = K.convert_to_tensor(weights, dtype="complex64") M = K.shape_tuple(structures)[0] codes = [] for w in range(self.W): word = K.zeros((M,), dtype="int64") for i in range(32): q = w * 32 + i if q < self.N: op = K.cast(structures[:, q], "int64") word = word | (op << (2 * i)) codes.append(word) codes = K.stack(codes, axis=1) # (M, W) num_pad = self.buffer_size - M if num_pad > 0: codes = K.concat([codes, K.zeros((num_pad, self.W), dtype="int64") - 1]) weights = K.concat([weights, K.zeros((num_pad,), dtype="complex64")]) else: codes = codes[: self.buffer_size] weights = weights[: self.buffer_size] return (codes, weights)
[docs] def string_to_code(self, s: Tuple[Tuple[int, ...], Tuple[int, ...]]) -> Any: """ Convert a Pauli string representation ((qidx, ...), (opcode, ...)) to its bit-packed int64 representation. :param s: Pauli string as ((qidx, ...), (opcode, ...)). :type s: Tuple[Tuple[int, ...], Tuple[int, ...]] :return: Bit-packed int64 tensor. :rtype: Any """ K = backend qubits, opcodes = s codes = [] for w in range(self.W): word = 0 for i in range(32): q = w * 32 + i if q in qubits: idx = qubits.index(q) op = opcodes[idx] word = word | (int(op) << (2 * i)) codes.append(word) return K.convert_to_tensor(np.array(codes, dtype=np.int64), dtype="int64")
[docs] def expectation(self, state: Any) -> Any: r""" Compute the expectation value $\langle 0| O(t) |0 \rangle$. Sum coefficients of purely Z observables in the final state. :param state: Propagated sparse observable state (codes, weights). :type state: Any :return: Real-valued expectation value. :rtype: Any """ codes, coeffs = state K = backend is_z = K.ones(K.shape_tuple(codes)[0], dtype="bool") m55 = 0x5555555555555555 for w in range(self.W): word = codes[:, w] low_bits = word & m55 high_bits = (word >> 1) & m55 word_is_z = (low_bits == high_bits) & (word != -1) is_z = is_z & word_is_z return K.real(K.sum(K.where(is_z, coeffs, 0.0)))
def _get_weight(self, codes: Any) -> Any: K = backend total_weight = K.zeros(K.shape_tuple(codes)[0], dtype="int32") m55 = 0x5555555555555555 for w in range(self.W): word = codes[:, w] non_iden = (word & m55) | ((word >> 1) & m55) # Use hardware-accelerated popcount if available via backend total_weight += K.cast(K.popc(non_iden), "int32") return total_weight def _aggregate_and_truncate(self, codes: Any, coeffs: Any) -> Any: K = backend M = K.shape_tuple(coeffs)[0] # Sort by codes to group duplicates sort_idx = K.lexsort([codes[:, i] for i in range(self.W - 1, -1, -1)]) codes_s = codes[sort_idx] coeffs_s = coeffs[sort_idx] # Identify boundaries of unique codes is_diff = K.zeros((M - 1,), dtype="bool") for w in range(self.W): is_diff = is_diff | (codes_s[1:, w] != codes_s[:-1, w]) is_diff = K.concat([K.convert_to_tensor([True], dtype="bool"), is_diff]) # Aggregate coefficients seg_ids = K.cumsum(K.cast(is_diff, "int32")) - 1 unique_coeffs = K.zeros((M,), dtype="complex64") unique_coeffs = K.scatter( unique_coeffs, K.reshape(seg_ids, [-1, 1]), coeffs_s, mode="add" ) # Use aggregated coefficients at boundary positions for truncation agg_coeffs_at_i = unique_coeffs[seg_ids] magnitudes = K.where(is_diff, K.abs(agg_coeffs_at_i), -1.0) # Keep top buffer_size terms _, top_idx = K.top_k(magnitudes, self.buffer_size) final_codes = codes_s[top_idx] final_coeffs = agg_coeffs_at_i[top_idx] # Nullify inactive/empty slots is_active = (K.abs(final_coeffs) > 1e-12) & (final_codes[:, 0] != -1) final_codes = K.where(K.reshape(is_active, [-1, 1]), final_codes, -1) final_coeffs = K.where(is_active, final_coeffs, 0.0) return (final_codes, final_coeffs)
[docs] def apply_gate( self, state: Any, gate_name: str, wires: Any, params: Any = None ) -> Any: r""" Propagate the observable through a quantum gate in the Heisenberg picture. Applies $O \to U^\dagger O U$. :param state: Current sparse observable state (codes, weights). :type state: Any :param gate_name: Name of the gate. :type gate_name: str :param wires: List of qubit indices the gate acts on. :type wires: Any :param params: Optional gate parameters. :type params: Any :return: Updated sparse observable state (codes, weights). :rtype: Any """ indices, coeffs = state K = backend ptm = self._get_ptm(gate_name, wires, params) ptm_complex = K.cast(ptm, "complex64") if len(wires) == 1: q = wires[0] w_idx, b_pos = q // 32, (q % 32) * 2 # 1. Extract current ops (S,) curr_ops = K.where( indices[:, w_idx] != -1, (indices[:, w_idx] >> b_pos) & 3, 0 ) # 2. Get multipliers (S, 4) ptm_t = K.transpose(ptm_complex) multipliers = K.gather1d(ptm_t, curr_ops) # 3. New coeffs (S*4,) flat_coeffs = K.reshape(coeffs[:, None] * multipliers, [-1]) # 4. New indices (S*4, W) mask = ~(K.convert_to_tensor(3, dtype="int64") << b_pos) t_ops = K.convert_to_tensor(np.arange(4), dtype="int64") new_words = (indices[:, w_idx, None] & mask) | (t_ops[None, :] << b_pos) flat_new_words = K.reshape(new_words, [-1]) cols = [] for i in range(self.W): # Use K.repeat for correct alignment during expansion [T0, T1] -> [T0, T0, T0, T0, T1, T1, T1, T1] col = K.where( i == w_idx, flat_new_words, K.repeat(indices[:, i], 4, axis=0) ) cols.append(col) flat_indices = K.stack(cols, axis=1) if self.k < self.N: w = self._get_weight(flat_indices) flat_coeffs = K.where(w <= self.k, flat_coeffs, 0.0) return self._aggregate_and_truncate(flat_indices, flat_coeffs) elif len(wires) == 2: q1, q2 = wires w1, b1 = q1 // 32, (q1 % 32) * 2 w2, b2 = q2 // 32, (q2 % 32) * 2 o1 = K.where(indices[:, w1] != -1, (indices[:, w1] >> b1) & 3, 0) o2 = K.where(indices[:, w2] != -1, (indices[:, w2] >> b2) & 3, 0) curr_ops12 = o1 * 4 + o2 ptm_t = K.transpose(ptm_complex) multipliers = K.gather1d(ptm_t, curr_ops12) flat_coeffs = K.reshape(coeffs[:, None] * multipliers, [-1]) t12 = K.convert_to_tensor(np.arange(16), dtype="int64") t1, t2 = t12 // 4, t12 % 4 mask1, mask2 = ~(K.convert_to_tensor(3, dtype="int64") << b1), ~( K.convert_to_tensor(3, dtype="int64") << b2 ) # All possible target words regardless of w1==w2 new_words1 = (indices[:, w1, None] & mask1) | (t1[None, :] << b1) new_words2 = (indices[:, w2, None] & mask2) | (t2[None, :] << b2) new_words_both = ( (indices[:, w1, None] & mask1 & mask2) | (t1[None, :] << b1) | (t2[None, :] << b2) ) flat_w1 = K.reshape(new_words1, [-1]) flat_w2 = K.reshape(new_words2, [-1]) flat_both = K.reshape(new_words_both, [-1]) cols = [] for i in range(self.W): # Use K.repeat for correct alignment [T0, T1] -> [T0... (16 times), T1... (16 times)] target_word = K.repeat(indices[:, i], 16, axis=0) # Correctly handle overlapping vs distinct word indices target_word = K.where( i == w1, K.where(w1 == w2, flat_both, flat_w1), target_word ) target_word = K.where((i == w2) & (w1 != w2), flat_w2, target_word) cols.append(target_word) flat_indices = K.stack(cols, axis=1) if self.k < self.N: w = self._get_weight(flat_indices) flat_coeffs = K.where(w <= self.k, flat_coeffs, 0.0) return self._aggregate_and_truncate(flat_indices, flat_coeffs) return state
def _get_ptm(self, gate_name: str, wires: Sequence[int], params: Any) -> Any: K = backend gate_func = getattr(gates, gate_name.lower()) if params is None: params = {} if isinstance(params, dict): u = gate_func(**params).tensor else: try: u = gate_func(params).tensor except TypeError: u = gate_func(theta=params).tensor u = K.convert_to_tensor(u) u = K.reshape(u, [2 ** len(wires), 2 ** len(wires)]) u_dag = K.conj(K.transpose(u)) s1 = self.pauli_mats if len(wires) == 1: rot = K.matmul(K.matmul(u_dag, s1), u) res = K.matmul( K.conj(K.reshape(s1, [4, 4])), K.transpose(K.reshape(rot, [4, 4])) ) return K.real(0.5 * res) else: s2 = K.reshape( K.reshape(s1, [4, 1, 2, 1, 2, 1]) * K.reshape(s1, [1, 4, 1, 2, 1, 2]), [16, 4, 4], ) rot = K.matmul(K.matmul(u_dag, s2), u) res = K.matmul( K.conj(K.reshape(s2, [16, 16])), K.transpose(K.reshape(rot, [16, 16])) ) return K.real(0.25 * res)
[docs] def compute_expectation_scan( self, structures: Any, weights: Any, layer: Callable[[Any, Any], None], params: Any, ) -> Any: """ Compute expectation value with JAX scan optimization for layers. :param structures: Initial Hamiltonian structures. :type structures: Any :param weights: Initial Hamiltonian weights. :type weights: Any :param layer: A function `f(circuit, params)` that defines a circuit layer. :type layer: Callable :param params: Batched parameters for the layers. :type params: Any :return: Final expectation value. :rtype: Any """ K = backend state = self.get_initial_state(structures, weights) def step(s: Any, p: Any) -> Any: c = Circuit(self.N) layer(c, p) for op in reversed(c.to_qir()): s = self.apply_gate(s, op["name"], op["index"], op.get("parameters")) return s final_state = K.scan(step, params[::-1], state) return self.expectation(final_state)