Source code for tensorcircuit.zx.evaluator

"""
Evaluation of compiled scalar graphs using exact arithmetic.
"""

import functools
from typing import Any, NamedTuple, Tuple, cast

import jax
import jax.numpy as jnp
from jax import Array, lax

from ..cons import dtypestr, idtypestr

# ==============================================================================
# Exact Scalar arithmetic
# ==============================================================================

_E4 = jnp.exp(1j * jnp.pi / 4)  # e^(i*pi/4)
_E4D = jnp.exp(-1j * jnp.pi / 4)  # e^(-i*pi/4)


@jax.jit
def _scalar_mul(d1: jax.Array, d2: jax.Array) -> jax.Array:
    """Multiply two exact scalar coefficient arrays."""
    a1, b1, c1, d1_coeff = d1[..., 0], d1[..., 1], d1[..., 2], d1[..., 3]
    a2, b2, c2, d2_coeff = d2[..., 0], d2[..., 1], d2[..., 2], d2[..., 3]

    A = a1 * a2 + b1 * d2_coeff - c1 * c2 + d1_coeff * b2
    B = a1 * b2 + b1 * a2 + c1 * d2_coeff + d1_coeff * c2
    C = a1 * c2 + b1 * b2 + c1 * a2 - d1_coeff * d2_coeff
    D = a1 * d2_coeff - b1 * c2 - c1 * b2 + d1_coeff * a2

    return jnp.stack([A, B, C, D], axis=-1).astype(d1.dtype)


def _scalar_to_complex(data: jax.Array) -> jax.Array:
    """Convert a (N, 4) array of coefficients to a (N,) array of complex numbers."""
    e4 = jnp.exp(jnp.array(1j * jnp.pi / 4, dtype=dtypestr))
    e4d = jnp.exp(jnp.array(-1j * jnp.pi / 4, dtype=dtypestr))
    return (
        data[..., 0].astype(dtypestr)
        + data[..., 1].astype(dtypestr) * e4
        + data[..., 2].astype(dtypestr) * 1j
        + data[..., 3].astype(dtypestr) * e4d
    )


[docs] class ExactScalarArray(NamedTuple): """Exact scalar array for ZX-calculus phase arithmetic using dyadic representation. Represents values of the form (c_0 + c_1·ω + c_2·ω² + c_3·ω³) × 2^power where ω = e^(iπ/4). """ coeffs: Array power: Array
[docs] @classmethod def create(cls, coeffs: Array, power: Array | None = None) -> "ExactScalarArray": if power is None: power = jnp.zeros(coeffs.shape[:-1], dtype=idtypestr) return cls(coeffs, power)
def __mul__(self, other: "ExactScalarArray") -> "ExactScalarArray": # type: ignore[override] new_coeffs = _scalar_mul(self.coeffs, other.coeffs) new_power = self.power + other.power return ExactScalarArray(new_coeffs, new_power)
[docs] def reduce(self) -> "ExactScalarArray": def cond_fun(carry: Tuple[Array, Array]) -> Any: coeffs, _ = carry reducible = jnp.all(coeffs % 2 == 0, axis=-1) & jnp.any( coeffs != 0, axis=-1 ) return jnp.any(reducible) def body_fun(carry: Tuple[Array, Array]) -> Tuple[Array, Array]: coeffs, power = carry reducible = jnp.all(coeffs % 2 == 0, axis=-1) & jnp.any( coeffs != 0, axis=-1 ) coeffs = jnp.where(reducible[..., None], coeffs // 2, coeffs) power = jnp.where(reducible, power + 1, power) return coeffs, power new_coeffs, new_power = jax.lax.while_loop( cond_fun, body_fun, (self.coeffs, self.power) ) return ExactScalarArray(new_coeffs, new_power)
[docs] def sum(self) -> "ExactScalarArray": min_power = jnp.min(self.power, keepdims=True, axis=-1) pow = (self.power - min_power)[..., None] aligned_coeffs = self.coeffs * 2**pow summed_coeffs = jnp.sum(aligned_coeffs, axis=-2) return ExactScalarArray(summed_coeffs, min_power.squeeze(-1))
[docs] def prod(self, axis: int = -1) -> "ExactScalarArray": if axis < 0: axis = self.coeffs.ndim + axis if self.coeffs.shape[axis] == 0: coeffs_shape = self.coeffs.shape[:axis] + self.coeffs.shape[axis + 1 :] result_coeffs = jnp.zeros(coeffs_shape, dtype=self.coeffs.dtype) result_coeffs = result_coeffs.at[..., 0].set(1) power_shape = self.power.shape[:axis] + self.power.shape[axis + 1 :] result_power = jnp.zeros(power_shape, dtype=self.power.dtype) return ExactScalarArray(result_coeffs, result_power) # Move the reduction axis to position 0 for sequential multiplication coeffs_t = jnp.moveaxis(self.coeffs, axis, 0) def body_fn(carry: Array, x: Array) -> Tuple[Array, Any]: return _scalar_mul(carry, x), None result_coeffs, _ = lax.scan(body_fn, coeffs_t[0], coeffs_t[1:]) result_power = jnp.sum(self.power, axis=axis) return ExactScalarArray(result_coeffs, result_power)
[docs] def to_complex(self) -> jax.Array: """ Convert the exact scalar to a complex JAX array. :return: Complex representation of the scalar. :rtype: jax.Array """ c_val = _scalar_to_complex(self.coeffs) scale = jnp.pow(2.0, self.power) return c_val * scale
# ============================================================================== # Evaluation logic # ============================================================================== # Lookup table for exact scalars (1 + omega^k) # These will be cast to idtypestr when used in evaluate _ONE_PLUS_PHASES = None _UNIT_PHASES_BASE = jnp.array( [ [1, 0, 0, 0], # omega^0 = 1 [0, 1, 0, 0], # omega^1 [0, 0, 1, 0], # omega^2 = i [0, 0, 0, -1], # omega^3 [-1, 0, 0, 0], # omega^4 = -1 [0, -1, 0, 0], # omega^5 [0, 0, -1, 0], # omega^6 = -i [0, 0, 0, 1], # omega^7 ], dtype=idtypestr, # Use idtypestr to stay consistent with precision settings ) def _get_lookup_tables() -> Tuple[Array, Array, Array]: unit_phases = _UNIT_PHASES_BASE.astype(idtypestr) one_plus_phases = unit_phases.at[:, 0].add(1) identity = jnp.array([1, 0, 0, 0], dtype=idtypestr) return unit_phases, one_plus_phases, identity def _matmul_gf2(a: Array, b: Array) -> Array: G, T, _ = a.shape if G * T == 0: return jnp.zeros((b.shape[0], G, T), dtype=b.dtype) return (b.astype(jnp.float32) @ a.astype(jnp.float32).reshape(G * T, -1).T).reshape( -1, G, T ).astype(jnp.uint8) % 2
[docs] @jax.jit def evaluate(circuit: Any, param_vals: Array) -> Array: """ Evaluate a compiled scalar graph circuit with batched parameter values. :param circuit: The compiled scalar graph program. :type circuit: Any :param param_vals: Array of parameter bit values (f-basis and measurement records). :type param_vals: Array :return: Evaluation results as a complex JAX array. :rtype: Array """ # ==================================================================== # TYPE A: Node Terms (1 + e^(i*alpha)) # Padded values are masked to multiplicative identity. # ==================================================================== # a_param_bits: (num_graphs, max_a, n_params), param_vals: (batch_size, n_params,) unit_phases, one_plus_phases, identity = _get_lookup_tables() rowsum_a = _matmul_gf2(circuit.a_param_bits, param_vals) phase_idx_a = (4 * rowsum_a + circuit.a_const_phases) % 8 term_vals_a_exact = one_plus_phases[phase_idx_a] a_mask = ( jnp.arange(circuit.a_const_phases.shape[1])[None, :] < circuit.a_num_terms[:, None] ) term_vals_a_exact = jnp.where(a_mask[..., None], term_vals_a_exact, identity) term_vals_a = ExactScalarArray.create(term_vals_a_exact) summands_a = term_vals_a.prod(axis=-2) # ==================================================================== # TYPE B: Half-Pi Terms (e^(i*beta)) # Padded values are 0, so they don't affect the sum. # ==================================================================== rowsum_b = _matmul_gf2(circuit.b_param_bits, param_vals) phase_idx_b = (rowsum_b * circuit.b_term_types) % 8 sum_phases_b = jnp.sum(phase_idx_b, axis=-1) % 8 summands_b_exact = unit_phases[sum_phases_b] summands_b = ExactScalarArray.create(summands_b_exact) # ==================================================================== # TYPE C: Pi-Pair Terms, (-1)^(Psi*Phi) # ==================================================================== rowsum_a_c = ( circuit.c_const_bits_a + _matmul_gf2(circuit.c_param_bits_a, param_vals) ) % 2 rowsum_b_c = ( circuit.c_const_bits_b + _matmul_gf2(circuit.c_param_bits_b, param_vals) ) % 2 exponent_c = (rowsum_a_c * rowsum_b_c) % 2 sum_exponents_c = jnp.sum(exponent_c, axis=-1) % 2 summands_c_exact = (1 - 2 * sum_exponents_c)[..., None] * identity summands_c = ExactScalarArray.create(summands_c_exact) # ==================================================================== # TYPE D: Phase Pairs (1 + e^a + e^b - e^g) # Padded values are masked to multiplicative identity. # ==================================================================== rowsum_a_d = _matmul_gf2(circuit.d_param_bits_a, param_vals) rowsum_b_d = _matmul_gf2(circuit.d_param_bits_b, param_vals) alpha = (circuit.d_const_alpha + rowsum_a_d * 4) % 8 beta = (circuit.d_const_beta + rowsum_b_d * 4) % 8 gamma = (alpha + beta) % 8 term_vals_d_exact = ( identity + unit_phases[alpha] + unit_phases[beta] - unit_phases[gamma] ) d_mask = ( jnp.arange(circuit.d_const_alpha.shape[1])[None, :] < circuit.d_num_terms[:, None] ) term_vals_d_exact = jnp.where(d_mask[..., None], term_vals_d_exact, identity) term_vals_d = ExactScalarArray.create(term_vals_d_exact) summands_d = term_vals_d.prod(axis=-2) # ==================================================================== # FINAL COMBINATION # ==================================================================== static_phases = ExactScalarArray.create(unit_phases[circuit.phase_indices]) float_factor = ExactScalarArray.create(circuit.floatfactor) total_summands = functools.reduce( lambda a, b: a * b, [summands_a, summands_b, summands_c, summands_d, static_phases, float_factor], ) def res_exact() -> Array: ts = ExactScalarArray( total_summands.coeffs, total_summands.power + circuit.power2 ) ts = ts.reduce() return ts.sum().to_complex() def res_approx() -> Array: return jnp.sum( total_summands.to_complex() * circuit.approximate_floatfactors * 2.0**circuit.power2, axis=-1, ) return cast( Array, lax.cond(circuit.has_approximate_floatfactors, res_approx, res_exact) )