tensorcircuit.interfaces.jax¶
Interface wraps quantum function as a jax function
- tensorcircuit.interfaces.jax.create_jax_function(fun: Callable[[...], Any], enable_dlpack: bool = False, jit: bool = False, output_shape: Tuple[int, ...] | Tuple | None = None, output_dtype: Any | None = None) Callable[[...], Any][source]¶
- tensorcircuit.interfaces.jax.jax_interface(fun: Callable[[...], Any], jit: bool = False, enable_dlpack: bool = False, output_shape: Tuple[int, ...] | Tuple | None = None, output_dtype: Any | None = None) Callable[[...], Any][source]¶
Wrap a function on different ML backend with a jax interface.
- Example:
tc.set_backend("tensorflow") def f(params): c = tc.Circuit(1) c.rx(0, theta=params[0]) c.ry(0, theta=params[1]) return tc.backend.real(c.expectation([tc.gates.z(), [0]])) f = tc.interfaces.jax_interface(f, jit=True) params = jnp.ones(2) value, grad = jax.value_and_grad(f)(params)
- Parameters:
fun (Callable[..., Any]) – The quantum function with tensor in and tensor out
jit (bool, optional) – whether to jit
fun, defaults to Falseenable_dlpack (bool, optional) – whether transform tensor backend via dlpack, defaults to False
output_shape (Optional[Union[Tuple[int, ...], Tuple[()]]], optional) – Optional shape of the function output, defaults to None
output_dtype (Optional[Any], optional) – Optional dtype of the function output, defaults to None
- Returns:
The same quantum function but now with jax array in and jax array out while AD is also supported
- Return type:
Callable[…, Any]