Source code for tensorcircuit.interfaces.jax

"""
Interface wraps quantum function as a jax function
"""

from typing import Any, Callable, Tuple, Optional, Union, Sequence
from functools import wraps, partial

from ..cons import backend
from .tensortrans import general_args_to_backend

Tensor = Any


[docs] def jax_wrapper( fun: Callable[..., Any], enable_dlpack: bool = False, output_shape: Optional[ Union[Tuple[int, ...], Tuple[int, ...], Sequence[Tuple[int, ...]]] ] = None, output_dtype: Optional[Union[Any, Sequence[Any]]] = None, ) -> Callable[..., Any]: import jax @wraps(fun) def fun_jax(*x: Any) -> Any: def wrapped_fun(*args: Any) -> Any: args = general_args_to_backend(args, enable_dlpack=enable_dlpack) y = fun(*args) y = general_args_to_backend( y, target_backend="jax", enable_dlpack=enable_dlpack ) return y # Use provided shape and dtype if available, otherwise run test if output_shape is not None and output_dtype is not None: if ( isinstance(output_shape, Sequence) and len(output_shape) > 0 and not isinstance(output_shape[0], int) ): # Multiple outputs case out_shape = tuple( jax.ShapeDtypeStruct(s, d) for s, d in zip(output_shape, output_dtype) ) else: # Single output case out_shape = jax.ShapeDtypeStruct(output_shape, output_dtype) # type: ignore else: # Get expected output shape by running function once test_out = wrapped_fun(*x) if isinstance(test_out, tuple): # Multiple outputs case out_shape = tuple( jax.ShapeDtypeStruct( t.shape if hasattr(t, "shape") else (), t.dtype if hasattr(t, "dtype") else x[0].dtype, ) for t in test_out ) else: # Single output case out_shape = jax.ShapeDtypeStruct( # type: ignore test_out.shape if hasattr(test_out, "shape") else (), test_out.dtype if hasattr(test_out, "dtype") else x[0].dtype, ) # Use pure_callback with correct output shape result = jax.pure_callback(wrapped_fun, out_shape, *x) return result return fun_jax
[docs] def jax_interface( fun: Callable[..., Any], jit: bool = False, enable_dlpack: bool = False, output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None, output_dtype: Optional[Any] = None, ) -> Callable[..., Any]: """ Wrap a function on different ML backend with a jax interface. :Example: .. code-block:: python tc.set_backend("tensorflow") def f(params): c = tc.Circuit(1) c.rx(0, theta=params[0]) c.ry(0, theta=params[1]) return tc.backend.real(c.expectation([tc.gates.z(), [0]])) f = tc.interfaces.jax_interface(f, jit=True) params = jnp.ones(2) value, grad = jax.value_and_grad(f)(params) :param fun: The quantum function with tensor in and tensor out :type fun: Callable[..., Any] :param jit: whether to jit ``fun``, defaults to False :type jit: bool, optional :param enable_dlpack: whether transform tensor backend via dlpack, defaults to False :type enable_dlpack: bool, optional :param output_shape: Optional shape of the function output, defaults to None :type output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]], optional :param output_dtype: Optional dtype of the function output, defaults to None :type output_dtype: Optional[Any], optional :return: The same quantum function but now with jax array in and jax array out while AD is also supported :rtype: Callable[..., Any] """ jax_fun = create_jax_function( fun, enable_dlpack=enable_dlpack, jit=jit, output_shape=output_shape, output_dtype=output_dtype, ) return jax_fun
[docs] def create_jax_function( fun: Callable[..., Any], enable_dlpack: bool = False, jit: bool = False, output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None, output_dtype: Optional[Any] = None, ) -> Callable[..., Any]: import jax from jax import custom_vjp if jit: fun = backend.jit(fun) wrapped = jax_wrapper( fun, enable_dlpack=enable_dlpack, output_shape=output_shape, output_dtype=output_dtype, ) @custom_vjp def f(*x: Any) -> Any: return wrapped(*x) def f_fwd(*x: Any) -> Tuple[Any, Tuple[Any, ...]]: y = wrapped(*x) return y, x def f_bwd(res: Tuple[Any, ...], g: Any) -> Tuple[Any, ...]: x = res if len(x) == 1: x = x[0] vjp_fun = partial(backend.vjp, fun) if jit: vjp_fun = backend.jit(vjp_fun) # type: ignore def vjp_wrapped(args: Any) -> Any: args = general_args_to_backend(args, enable_dlpack=enable_dlpack) gb = general_args_to_backend(g, enable_dlpack=enable_dlpack) r = vjp_fun(args, gb)[1] r = general_args_to_backend( r, target_backend="jax", enable_dlpack=enable_dlpack ) return r # Handle gradient shape for both single input and tuple inputs if isinstance(x, tuple): # Create a tuple of ShapeDtypeStruct for each input grad_shape = tuple(jax.ShapeDtypeStruct(xi.shape, xi.dtype) for xi in x) else: grad_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) dx = jax.pure_callback( vjp_wrapped, grad_shape, x, ) if not isinstance(dx, tuple): dx = (dx,) return dx # type: ignore f.defvjp(f_fwd, f_bwd) return f