Source code for tensorcircuit.backends.pytorch_backend

"""
Backend magic inherited from tensornetwork: pytorch backend
"""

# pylint: disable=invalid-name

import logging
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from operator import mul
from functools import reduce, partial

from scipy.sparse import coo_matrix
import tensornetwork
from tensornetwork.backends.pytorch import pytorch_backend
from .abstract_backend import ExtendedBackend

dtypestr: str
rdtypestr: str
Tensor = Any
pytree = Any

torchlib: Any

logger = logging.getLogger(__name__)


# TODO(@refraction-ray): lack stateful random methods implementation for now
# TODO(@refraction-ray): lack scatter impl for now
# To be added once pytorch backend is ready


[docs] class torch_jit_func: """ Delay the tracing of torch jit to the first run time: consistent with tf and jax mechanism """
[docs] def __init__(self, f: Callable[..., Any]): self.compiled = False self.f = f
def __call__(self, *args: Any, **kws: Any) -> Any: if self.compiled is False: self.f = torchlib.jit.trace(self.f, example_inputs=args) self.compiled = True return self.f(*args, **kws)
[docs] class torch_optimizer:
[docs] def __init__(self, optimizer: Any) -> None: self.optimizer = optimizer self.is_init = False
[docs] def update(self, grads: pytree, params: pytree) -> pytree: # flatten grad and param params, treedef = PyTorchBackend.tree_flatten(None, params) grads, _ = PyTorchBackend.tree_flatten(None, grads) if self.is_init is False: self.optimizer = self.optimizer(params) self.is_init = True with torchlib.no_grad(): for g, p in zip(grads, params): p.grad = g self.optimizer.step() self.optimizer.zero_grad() # reorg the param params = PyTorchBackend.tree_unflatten(None, treedef, params) return params
def _conj_torch(self: Any, tensor: Tensor) -> Tensor: t = torchlib.conj(tensor) return t.resolve_conj() # any side effect? def _sum_torch( self: Any, tensor: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(tensor.shape))]) return torchlib.sum(tensor, dim=axis, keepdim=keepdims) def _qr_torch( self: Any, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the QR decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `q` would have shape (2, 3, 6), and `r` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .pytorch_ops import torchqr left_dims = list(tensor.shape[:pivot_axis]) right_dims = list(tensor.shape[pivot_axis:]) tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)]) q, r = torchqr.apply(tensor) if non_negative_diagonal: phases = torchlib.sign(torchlib.linalg.diagonal(r)) q = q * phases r = phases[:, None] * r center_dim = q.shape[1] q = torchlib.reshape(q, left_dims + [center_dim]) r = torchlib.reshape(r, [center_dim] + right_dims) return q, r def _rq_torch( self: Any, tensor: Tensor, pivot_axis: int = 1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the RQ decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `r` would have shape (2, 3, 6), and `q` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .pytorch_ops import torchqr left_dims = list(tensor.shape[:pivot_axis]) right_dims = list(tensor.shape[pivot_axis:]) tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)]) q, r = torchqr.apply(tensor.adjoint()) if non_negative_diagonal: phases = torchlib.sign(torchlib.linalg.diagonal(r)) q = q * phases r = phases[:, None] * r r, q = r.adjoint(), q.adjoint() # M=r*q at this point center_dim = r.shape[1] r = torchlib.reshape(r, left_dims + [center_dim]) q = torchlib.reshape(q, [center_dim] + right_dims) return r, q tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.sum = _sum_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.conj = _conj_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.qr = _qr_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.rq = _rq_torch def _patched_eigh(self: Any, matrix: Any) -> Tuple[Any, Any]: eigvals, eigvecs = torchlib.linalg.eigh(matrix) return eigvals, eigvecs pytorch_backend.PyTorchBackend.eigh = _patched_eigh
[docs] class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type: ignore """ See the original backend API at `pytorch backend <https://github.com/google/TensorNetwork/blob/master/tensornetwork/backends/pytorch/pytorch_backend.py>`_ Note the functionality provided by pytorch backend is incomplete, it currenly lacks native efficicent jit and vmap support. """
[docs] def __init__(self) -> None: super(PyTorchBackend, self).__init__() global torchlib try: import torch except ImportError: raise ImportError( "PyTorch not installed, please switch to a different " "backend or install PyTorch." ) torchlib = torch self.name = "pytorch"
[docs] def eye( self, N: int, dtype: Optional[str] = None, M: Optional[int] = None ) -> Tensor: if dtype is None: dtype = dtypestr if not M: M = N r = torchlib.eye(n=N, m=M) return self.cast(r, dtype)
[docs] def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = torchlib.ones(shape) return self.cast(r, dtype)
[docs] def exp(self, tensor: Tensor) -> Tensor: return torchlib.exp(tensor)
[docs] def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = torchlib.zeros(shape) return self.cast(r, dtype)
[docs] def zeros_like(self, a: Tensor, dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = self.dtype(a) r = torchlib.zeros_like(a) return self.cast(r, dtype)
[docs] def ones_like(self, a: Tensor, dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = self.dtype(a) r = torchlib.ones_like(a) return self.cast(r, dtype)
[docs] def copy(self, a: Tensor) -> Tensor: return a.clone()
[docs] def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: if self.is_tensor(tensor): result = tensor else: result = torchlib.tensor(tensor) if dtype is not None: result = self.cast(result, dtype) return result
[docs] def set_random_state( self, seed: Optional[int] = None, get_only: bool = False ) -> Any: if isinstance(seed, torchlib.Generator): g = seed else: g = torchlib.Generator() if seed is not None: g.manual_seed(seed) if get_only is False: self.g = g return g
[docs] def stateful_randn( self, g: Any, shape: Union[int, Sequence[int]] = 1, mean: float = 0, stddev: float = 1, dtype: Optional[str] = None, ) -> Tensor: if dtype is None: dtype = rdtypestr if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = torchlib.float32 elif dtype == "64": dtyper = torchlib.float64 elif not isinstance(dtype, str): dtyper = dtype else: raise ValueError("unspported `dtype` %s" % dtype) r = torchlib.randn(size=shape, generator=g, dtype=dtyper) * stddev + mean return r
[docs] def stateful_randu( self, g: Any, shape: Union[int, Sequence[int]] = 1, low: float = 0, high: float = 1, dtype: Optional[str] = None, ) -> Tensor: if dtype is None: dtype = rdtypestr if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = torchlib.float32 elif dtype == "64": dtyper = torchlib.float64 elif not isinstance(dtype, str): dtyper = dtype else: raise ValueError("unspported `dtype` %s" % dtype) r = torchlib.rand(size=shape, generator=g, dtype=dtyper) * (high - low) + low return r
[docs] def stateful_randc( self, g: Any, a: Union[int, Sequence[int], Tensor], shape: Union[int, Sequence[int]], p: Optional[Union[Sequence[float], Tensor]] = None, ) -> Tensor: if isinstance(shape, int): shape = (shape,) if isinstance(a, int): possible_values = torchlib.arange(a) n = a else: possible_values = self.convert_to_tensor(a) n = possible_values.shape[0] if p is None: indices = torchlib.randint(0, n, size=shape, generator=g) else: p = self.convert_to_tensor(p) num_samples = reduce(mul, shape) indices = torchlib.multinomial( p, num_samples, replacement=True, generator=g ) indices = torchlib.reshape(indices, shape) return possible_values[indices]
[docs] def expm(self, a: Tensor) -> Tensor: return torchlib.linalg.matrix_exp(a)
# raise NotImplementedError("pytorch backend doesn't support expm") # in 2020, torch has no expm, hmmm. but that's ok, # it doesn't support complex numbers which is more severe issue. # see https://github.com/pytorch/pytorch/issues/9983
[docs] def sin(self, a: Tensor) -> Tensor: return torchlib.sin(a)
[docs] def cos(self, a: Tensor) -> Tensor: return torchlib.cos(a)
[docs] def acos(self, a: Tensor) -> Tensor: return torchlib.acos(a)
[docs] def acosh(self, a: Tensor) -> Tensor: return torchlib.acosh(a)
[docs] def asin(self, a: Tensor) -> Tensor: return torchlib.asin(a)
[docs] def asinh(self, a: Tensor) -> Tensor: return torchlib.asinh(a)
[docs] def atan(self, a: Tensor) -> Tensor: return torchlib.atan(a)
[docs] def atan2(self, y: Tensor, x: Tensor) -> Tensor: return torchlib.atan2(y, x)
[docs] def atanh(self, a: Tensor) -> Tensor: return torchlib.atanh(a)
[docs] def cosh(self, a: Tensor) -> Tensor: return torchlib.cosh(a)
[docs] def tan(self, a: Tensor) -> Tensor: return torchlib.tan(a)
[docs] def tanh(self, a: Tensor) -> Tensor: return torchlib.tanh(a)
[docs] def sinh(self, a: Tensor) -> Tensor: return torchlib.sinh(a)
[docs] def size(self, a: Tensor) -> Tensor: return a.size()
[docs] def eigvalsh(self, a: Tensor) -> Tensor: return torchlib.linalg.eigvalsh(a)
[docs] def lobpcg_standard( self, a: Union[Tensor, Callable[[Tensor], Tensor]], x0: Tensor, m: int = 100, tol: Optional[Union[Tensor, float]] = None, ) -> Tuple[Tensor, Tensor, int]: """ PyTorch LOBPCG implementation. Note: 1. Complex input is not officially supported and is numerically unstable. 2. Callable input for operator 'a' is not supported. """ if callable(a): raise NotImplementedError( "PyTorch backend `lobpcg` does not support callable linear operator yet." ) theta, x = torchlib.lobpcg( a, X=x0, niter=m, tol=tol, largest=True, ortho_fparams={"eps": 1e-6} ) return theta, x, m
[docs] def kron(self, a: Tensor, b: Tensor) -> Tensor: return torchlib.kron(a, b)
[docs] def numpy(self, a: Tensor) -> Tensor: if self.is_sparse(a): a = a.coalesce() return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape) a = a.cpu() if a.is_conj(): return a.resolve_conj().numpy() if a.requires_grad: return a.detach().numpy() return a.numpy()
[docs] def i(self, dtype: Any = None) -> Tensor: if not dtype: dtype = getattr(torchlib, dtypestr) if isinstance(dtype, str): dtype = getattr(torchlib, dtype) return torchlib.tensor(1j, dtype=dtype)
[docs] def det(self, a: Tensor) -> Tensor: return torchlib.linalg.det(a)
[docs] def real(self, a: Tensor) -> Tensor: try: a = torchlib.real(a) except RuntimeError: pass return a
[docs] def imag(self, a: Tensor) -> Tensor: try: a = torchlib.imag(a) except RuntimeError: pass return a
[docs] def dtype(self, a: Tensor) -> str: return a.dtype.__str__().split(".")[-1] # type: ignore
[docs] def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return torchlib.stack(a, dim=axis)
[docs] def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return torchlib.cat(a, dim=axis)
[docs] def tile(self, a: Tensor, rep: Tensor) -> Tensor: return torchlib.tile(a, rep)
[docs] def mean( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(a.shape))]) return torchlib.mean(a, dim=axis, keepdim=keepdims)
[docs] def std( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(a.shape))]) return torchlib.std(a, dim=axis, unbiased=False, keepdim=keepdims)
[docs] def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: return torchlib.min(a) return torchlib.min(a, dim=axis).values
[docs] def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: return torchlib.max(a) return torchlib.max(a, dim=axis).values
[docs] def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmax(a, dim=axis)
[docs] def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmin(a, dim=axis)
[docs] def sort(self, a: Tensor, axis: int = -1) -> Tensor: return torchlib.sort(a, dim=axis).values
[docs] def top_k(self, a: Tensor, k: int) -> Tuple[Tensor, Tensor]: r = torchlib.topk(a, k) return r.values, r.indices
[docs] def lexsort(self, keys: Any, axis: int = -1) -> Any: # PyTorch stable sort for lexsort if not keys: return None idx = torchlib.arange(keys[0].size(0), device=keys[0].device) for k in keys: idx = idx[torchlib.argsort(k[idx], stable=True)] return idx
[docs] def repeat(self, a: Any, repeats: Any, axis: Optional[int] = None) -> Any: return torchlib.repeat_interleave(a, repeats, dim=axis)
[docs] def popc(self, a: Any) -> Any: if hasattr(torchlib, "bitwise_count"): # PyTorch 1.10+ return torchlib.bitwise_count(a) # Fallback c = a.to(torchlib.int64) c = c - ((c >> 1) & 0x5555555555555555) c = (c & 0x3333333333333333) + ((c >> 2) & 0x3333333333333333) weight_w = (((c + (c >> 4)) & 0x0F0F0F0F0F0F0F0F) * 0x0101010101010101) >> 56 return weight_w.to(torchlib.int32)
[docs] def argsort(self, a: Tensor, axis: int = -1) -> Tensor: return torchlib.argsort(a, dim=axis)
[docs] def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return torchlib.unique(a, return_counts=True) # type: ignore
[docs] def sigmoid(self, a: Tensor) -> Tensor: return torchlib.sigmoid(a)
[docs] def relu(self, a: Tensor) -> Tensor: return torchlib.relu(a)
[docs] def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor: return torchlib.nn.Softmax(a, dim=axis)
[docs] def onehot(self, a: Tensor, num: int) -> Tensor: a = a.long() return torchlib.nn.functional.one_hot(a, num)
[docs] def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: a = self.reshape(a, [-1]) return torchlib.cumsum(a, dim=0) else: return torchlib.cumsum(a, dim=axis)
[docs] def is_tensor(self, a: Any) -> bool: if isinstance(a, torchlib.Tensor): return True return False
[docs] def cast(self, a: Tensor, dtype: Union[str, Any]) -> Tensor: if isinstance(dtype, str): torch_dtype = getattr(torchlib, dtype) else: torch_dtype = dtype if torchlib.is_complex(a) and not getattr(torch_dtype, "is_complex", False): a = torchlib.real(a) return a.to(torch_dtype)
[docs] def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor: if stop is None: return torchlib.arange(start=0, end=start, step=step) return torchlib.arange(start=start, end=stop, step=step)
[docs] def mod(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.fmod(x, y)
[docs] def floor_divide(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.floor_divide(x, y)
[docs] def floor(self, a: Tensor) -> Tensor: return torchlib.floor(a)
[docs] def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: return torchlib.clamp(a, a_min, a_max)
[docs] def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_right_shift(x, y)
[docs] def left_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_left_shift(x, y)
[docs] def bitwise_and(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_and(x, y)
[docs] def bitwise_xor(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_xor(x, y)
[docs] def bitwise_or(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_or(x, y)
[docs] def any(self, a: Tensor) -> Any: return torchlib.any(a)
[docs] def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Any: if axis is None: return torchlib.all(a) return torchlib.all(a, dim=axis)
[docs] def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor: return torchlib.linalg.solve(A, b)
[docs] def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: if not self.is_tensor(a): a = self.convert_to_tensor(a) if not self.is_tensor(v): v = self.convert_to_tensor(v) return torchlib.searchsorted(a, v, side=side)
[docs] def where( self, condition: Tensor, x: Optional[Tensor] = None, y: Optional[Tensor] = None, ) -> Tensor: if x is None and y is None: return torchlib.where(condition) assert x is not None and y is not None return torchlib.where(condition, x, y)
[docs] def scatter( self, operand: Tensor, indices: Tensor, updates: Tensor, mode: str = "update" ) -> Tensor: operand_new = operand.clone() index_depth = indices.shape[-1] idx_tuple = tuple(indices[..., i] for i in range(index_depth)) if mode == "update": return operand_new.index_put(idx_tuple, updates, accumulate=False) elif mode == "add": return operand_new.index_put(idx_tuple, updates, accumulate=True) elif mode == "sub": return operand_new.index_put(idx_tuple, -updates, accumulate=True) else: raise ValueError(f"Unsupported scatter mode: {mode}")
[docs] def reverse(self, a: Tensor) -> Tensor: return torchlib.flip(a, dims=(-1,))
[docs] def coo_sparse_matrix( self, indices: Tensor, values: Tensor, shape: Tensor ) -> Tensor: # Convert COO format to PyTorch sparse tensor indices = self.convert_to_tensor(indices) return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape)
[docs] def sparse_dense_matmul( self, sp_a: Tensor, b: Tensor, ) -> Tensor: # Matrix multiplication between sparse and dense tensor return torchlib.sparse.mm(sp_a, b)
[docs] def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor: try: # Convert COO to CSR format if supported return coo.to_sparse_csr() except AttributeError as e: if not strict: return coo else: raise e
[docs] def to_dense(self, sp_a: Tensor) -> Tensor: # Convert sparse tensor to dense return sp_a.to_dense()
[docs] def is_sparse(self, a: Tensor) -> bool: # Check if tensor is sparse return a.is_sparse or a.is_sparse_csr # type: ignore
[docs] def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any: # torch native tree_map not support multiple pytree args # return torchlib.utils._pytree.tree_map(f, *pytrees) args = [] for pytree in pytrees: flat_args, spec = self.tree_flatten(pytree) args.append(flat_args) res = [ f(*[args[i][k] for i in range(len(pytrees))]) for k in range(len(flat_args)) ] return self.tree_unflatten(spec, res)
[docs] def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]: return torchlib.utils._pytree.tree_flatten(pytree) # type: ignore
[docs] def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any: return torchlib.utils._pytree.tree_unflatten(leaves, treedef)
[docs] def from_dlpack(self, a: Any) -> Tensor: return torchlib.utils.dlpack.from_dlpack(a)
[docs] def to_dlpack(self, a: Tensor) -> Any: return torchlib.utils.dlpack.to_dlpack(a)
[docs] def cond( self, pred: bool, true_fun: Callable[[], Tensor], false_fun: Callable[[], Tensor], ) -> Tensor: if pred: return true_fun() return false_fun()
[docs] def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor: return branches[index.numpy()]()
[docs] def device(self, a: Tensor) -> str: dev = a.device return self._dev2str(dev)
[docs] def device_move(self, a: Tensor, dev: Any) -> Tensor: if not isinstance(dev, str): dev = self._dev2str(dev) if dev.startswith("gpu"): dev = "cuda:" + dev.split(":")[-1] return a.to(device=dev)
def _dev2str(self, dev: Any) -> str: if dev.type == "cpu": return "cpu" if dev.type == "cuda": return "gpu:" + str(dev.index) raise ValueError("PyTorchBackend don't support non-GPU/CPU device") def _str2dev(self, str_: str) -> Any: if str_ == "cpu": return torchlib.device("cpu") if str_.startswith("gpu"): _id = int(str_.split(":")[-1]) return torchlib.cuda.device(_id) raise ValueError("PyTorchBackend don't support non-GPU/CPU device")
[docs] def stop_gradient(self, a: Tensor) -> Tensor: return a.detach()
[docs] def grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Any]: def wrapper(*args: Any, **kws: Any) -> Any: y, gr = self.value_and_grad(f, argnums, has_aux)(*args, **kws) if has_aux: return gr, y[1:] return gr return wrapper
# def wrapper(*args: Any, **kws: Any) -> Any: # x = [] # if isinstance(argnums, int): # argnumsl = [argnums] # # if you also call lhs as argnums, something weird may happen # # the reason is that python then take it as local vars # else: # argnumsl = argnums # type: ignore # for i, arg in enumerate(args): # if i in argnumsl: # x.append(arg.requires_grad_(True)) # else: # x.append(arg) # y = f(*x, **kws) # y.backward() # gs = [x[i].grad for i in argnumsl] # if len(gs) == 1: # gs = gs[0] # return gs # return wrapper
[docs] def value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: def wrapper(*args: Any, **kws: Any) -> Any: gavf = torchlib.func.grad_and_value(f, argnums=argnums, has_aux=has_aux) g, v = gavf(*args, **kws) return v, g return wrapper
# def ask_require(t: Tensor) -> Any: # t.requires_grad_(True) # return t # def get_grad(t: Tensor) -> Tensor: # return t.grad # def wrapper(*args: Any, **kws: Any) -> Any: # # x = [] # if isinstance(argnums, int): # argnumsl = [argnums] # # if you also call lhs as argnums, something weird may happen # # the reason is that python then take it as local vars # else: # argnumsl = argnums # type: ignore # args = list(args) # for i, arg in enumerate(args): # if i in argnumsl: # args[i] = self.tree_map(ask_require, arg) # args = tuple(args) # y = f(*args, **kws) # if has_aux: # y[0].backward() # else: # y.backward() # gs = [self.tree_map(get_grad, x[i]) for i in argnumsl] # if len(gs) == 1: # gs = gs[0] # return y, gs # return wrapper
[docs] def vjp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if isinstance(inputs, list): inputs = tuple(inputs) if isinstance(v, list): v = tuple(v) return torchlib.autograd.functional.vjp(f, inputs, v) # type: ignore
[docs] def jvp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if isinstance(inputs, list): inputs = tuple(inputs) if isinstance(v, list): v = tuple(v) # for both tf and torch # behind the scene: https://j-towns.github.io/2017/06/12/A-new-trick.html # to be investigate whether the overhead issue remains as in # https://github.com/renmengye/tensorflow-forward-ad/issues/2 return torchlib.autograd.functional.jvp(f, inputs, v) # type: ignore
[docs] def vmap( self, f: Callable[..., Any], vectorized_argnums: Union[int, Sequence[int]] = 0, ) -> Any: if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) def wrapper(*args: Any, **kws: Any) -> Tensor: in_axes = tuple([0 if i in vectorized_argnums else None for i in range(len(args))]) # type: ignore return torchlib.vmap(f, in_axes, 0)(*args, **kws) return wrapper
# v3 # logger.warning( # "pytorch backend has no intrinsic vmap like interface" # ", use plain for loop for compatibility" # ) # # the vmap support is vey limited, f must return one tensor # # nested list of tensor as return is not supported # if isinstance(vectorized_argnums, int): # vectorized_argnums = (vectorized_argnums,) # def wrapper(*args: Any, **kws: Any) -> Tensor: # results = [] # for barg in zip(*[args[i] for i in vectorized_argnums]): # type: ignore # narg = [] # j = 0 # for k in range(len(args)): # if k in vectorized_argnums: # type: ignore # narg.append(barg[j]) # j += 1 # else: # narg.append(args[k]) # results.append(f(*narg, **kws)) # return torchlib.stack(results) # return wrapper # v2 # def vmapf(*args: Tensor, **kws: Any) -> Tensor: # r = [] # for i in range(args[0].shape[0]): # nargs = [arg[i] for arg in args] # r.append(f(*nargs, **kws)) # return torchlib.stack(r) # return vmapf # v1 # raise NotImplementedError("pytorch backend doesn't support vmap") # There seems to be no map like architecture in pytorch for now # see https://discuss.pytorch.org/t/fast-way-to-use-map-in-pytorch/70814
[docs] def jit( self, f: Callable[..., Any], static_argnums: Optional[Union[int, Sequence[int]]] = None, jit_compile: Optional[bool] = None, **kws: Any, ) -> Any: if jit_compile is True: # experimental feature reusing the jit_compile flag for tf return torchlib.compile(f) return f
# return f # do nothing here until I figure out what torch.jit is for and how does it work # see https://github.com/pytorch/pytorch/issues/36910
[docs] def vectorized_value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, vectorized_argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: # [WIP], not a consistent impl compared to tf and jax backend, but pytorch backend is not fully supported anyway if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) def wrapper( *args: Any, **kws: Any ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: jf = self.value_and_grad(f, argnums=argnums, has_aux=has_aux) jf = self.vmap(jf, vectorized_argnums=vectorized_argnums) vs, gs = jf(*args, **kws) if isinstance(argnums, int): argnums_list = [argnums] gs = [gs] else: argnums_list = argnums # type: ignore gs = list(gs) for i, (j, g) in enumerate(zip(argnums_list, gs)): if j not in vectorized_argnums: # type: ignore gs[i] = self.tree_map(partial(torchlib.sum, dim=0), g) if isinstance(argnums, int): gs = gs[0] else: gs = tuple(gs) return vs, gs return wrapper
[docs] def expand_dims(self, a: Tensor, axis: int) -> Tensor: return torchlib.unsqueeze(a, dim=axis)
vvag = vectorized_value_and_grad
[docs] def meshgrid(self, *args: Any, **kws: Any) -> Tensor: return torchlib.meshgrid(*args, **kws)
optimizer = torch_optimizer