"""Pauli noise channels and error sampling infrastructure."""
from __future__ import annotations
import functools
from collections import defaultdict
from typing import Any
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import numpy as np
[docs]
@dataclass
class Channel:
"""A probability distribution over error outcomes.
Attributes:
probs: Shape (2^k,) probability array, sums to 1, dtype float64
unique_col_ids: Tuple of column IDs, where each ID corresponds to a bit of the channel.
"""
probs: Any
unique_col_ids: tuple[int, ...]
@property
def num_bits(self) -> int:
"""Number of bits in the channel (k where probs has shape 2^k)."""
return int(np.log2(len(self.probs)))
[docs]
def error_probs(p: float) -> Any:
"""
Single-bit error channel probability distribution.
:param p: Error probability.
:type p: float
:return: Array [1-p, p] of shape (2,).
:rtype: np.ndarray
"""
return np.array([1 - p, p], dtype=np.float64)
[docs]
def pauli_channel_1_probs(px: float, py: float, pz: float) -> Any:
"""
Single-qubit Pauli channel probability distribution.
:param px: X error probability.
:type px: float
:param py: Y error probability.
:type py: float
:param pz: Z error probability.
:type pz: float
:return: Array [I, Z, X, Y] of shape (4,).
:rtype: np.ndarray
"""
return np.array([1 - px - py - pz, pz, px, py], dtype=np.float64)
[docs]
def pauli_channel_2_probs(
pix: float,
piy: float,
piz: float,
pxi: float,
pxx: float,
pxy: float,
pxz: float,
pyi: float,
pyx: float,
pyy: float,
pyz: float,
pzi: float,
pzx: float,
pzy: float,
pzz: float,
) -> Any:
"""
Two-qubit Pauli channel probability distribution.
:param pix: Probability of error I on qubit 1 and X on qubit 2.
...
:return: Array of shape (16,).
:rtype: np.ndarray
"""
remainder = (
1
- pix
- piy
- piz
- pxi
- pxx
- pxy
- pxz
- pyi
- pyx
- pyy
- pyz
- pzi
- pzx
- pzy
- pzz
)
probs = np.array(
[
remainder, # 00,00
pzi, # 10,00
pxi, # 01,00
pyi, # 11,00
piz, # 00,10
pzz, # 10,10
pxz, # 01,10
pyz, # 11,10
pix, # 00,01
pzx, # 10,01
pxx, # 01,01
pyx, # 11,01
piy, # 00,11
pzy, # 10,11
pxy, # 01,11
pyy, # 11,11
],
dtype=np.float64,
)
return probs
[docs]
def xor_convolve(probs_a: Any, probs_b: Any) -> Any:
"""
XOR convolution of two probability distributions.
:param probs_a: Probability distribution of channel A.
:type probs_a: np.ndarray
:param probs_b: Probability distribution of channel B.
:type probs_b: np.ndarray
:return: Combined probability distribution.
:rtype: np.ndarray
"""
n = len(probs_a)
if len(probs_b) != n:
raise ValueError("Both channels must have same number of outcomes")
# NOTE: The convolution could be done in O(n*log(n)) using Walsh-Hadamard transform.
# But since probability arrays are usually limited to <=16 entries, this is not
# worth the complexity.
result = np.zeros(n, dtype=np.float64)
for a in range(n):
for b in range(n):
o = a ^ b
result[o] += probs_a[a] * probs_b[b]
return result
[docs]
def reduce_null_bits(
channels: list[Channel], null_col_id: int | None = None
) -> list[Channel]:
"""Remove bits corresponding to the null column (all-zero column).
If a channel has bits mapped to null_col_id (representing an all-zero
column in the transform matrix), those bits don't affect any f-variable
and can be marginalized out by summing over them.
Args:
channels: List of channels
null_col_id: Column ID representing the all-zero column, or None if
there is no all-zero column.
Returns:
List of channels with null bits marginalized out. Channels with all
null entries are removed entirely (they have no effect on outputs).
"""
if null_col_id is None:
# No null column, nothing to reduce
return channels
result: list[Channel] = []
for channel in channels:
n = channel.num_bits
non_null_positions = [
i
for i, col_id in enumerate(channel.unique_col_ids)
if col_id != null_col_id
]
if len(non_null_positions) == 0:
# All entries are null, channel has no effect - remove it
continue
# Marginalize out the null bits by summing over them
new_col_ids = tuple(channel.unique_col_ids[i] for i in non_null_positions)
new_num_bits = len(non_null_positions)
sum_axes = tuple(i for i in range(n) if i not in non_null_positions)
probs_tensor = channel.probs.reshape((2,) * n, order="F")
new_probs = probs_tensor.sum(axis=sum_axes).reshape(2**new_num_bits, order="F")
result.append(Channel(probs=new_probs, unique_col_ids=new_col_ids))
return result
[docs]
def normalize_channels(channels: list[Channel]) -> list[Channel]:
"""Normalize channels by sorting unique_col_ids, permuting probs accordingly.
This ensures channels affecting the same set of columns have identical
unique_col_ids tuples, enabling merge_identical_channels to group them.
Args:
channels: List of channels
Returns:
List of channels with sorted unique_col_ids
"""
result: list[Channel] = []
for channel in channels:
n = channel.num_bits
source_col_ids = np.array(channel.unique_col_ids)
axis_perm = np.argsort(source_col_ids, kind="stable")
probs_tensor = channel.probs.reshape((2,) * n, order="F")
new_probs = probs_tensor.transpose(axis_perm).reshape(2**n, order="F")
result.append(
Channel(probs=new_probs, unique_col_ids=tuple(source_col_ids[axis_perm]))
)
return result
[docs]
def expand_channel(channel: Channel, target_col_ids: tuple[int, ...]) -> Channel:
"""Expand a channel's distribution to a larger signature set.
The channel's existing col_ids must be a strict subset of target_col_ids.
Both must be sorted. New bit positions are treated as "don't care" (always 0).
Args:
channel: Channel to expand (must have sorted unique_col_ids)
target_col_ids: Target signature set (must be sorted superset)
Returns:
New channel with expanded distribution
"""
source_col_ids = channel.unique_col_ids
assert source_col_ids == tuple(sorted(source_col_ids)), "Source must be sorted"
assert target_col_ids == tuple(sorted(target_col_ids)), "Target must be sorted"
assert set(source_col_ids) < set(target_col_ids), "Source must be strict subset"
# Map source columns to their positions in target
source_to_target = {s: target_col_ids.index(s) for s in source_col_ids}
n_target = len(target_col_ids)
new_probs = np.zeros(2**n_target, dtype=np.float64)
for old_idx in range(len(channel.probs)):
# Map old bit pattern to new bit pattern (new bits stay 0)
new_idx = 0
for src_pos, src_col in enumerate(source_col_ids):
if (old_idx >> src_pos) & 1:
new_idx |= 1 << source_to_target[src_col]
new_probs[new_idx] += channel.probs[old_idx]
return Channel(probs=new_probs, unique_col_ids=target_col_ids)
[docs]
def merge_identical_channels(channels: list[Channel]) -> list[Channel]:
"""Merge all channels with identical signature sets.
Groups channels by their unique_col_ids and convolves all channels
in each group into a single channel.
Args:
channels: List of channels
Returns:
List with at most one channel per unique signature set
"""
groups: dict[tuple[int, ...], list[Channel]] = defaultdict(list)
for channel in channels:
key = channel.unique_col_ids
groups[key].append(channel)
result: list[Channel] = []
for col_ids, group in groups.items():
if len(group) == 1:
result.append(group[0])
else:
# Convolve all channels in the group
combined_probs = group[0].probs.copy()
for channel in group[1:]:
combined_probs = xor_convolve(combined_probs, channel.probs)
result.append(Channel(probs=combined_probs, unique_col_ids=col_ids))
return result
[docs]
def absorb_subset_channels(channels: list[Channel], max_bits: int = 4) -> list[Channel]:
"""Absorb channels whose signatures are subsets of others.
If channel A's signatures are a strict subset of channel B's signatures,
and |B| <= max_bits, then A is absorbed into B.
Args:
channels: List of channels
max_bits: Maximum number of bits allowed per channel
Returns:
List with no channel being a strict subset of another
"""
# Sort by number of bits (largest first) for efficient processing
channels = sorted(channels, key=lambda c: -len(c.unique_col_ids))
result: list[Channel] = []
absorbed: set[int] = set()
for i, channel_i in enumerate(channels):
if i in absorbed:
continue
set_i = set(channel_i.unique_col_ids)
# Try to absorb smaller channels into this one
current_probs = channel_i.probs.copy()
current_col_ids = channel_i.unique_col_ids
for j, channel_j in enumerate(channels):
if j <= i or j in absorbed:
continue
set_j = set(channel_j.unique_col_ids)
# Check if j is a strict subset of i
if set_j < set_i and len(set_i) <= max_bits:
# Expand channel_j to match channel_i's signatures and convolve
expanded_j = expand_channel(channel_j, current_col_ids)
current_probs = xor_convolve(current_probs, expanded_j.probs)
absorbed.add(j)
result.append(Channel(probs=current_probs, unique_col_ids=current_col_ids))
return result
[docs]
def simplify_channels(
channels: list[Channel], max_bits: int = 4, null_col_id: int | None = None
) -> list[Channel]:
"""
Simplify a list of channels by reducing null bits, merging identical signatures, and absorbing subsets.
:param channels: List of input channels.
:type channels: list[Channel]
:param max_bits: Maximum number of bits allowed per simplified channel, defaults to 4.
:type max_bits: int, optional
:param null_col_id: ID of the all-zero column in the error transform, defaults to None.
:type null_col_id: int, optional
:return: Simplified list of channels.
:rtype: list[Channel]
"""
channels = reduce_null_bits(channels, null_col_id)
channels = normalize_channels(channels)
channels = merge_identical_channels(channels)
channels = absorb_subset_channels(channels, max_bits)
return channels
[docs]
class ChannelSampler:
"""Samples from multiple error channels and transforms to a reduced basis.
This class combines multiple error channels (each producing error bits e0, e1, ...)
and applies a linear transformation over GF(2) to convert samples from the original
"e" basis to a reduced "f" basis using geometric-skip sampling optimized for
low-noise regimes.
f_i = error_transform_ij * e_j mod 2
Channels are automatically simplified by:
1. Removing bits e_i that do not affect any f-variable (i.e. all-zero columns in error_transform)
2. Merging channels with identical column signatures, i.e. channels whose corresponding
columns in error_transform are identical.
3. Absorbing channels whose signatures are subsets of others, i.e. channels whose corresponding
columns in error_transform are a strict subset of another channel's columns.
Example:
>>> probs = [error_probs(0.1), error_probs(0.2)] # two 1-bit channels
>>> transform = np.array([[1, 1]]) # f0 = e0 XOR e1
>>> sampler = ChannelSampler(probs, transform)
>>> samples = sampler.sample(1000) # shape (1000, 1)
"""
[docs]
def __init__(
self,
channel_probs: list[Any],
error_transform: Any,
seed: int | None = None,
):
"""Initialize the sampler with channel probabilities and a basis transformation.
Args:
channel_probs: List of probability arrays. Channel i has shape (2^k_i,)
and produces k_i error bits starting at index sum(k_0:k_{i-1}).
For example, if channels have shapes [(4,), (2,), (4,)], they
produce variables [e0,e1], [e2], [e3,e4].
error_transform: Binary matrix of shape (num_f, num_e) where entry [i, j] = 1
means f_i depends on e_j. For example, if row 0 is [0, 1, 0, 1],
then f0 = e1 XOR e3.
seed: Random seed for sampling. If None, a random seed is generated.
"""
unique_cols, inverse = np.unique(error_transform, axis=1, return_inverse=True)
# Signature matrix: each row is a unique column signature
signature_matrix = unique_cols.T # shape (num_signatures, num_f)
# Find null_col_id: the index of the all-zero column (or None)
zero_col_indices = np.flatnonzero(np.all(unique_cols == 0, axis=0))
null_col_id = int(zero_col_indices[0]) if len(zero_col_indices) else None
# Create Channel objects with unique_col_ids from inverse mapping
channels: list[Channel] = []
e_offset = 0
for probs in channel_probs:
num_bits = int(np.log2(len(probs)))
col_ids = tuple(int(inverse[e_offset + i]) for i in range(num_bits))
channels.append(Channel(probs=probs, unique_col_ids=col_ids))
e_offset += num_bits
self.channels = simplify_channels(channels, null_col_id=null_col_id)
self.signature_matrix = signature_matrix.astype(np.uint8)
self._rng = np.random.default_rng(
seed if seed is not None else np.random.default_rng().integers(0, 2**30)
)
self._sparse_data = self._precompute_sparse(
self.channels, self.signature_matrix
)
self._jax_data = self._precompute_jax(self.channels, self.signature_matrix)
@property
def num_f_params(self) -> int:
"""
Number of reduced basis parameters (f-basis).
:return: Number of parameters.
:rtype: int
"""
return int(self.signature_matrix.shape[1])
@staticmethod
def _precompute_sparse(
channels: list[Channel], signature_matrix: Any
) -> list[tuple[float, Any, Any]]:
"""
Precompute per-channel data for geometric-skip sampling.
:param channels: List of error channels.
:type channels: list[Channel]
:param signature_matrix: Binary matrix mapping error bits to f-basis bits.
:type signature_matrix: np.ndarray
:return: List of precomputed sampling data (p_fire, cond_cdf, xor_patterns).
:rtype: list[tuple[float, np.ndarray, np.ndarray]]
"""
data: list[tuple[float, Any, Any]] = []
for ch in channels:
probs = ch.probs.astype(np.float64)
p_fire = 1.0 - float(probs[0])
n_outcomes = len(probs)
if p_fire <= 1e-15 or n_outcomes <= 1:
continue
cond_cdf = np.cumsum(probs[1:] / p_fire, dtype=np.float64)
cond_cdf /= cond_cdf[-1]
col_ids = np.asarray(ch.unique_col_ids)
num_bits = len(col_ids)
outcomes = np.arange(1, n_outcomes)
bits_mask = ((outcomes[:, None] >> np.arange(num_bits)) & 1).astype(
np.uint8
)
xor_patterns = (bits_mask @ signature_matrix[col_ids] % 2).astype(np.uint8)
data.append((p_fire, cond_cdf, xor_patterns))
return data
[docs]
def sample(self, num_samples: int = 1) -> Any:
"""
Sample from all error channels and transform to the f-basis.
Uses geometric-skip sampling optimized for low-noise regimes.
:param num_samples: Number of samples to draw, defaults to 1.
:type num_samples: int, optional
:return: Binary array of shape (num_samples, num_f) containing sampled outcomes.
:rtype: np.ndarray
"""
num_outputs = self.signature_matrix.shape[1]
result = np.zeros((num_samples, num_outputs), dtype=np.uint8)
for p_fire, cond_cdf, xor_pats in self._sparse_data:
expected = num_samples * p_fire
sigma = np.sqrt(expected * (1.0 - p_fire))
# At 7 sigma, we undersample in about 1 out of 1e12 cases
n_draws = int(expected + 7.0 * sigma) + 100
positions = np.cumsum(self._rng.geometric(p_fire, size=n_draws)) - 1
positions = positions[positions < num_samples]
if len(positions) == 0:
continue
outcome_idx = np.searchsorted(
cond_cdf, self._rng.uniform(size=len(positions))
)
result[positions] ^= xor_pats[outcome_idx]
return result
@staticmethod
def _precompute_jax(
channels: list[Channel], signature_matrix: Any
) -> tuple[Any, Any] | None:
"""Precompute batched JAX arrays for on-device categorical sampling.
Returns (log_probs_padded, xor_patterns_padded) or None if no active
channels. log_probs_padded has shape (num_channels, max_outcomes) and
xor_patterns_padded has shape (num_channels, max_outcomes, num_f).
Inactive outcome slots are filled with -inf log-prob so they are never
sampled.
"""
num_f = signature_matrix.shape[1]
active: list[tuple[Any, Any]] = []
for ch in channels:
probs = ch.probs.astype(np.float64)
p_fire = 1.0 - float(probs[0])
n_outcomes = len(probs)
if p_fire <= 1e-15 or n_outcomes <= 1:
continue
col_ids = np.asarray(ch.unique_col_ids)
num_bits = len(col_ids)
outcomes = np.arange(n_outcomes)
bits_mask = ((outcomes[:, None] >> np.arange(num_bits)) & 1).astype(
np.uint8
)
xor_patterns = (bits_mask @ signature_matrix[col_ids] % 2).astype(np.uint8)
log_probs = np.log(np.maximum(probs, 1e-30)).astype(np.float32)
active.append((log_probs, xor_patterns))
if not active:
return None
max_outcomes = max(lp.shape[0] for lp, _ in active)
# Pad all channels to the same number of outcomes
log_probs_list = []
xor_pats_list = []
for lp, xp in active:
n = lp.shape[0]
if n < max_outcomes:
lp_padded = np.full(max_outcomes, -1e30, dtype=np.float32)
lp_padded[:n] = lp
xp_padded = np.zeros((max_outcomes, num_f), dtype=np.uint8)
xp_padded[:n] = xp
else:
lp_padded = lp
xp_padded = xp
log_probs_list.append(lp_padded)
xor_pats_list.append(xp_padded)
return (
jnp.array(np.stack(log_probs_list)), # (num_channels, max_outcomes)
jnp.array(np.stack(xor_pats_list)), # (num_channels, max_outcomes, num_f)
)
[docs]
def sample_jax(self, num_samples: int, key: Any) -> tuple[Any, Any]:
"""Sample from all error channels on-device using JAX categorical sampling.
:param num_samples: Number of samples to draw.
:type num_samples: int
:param key: JAX PRNG key.
:type key: Any
:return: Tuple of (samples, next_key) where samples is a JAX array
of shape (num_samples, num_f) with dtype uint8.
:rtype: tuple[jax.Array, Any]
"""
num_outputs = self.signature_matrix.shape[1]
if self._jax_data is None:
return jnp.zeros((num_samples, num_outputs), dtype=jnp.uint8), key
log_probs_all, xor_pats_all = self._jax_data
key, subkey = jax.random.split(key)
result = _sample_jax_impl(log_probs_all, xor_pats_all, subkey, num_samples)
return result, key
@functools.partial(jax.jit, static_argnums=(3,))
def _sample_jax_impl(
log_probs_all: jax.Array,
xor_pats_all: jax.Array,
key: jax.Array,
num_samples: int,
) -> jax.Array:
"""JIT-compiled core of JAX categorical sampling."""
num_channels = log_probs_all.shape[0]
subkeys = jax.random.split(key, num_channels)
outcome_indices = jax.vmap(
lambda k, lp: jax.random.categorical(k, lp, shape=(num_samples,))
)(subkeys, log_probs_all)
gathered = jax.vmap(lambda xp, idx: xp[idx])(xor_pats_all, outcome_indices)
return jnp.bitwise_xor.reduce(gathered, axis=0)