tensorcircuit.interfaces.jax¶

Interface wraps quantum function as a jax function

tensorcircuit.interfaces.jax.create_jax_function(fun: Callable[[...], Any], enable_dlpack: bool = False, jit: bool = False, output_shape: Tuple[int, ...] | Tuple[()] | None = None, output_dtype: Any | None = None) Callable[[...], Any][source]¶
tensorcircuit.interfaces.jax.jax_interface(fun: Callable[[...], Any], jit: bool = False, enable_dlpack: bool = False, output_shape: Tuple[int, ...] | Tuple[()] | None = None, output_dtype: Any | None = None) Callable[[...], Any][source]¶

Wrap a function on different ML backend with a jax interface.

Example:

tc.set_backend("tensorflow")

def f(params):
    c = tc.Circuit(1)
    c.rx(0, theta=params[0])
    c.ry(0, theta=params[1])
    return tc.backend.real(c.expectation([tc.gates.z(), [0]]))

f = tc.interfaces.jax_interface(f, jit=True)

params = jnp.ones(2)
value, grad = jax.value_and_grad(f)(params)
Parameters:
  • fun (Callable[..., Any]) – The quantum function with tensor in and tensor out

  • jit (bool, optional) – whether to jit fun, defaults to False

  • enable_dlpack (bool, optional) – whether transform tensor backend via dlpack, defaults to False

  • output_shape (Optional[Union[Tuple[int, ...], Tuple[()]]], optional) – Optional shape of the function output, defaults to None

  • output_dtype (Optional[Any], optional) – Optional dtype of the function output, defaults to None

Returns:

The same quantum function but now with jax array in and jax array out while AD is also supported

Return type:

Callable[…, Any]

tensorcircuit.interfaces.jax.jax_wrapper(fun: Callable[[...], Any], enable_dlpack: bool = False, output_shape: Tuple[int, ...] | Sequence[Tuple[int, ...]] | None = None, output_dtype: Any | Sequence[Any] | None = None) Callable[[...], Any][source]¶