"""
Backend magic inherited from tensornetwork: pytorch backend
"""
# pylint: disable=invalid-name
import logging
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from operator import mul
from functools import reduce, partial
from scipy.sparse import coo_matrix
import tensornetwork
from tensornetwork.backends.pytorch import pytorch_backend
from .abstract_backend import ExtendedBackend
dtypestr: str
rdtypestr: str
Tensor = Any
pytree = Any
torchlib: Any
logger = logging.getLogger(__name__)
# TODO(@refraction-ray): lack stateful random methods implementation for now
# TODO(@refraction-ray): lack scatter impl for now
# To be added once pytorch backend is ready
[docs]
class torch_jit_func:
"""
Delay the tracing of torch jit to the first run time:
consistent with tf and jax mechanism
"""
[docs]
def __init__(self, f: Callable[..., Any]):
self.compiled = False
self.f = f
def __call__(self, *args: Any, **kws: Any) -> Any:
if self.compiled is False:
self.f = torchlib.jit.trace(self.f, example_inputs=args)
self.compiled = True
return self.f(*args, **kws)
[docs]
class torch_optimizer:
[docs]
def __init__(self, optimizer: Any) -> None:
self.optimizer = optimizer
self.is_init = False
[docs]
def update(self, grads: pytree, params: pytree) -> pytree:
# flatten grad and param
params, treedef = PyTorchBackend.tree_flatten(None, params)
grads, _ = PyTorchBackend.tree_flatten(None, grads)
if self.is_init is False:
self.optimizer = self.optimizer(params)
self.is_init = True
with torchlib.no_grad():
for g, p in zip(grads, params):
p.grad = g
self.optimizer.step()
self.optimizer.zero_grad()
# reorg the param
params = PyTorchBackend.tree_unflatten(None, treedef, params)
return params
def _conj_torch(self: Any, tensor: Tensor) -> Tensor:
t = torchlib.conj(tensor)
return t.resolve_conj() # any side effect?
def _sum_torch(
self: Any,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False,
) -> Tensor:
if axis is None:
axis = tuple([i for i in range(len(tensor.shape))])
return torchlib.sum(tensor, dim=axis, keepdim=keepdims)
def _qr_torch(
self: Any,
tensor: Tensor,
pivot_axis: int = -1,
non_negative_diagonal: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Computes the QR decomposition of a tensor.
The QR decomposition is performed by treating the tensor as a matrix,
with an effective left (row) index resulting from combining the
axes `tensor.shape[:pivot_axis]` and an effective right (column)
index resulting from combining the axes `tensor.shape[pivot_axis:]`.
:Example:
If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2,
then `q` would have shape (2, 3, 6), and `r` would
have shape (6, 4, 5).
The output consists of two tensors `Q, R` such that:
Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM]
Note that the output ordering matches numpy.linalg.svd rather than tf.svd.
:param tensor: A tensor to be decomposed.
:type tensor: Tensor
:param pivot_axis: Where to split the tensor's axes before flattening into a matrix.
:type pivot_axis: int, optional
:param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix.
:type non_negative_diagonal: bool, optional
:returns: Q, the left tensor factor, and R, the right tensor factor.
:rtype: Tuple[Tensor, Tensor]
"""
from .pytorch_ops import torchqr
left_dims = list(tensor.shape[:pivot_axis])
right_dims = list(tensor.shape[pivot_axis:])
tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)])
q, r = torchqr.apply(tensor)
if non_negative_diagonal:
phases = torchlib.sign(torchlib.linalg.diagonal(r))
q = q * phases
r = phases[:, None] * r
center_dim = q.shape[1]
q = torchlib.reshape(q, left_dims + [center_dim])
r = torchlib.reshape(r, [center_dim] + right_dims)
return q, r
def _rq_torch(
self: Any,
tensor: Tensor,
pivot_axis: int = 1,
non_negative_diagonal: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Computes the RQ decomposition of a tensor.
The QR decomposition is performed by treating the tensor as a matrix,
with an effective left (row) index resulting from combining the axes
`tensor.shape[:pivot_axis]` and an effective right (column) index
resulting from combining the axes `tensor.shape[pivot_axis:]`.
:Example:
If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2,
then `r` would have shape (2, 3, 6), and `q` would
have shape (6, 4, 5).
The output consists of two tensors `Q, R` such that:
Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM]
Note that the output ordering matches numpy.linalg.svd rather than tf.svd.
:param tensor: A tensor to be decomposed.
:type tensor: Tensor
:param pivot_axis: Where to split the tensor's axes before flattening into a matrix.
:type pivot_axis: int, optional
:param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix.
:type non_negative_diagonal: bool, optional
:returns: Q, the left tensor factor, and R, the right tensor factor.
:rtype: Tuple[Tensor, Tensor]
"""
from .pytorch_ops import torchqr
left_dims = list(tensor.shape[:pivot_axis])
right_dims = list(tensor.shape[pivot_axis:])
tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)])
q, r = torchqr.apply(tensor.adjoint())
if non_negative_diagonal:
phases = torchlib.sign(torchlib.linalg.diagonal(r))
q = q * phases
r = phases[:, None] * r
r, q = r.adjoint(), q.adjoint()
# M=r*q at this point
center_dim = r.shape[1]
r = torchlib.reshape(r, left_dims + [center_dim])
q = torchlib.reshape(q, [center_dim] + right_dims)
return r, q
tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.sum = _sum_torch
tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.conj = _conj_torch
tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.qr = _qr_torch
tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.rq = _rq_torch
def _patched_eigh(self: Any, matrix: Any) -> Tuple[Any, Any]:
eigvals, eigvecs = torchlib.linalg.eigh(matrix)
return eigvals, eigvecs
pytorch_backend.PyTorchBackend.eigh = _patched_eigh
[docs]
class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type: ignore
"""
See the original backend API at `pytorch backend
<https://github.com/google/TensorNetwork/blob/master/tensornetwork/backends/pytorch/pytorch_backend.py>`_
Note the functionality provided by pytorch backend is incomplete,
it currenly lacks native efficicent jit and vmap support.
"""
[docs]
def __init__(self) -> None:
super(PyTorchBackend, self).__init__()
global torchlib
try:
import torch
except ImportError:
raise ImportError(
"PyTorch not installed, please switch to a different "
"backend or install PyTorch."
)
torchlib = torch
self.name = "pytorch"
[docs]
def eye(
self, N: int, dtype: Optional[str] = None, M: Optional[int] = None
) -> Tensor:
if dtype is None:
dtype = dtypestr
if not M:
M = N
r = torchlib.eye(n=N, m=M)
return self.cast(r, dtype)
[docs]
def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
if dtype is None:
dtype = dtypestr
r = torchlib.ones(shape)
return self.cast(r, dtype)
[docs]
def exp(self, tensor: Tensor) -> Tensor:
return torchlib.exp(tensor)
[docs]
def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
if dtype is None:
dtype = dtypestr
r = torchlib.zeros(shape)
return self.cast(r, dtype)
[docs]
def zeros_like(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
if dtype is None:
dtype = self.dtype(a)
r = torchlib.zeros_like(a)
return self.cast(r, dtype)
[docs]
def ones_like(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
if dtype is None:
dtype = self.dtype(a)
r = torchlib.ones_like(a)
return self.cast(r, dtype)
[docs]
def copy(self, a: Tensor) -> Tensor:
return a.clone()
[docs]
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
if self.is_tensor(tensor):
result = tensor
else:
result = torchlib.tensor(tensor)
if dtype is not None:
result = self.cast(result, dtype)
return result
[docs]
def set_random_state(
self, seed: Optional[int] = None, get_only: bool = False
) -> Any:
if isinstance(seed, torchlib.Generator):
g = seed
else:
g = torchlib.Generator()
if seed is not None:
g.manual_seed(seed)
if get_only is False:
self.g = g
return g
[docs]
def stateful_randn(
self,
g: Any,
shape: Union[int, Sequence[int]] = 1,
mean: float = 0,
stddev: float = 1,
dtype: Optional[str] = None,
) -> Tensor:
if dtype is None:
dtype = rdtypestr
if isinstance(dtype, str):
dtype = dtype[-2:]
if isinstance(shape, int):
shape = (shape,)
if dtype == "32":
dtyper = torchlib.float32
elif dtype == "64":
dtyper = torchlib.float64
elif not isinstance(dtype, str):
dtyper = dtype
else:
raise ValueError("unspported `dtype` %s" % dtype)
r = torchlib.randn(size=shape, generator=g, dtype=dtyper) * stddev + mean
return r
[docs]
def stateful_randu(
self,
g: Any,
shape: Union[int, Sequence[int]] = 1,
low: float = 0,
high: float = 1,
dtype: Optional[str] = None,
) -> Tensor:
if dtype is None:
dtype = rdtypestr
if isinstance(dtype, str):
dtype = dtype[-2:]
if isinstance(shape, int):
shape = (shape,)
if dtype == "32":
dtyper = torchlib.float32
elif dtype == "64":
dtyper = torchlib.float64
elif not isinstance(dtype, str):
dtyper = dtype
else:
raise ValueError("unspported `dtype` %s" % dtype)
r = torchlib.rand(size=shape, generator=g, dtype=dtyper) * (high - low) + low
return r
[docs]
def stateful_randc(
self,
g: Any,
a: Union[int, Sequence[int], Tensor],
shape: Union[int, Sequence[int]],
p: Optional[Union[Sequence[float], Tensor]] = None,
) -> Tensor:
if isinstance(shape, int):
shape = (shape,)
if isinstance(a, int):
possible_values = torchlib.arange(a)
n = a
else:
possible_values = self.convert_to_tensor(a)
n = possible_values.shape[0]
if p is None:
indices = torchlib.randint(0, n, size=shape, generator=g)
else:
p = self.convert_to_tensor(p)
num_samples = reduce(mul, shape)
indices = torchlib.multinomial(
p, num_samples, replacement=True, generator=g
)
indices = torchlib.reshape(indices, shape)
return possible_values[indices]
[docs]
def expm(self, a: Tensor) -> Tensor:
return torchlib.linalg.matrix_exp(a)
# raise NotImplementedError("pytorch backend doesn't support expm")
# in 2020, torch has no expm, hmmm. but that's ok,
# it doesn't support complex numbers which is more severe issue.
# see https://github.com/pytorch/pytorch/issues/9983
[docs]
def sin(self, a: Tensor) -> Tensor:
return torchlib.sin(a)
[docs]
def cos(self, a: Tensor) -> Tensor:
return torchlib.cos(a)
[docs]
def acos(self, a: Tensor) -> Tensor:
return torchlib.acos(a)
[docs]
def acosh(self, a: Tensor) -> Tensor:
return torchlib.acosh(a)
[docs]
def asin(self, a: Tensor) -> Tensor:
return torchlib.asin(a)
[docs]
def asinh(self, a: Tensor) -> Tensor:
return torchlib.asinh(a)
[docs]
def atan(self, a: Tensor) -> Tensor:
return torchlib.atan(a)
[docs]
def atan2(self, y: Tensor, x: Tensor) -> Tensor:
return torchlib.atan2(y, x)
[docs]
def atanh(self, a: Tensor) -> Tensor:
return torchlib.atanh(a)
[docs]
def cosh(self, a: Tensor) -> Tensor:
return torchlib.cosh(a)
[docs]
def tan(self, a: Tensor) -> Tensor:
return torchlib.tan(a)
[docs]
def tanh(self, a: Tensor) -> Tensor:
return torchlib.tanh(a)
[docs]
def sinh(self, a: Tensor) -> Tensor:
return torchlib.sinh(a)
[docs]
def size(self, a: Tensor) -> Tensor:
return a.size()
[docs]
def eigvalsh(self, a: Tensor) -> Tensor:
return torchlib.linalg.eigvalsh(a)
[docs]
def lobpcg_standard(
self,
a: Union[Tensor, Callable[[Tensor], Tensor]],
x0: Tensor,
m: int = 100,
tol: Optional[Union[Tensor, float]] = None,
) -> Tuple[Tensor, Tensor, int]:
"""
PyTorch LOBPCG implementation.
Note:
1. Complex input is not officially supported and is numerically unstable.
2. Callable input for operator 'a' is not supported.
"""
if callable(a):
raise NotImplementedError(
"PyTorch backend `lobpcg` does not support callable linear operator yet."
)
theta, x = torchlib.lobpcg(
a, X=x0, niter=m, tol=tol, largest=True, ortho_fparams={"eps": 1e-6}
)
return theta, x, m
[docs]
def kron(self, a: Tensor, b: Tensor) -> Tensor:
return torchlib.kron(a, b)
[docs]
def numpy(self, a: Tensor) -> Tensor:
if self.is_sparse(a):
a = a.coalesce()
return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape)
a = a.cpu()
if a.is_conj():
return a.resolve_conj().numpy()
if a.requires_grad:
return a.detach().numpy()
return a.numpy()
[docs]
def i(self, dtype: Any = None) -> Tensor:
if not dtype:
dtype = getattr(torchlib, dtypestr)
if isinstance(dtype, str):
dtype = getattr(torchlib, dtype)
return torchlib.tensor(1j, dtype=dtype)
[docs]
def det(self, a: Tensor) -> Tensor:
return torchlib.linalg.det(a)
[docs]
def real(self, a: Tensor) -> Tensor:
try:
a = torchlib.real(a)
except RuntimeError:
pass
return a
[docs]
def imag(self, a: Tensor) -> Tensor:
try:
a = torchlib.imag(a)
except RuntimeError:
pass
return a
[docs]
def dtype(self, a: Tensor) -> str:
return a.dtype.__str__().split(".")[-1] # type: ignore
[docs]
def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return torchlib.stack(a, dim=axis)
[docs]
def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return torchlib.cat(a, dim=axis)
[docs]
def tile(self, a: Tensor, rep: Tensor) -> Tensor:
return torchlib.tile(a, rep)
[docs]
def mean(
self,
a: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False,
) -> Tensor:
if axis is None:
axis = tuple([i for i in range(len(a.shape))])
return torchlib.mean(a, dim=axis, keepdim=keepdims)
[docs]
def std(
self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False
) -> Tensor:
if axis is None:
axis = tuple([i for i in range(len(a.shape))])
return torchlib.std(a, dim=axis, unbiased=False, keepdim=keepdims)
[docs]
def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
if axis is None:
return torchlib.min(a)
return torchlib.min(a, dim=axis).values
[docs]
def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
if axis is None:
return torchlib.max(a)
return torchlib.max(a, dim=axis).values
[docs]
def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
return torchlib.argmax(a, dim=axis)
[docs]
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
return torchlib.argmin(a, dim=axis)
[docs]
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
return torchlib.sort(a, dim=axis).values
[docs]
def top_k(self, a: Tensor, k: int) -> Tuple[Tensor, Tensor]:
r = torchlib.topk(a, k)
return r.values, r.indices
[docs]
def lexsort(self, keys: Any, axis: int = -1) -> Any:
# PyTorch stable sort for lexsort
if not keys:
return None
idx = torchlib.arange(keys[0].size(0), device=keys[0].device)
for k in keys:
idx = idx[torchlib.argsort(k[idx], stable=True)]
return idx
[docs]
def repeat(self, a: Any, repeats: Any, axis: Optional[int] = None) -> Any:
return torchlib.repeat_interleave(a, repeats, dim=axis)
[docs]
def popc(self, a: Any) -> Any:
if hasattr(torchlib, "bitwise_count"): # PyTorch 1.10+
return torchlib.bitwise_count(a)
# Fallback
c = a.to(torchlib.int64)
c = c - ((c >> 1) & 0x5555555555555555)
c = (c & 0x3333333333333333) + ((c >> 2) & 0x3333333333333333)
weight_w = (((c + (c >> 4)) & 0x0F0F0F0F0F0F0F0F) * 0x0101010101010101) >> 56
return weight_w.to(torchlib.int32)
[docs]
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
return torchlib.argsort(a, dim=axis)
[docs]
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
return torchlib.unique(a, return_counts=True) # type: ignore
[docs]
def sigmoid(self, a: Tensor) -> Tensor:
return torchlib.sigmoid(a)
[docs]
def relu(self, a: Tensor) -> Tensor:
return torchlib.relu(a)
[docs]
def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor:
return torchlib.nn.Softmax(a, dim=axis)
[docs]
def onehot(self, a: Tensor, num: int) -> Tensor:
a = a.long()
return torchlib.nn.functional.one_hot(a, num)
[docs]
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
if axis is None:
a = self.reshape(a, [-1])
return torchlib.cumsum(a, dim=0)
else:
return torchlib.cumsum(a, dim=axis)
[docs]
def is_tensor(self, a: Any) -> bool:
if isinstance(a, torchlib.Tensor):
return True
return False
[docs]
def cast(self, a: Tensor, dtype: Union[str, Any]) -> Tensor:
if isinstance(dtype, str):
torch_dtype = getattr(torchlib, dtype)
else:
torch_dtype = dtype
if torchlib.is_complex(a) and not getattr(torch_dtype, "is_complex", False):
a = torchlib.real(a)
return a.to(torch_dtype)
[docs]
def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor:
if stop is None:
return torchlib.arange(start=0, end=start, step=step)
return torchlib.arange(start=start, end=stop, step=step)
[docs]
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.fmod(x, y)
[docs]
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.floor_divide(x, y)
[docs]
def floor(self, a: Tensor) -> Tensor:
return torchlib.floor(a)
[docs]
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return torchlib.clamp(a, a_min, a_max)
[docs]
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_right_shift(x, y)
[docs]
def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_left_shift(x, y)
[docs]
def bitwise_and(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_and(x, y)
[docs]
def bitwise_xor(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_xor(x, y)
[docs]
def bitwise_or(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_or(x, y)
[docs]
def any(self, a: Tensor) -> Any:
return torchlib.any(a)
[docs]
def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Any:
if axis is None:
return torchlib.all(a)
return torchlib.all(a, dim=axis)
[docs]
def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
return torchlib.linalg.solve(A, b)
[docs]
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
if not self.is_tensor(a):
a = self.convert_to_tensor(a)
if not self.is_tensor(v):
v = self.convert_to_tensor(v)
return torchlib.searchsorted(a, v, side=side)
[docs]
def where(
self,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
if x is None and y is None:
return torchlib.where(condition)
assert x is not None and y is not None
return torchlib.where(condition, x, y)
[docs]
def scatter(
self, operand: Tensor, indices: Tensor, updates: Tensor, mode: str = "update"
) -> Tensor:
operand_new = operand.clone()
index_depth = indices.shape[-1]
idx_tuple = tuple(indices[..., i] for i in range(index_depth))
if mode == "update":
return operand_new.index_put(idx_tuple, updates, accumulate=False)
elif mode == "add":
return operand_new.index_put(idx_tuple, updates, accumulate=True)
elif mode == "sub":
return operand_new.index_put(idx_tuple, -updates, accumulate=True)
else:
raise ValueError(f"Unsupported scatter mode: {mode}")
[docs]
def reverse(self, a: Tensor) -> Tensor:
return torchlib.flip(a, dims=(-1,))
[docs]
def coo_sparse_matrix(
self, indices: Tensor, values: Tensor, shape: Tensor
) -> Tensor:
# Convert COO format to PyTorch sparse tensor
indices = self.convert_to_tensor(indices)
return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape)
[docs]
def sparse_dense_matmul(
self,
sp_a: Tensor,
b: Tensor,
) -> Tensor:
# Matrix multiplication between sparse and dense tensor
return torchlib.sparse.mm(sp_a, b)
[docs]
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
try:
# Convert COO to CSR format if supported
return coo.to_sparse_csr()
except AttributeError as e:
if not strict:
return coo
else:
raise e
[docs]
def to_dense(self, sp_a: Tensor) -> Tensor:
# Convert sparse tensor to dense
return sp_a.to_dense()
[docs]
def is_sparse(self, a: Tensor) -> bool:
# Check if tensor is sparse
return a.is_sparse or a.is_sparse_csr # type: ignore
[docs]
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
# torch native tree_map not support multiple pytree args
# return torchlib.utils._pytree.tree_map(f, *pytrees)
args = []
for pytree in pytrees:
flat_args, spec = self.tree_flatten(pytree)
args.append(flat_args)
res = [
f(*[args[i][k] for i in range(len(pytrees))]) for k in range(len(flat_args))
]
return self.tree_unflatten(spec, res)
[docs]
def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]:
return torchlib.utils._pytree.tree_flatten(pytree) # type: ignore
[docs]
def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any:
return torchlib.utils._pytree.tree_unflatten(leaves, treedef)
[docs]
def from_dlpack(self, a: Any) -> Tensor:
return torchlib.utils.dlpack.from_dlpack(a)
[docs]
def to_dlpack(self, a: Tensor) -> Any:
return torchlib.utils.dlpack.to_dlpack(a)
[docs]
def cond(
self,
pred: bool,
true_fun: Callable[[], Tensor],
false_fun: Callable[[], Tensor],
) -> Tensor:
if pred:
return true_fun()
return false_fun()
[docs]
def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor:
return branches[index.numpy()]()
[docs]
def device(self, a: Tensor) -> str:
dev = a.device
return self._dev2str(dev)
[docs]
def device_move(self, a: Tensor, dev: Any) -> Tensor:
if not isinstance(dev, str):
dev = self._dev2str(dev)
if dev.startswith("gpu"):
dev = "cuda:" + dev.split(":")[-1]
return a.to(device=dev)
def _dev2str(self, dev: Any) -> str:
if dev.type == "cpu":
return "cpu"
if dev.type == "cuda":
return "gpu:" + str(dev.index)
raise ValueError("PyTorchBackend don't support non-GPU/CPU device")
def _str2dev(self, str_: str) -> Any:
if str_ == "cpu":
return torchlib.device("cpu")
if str_.startswith("gpu"):
_id = int(str_.split(":")[-1])
return torchlib.cuda.device(_id)
raise ValueError("PyTorchBackend don't support non-GPU/CPU device")
[docs]
def stop_gradient(self, a: Tensor) -> Tensor:
return a.detach()
[docs]
def grad(
self,
f: Callable[..., Any],
argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False,
) -> Callable[..., Any]:
def wrapper(*args: Any, **kws: Any) -> Any:
y, gr = self.value_and_grad(f, argnums, has_aux)(*args, **kws)
if has_aux:
return gr, y[1:]
return gr
return wrapper
# def wrapper(*args: Any, **kws: Any) -> Any:
# x = []
# if isinstance(argnums, int):
# argnumsl = [argnums]
# # if you also call lhs as argnums, something weird may happen
# # the reason is that python then take it as local vars
# else:
# argnumsl = argnums # type: ignore
# for i, arg in enumerate(args):
# if i in argnumsl:
# x.append(arg.requires_grad_(True))
# else:
# x.append(arg)
# y = f(*x, **kws)
# y.backward()
# gs = [x[i].grad for i in argnumsl]
# if len(gs) == 1:
# gs = gs[0]
# return gs
# return wrapper
[docs]
def value_and_grad(
self,
f: Callable[..., Any],
argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False,
) -> Callable[..., Tuple[Any, Any]]:
def wrapper(*args: Any, **kws: Any) -> Any:
gavf = torchlib.func.grad_and_value(f, argnums=argnums, has_aux=has_aux)
g, v = gavf(*args, **kws)
return v, g
return wrapper
# def ask_require(t: Tensor) -> Any:
# t.requires_grad_(True)
# return t
# def get_grad(t: Tensor) -> Tensor:
# return t.grad
# def wrapper(*args: Any, **kws: Any) -> Any:
# # x = []
# if isinstance(argnums, int):
# argnumsl = [argnums]
# # if you also call lhs as argnums, something weird may happen
# # the reason is that python then take it as local vars
# else:
# argnumsl = argnums # type: ignore
# args = list(args)
# for i, arg in enumerate(args):
# if i in argnumsl:
# args[i] = self.tree_map(ask_require, arg)
# args = tuple(args)
# y = f(*args, **kws)
# if has_aux:
# y[0].backward()
# else:
# y.backward()
# gs = [self.tree_map(get_grad, x[i]) for i in argnumsl]
# if len(gs) == 1:
# gs = gs[0]
# return y, gs
# return wrapper
[docs]
def vjp(
self,
f: Callable[..., Any],
inputs: Union[Tensor, Sequence[Tensor]],
v: Union[Tensor, Sequence[Tensor]],
) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]:
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(v, list):
v = tuple(v)
return torchlib.autograd.functional.vjp(f, inputs, v) # type: ignore
[docs]
def jvp(
self,
f: Callable[..., Any],
inputs: Union[Tensor, Sequence[Tensor]],
v: Union[Tensor, Sequence[Tensor]],
) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]:
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(v, list):
v = tuple(v)
# for both tf and torch
# behind the scene: https://j-towns.github.io/2017/06/12/A-new-trick.html
# to be investigate whether the overhead issue remains as in
# https://github.com/renmengye/tensorflow-forward-ad/issues/2
return torchlib.autograd.functional.jvp(f, inputs, v) # type: ignore
[docs]
def vmap(
self,
f: Callable[..., Any],
vectorized_argnums: Union[int, Sequence[int]] = 0,
) -> Any:
if isinstance(vectorized_argnums, int):
vectorized_argnums = (vectorized_argnums,)
def wrapper(*args: Any, **kws: Any) -> Tensor:
in_axes = tuple([0 if i in vectorized_argnums else None for i in range(len(args))]) # type: ignore
return torchlib.vmap(f, in_axes, 0)(*args, **kws)
return wrapper
# v3
# logger.warning(
# "pytorch backend has no intrinsic vmap like interface"
# ", use plain for loop for compatibility"
# )
# # the vmap support is vey limited, f must return one tensor
# # nested list of tensor as return is not supported
# if isinstance(vectorized_argnums, int):
# vectorized_argnums = (vectorized_argnums,)
# def wrapper(*args: Any, **kws: Any) -> Tensor:
# results = []
# for barg in zip(*[args[i] for i in vectorized_argnums]): # type: ignore
# narg = []
# j = 0
# for k in range(len(args)):
# if k in vectorized_argnums: # type: ignore
# narg.append(barg[j])
# j += 1
# else:
# narg.append(args[k])
# results.append(f(*narg, **kws))
# return torchlib.stack(results)
# return wrapper
# v2
# def vmapf(*args: Tensor, **kws: Any) -> Tensor:
# r = []
# for i in range(args[0].shape[0]):
# nargs = [arg[i] for arg in args]
# r.append(f(*nargs, **kws))
# return torchlib.stack(r)
# return vmapf
# v1
# raise NotImplementedError("pytorch backend doesn't support vmap")
# There seems to be no map like architecture in pytorch for now
# see https://discuss.pytorch.org/t/fast-way-to-use-map-in-pytorch/70814
[docs]
def jit(
self,
f: Callable[..., Any],
static_argnums: Optional[Union[int, Sequence[int]]] = None,
jit_compile: Optional[bool] = None,
**kws: Any,
) -> Any:
if jit_compile is True:
# experimental feature reusing the jit_compile flag for tf
return torchlib.compile(f)
return f
# return f # do nothing here until I figure out what torch.jit is for and how does it work
# see https://github.com/pytorch/pytorch/issues/36910
[docs]
def vectorized_value_and_grad(
self,
f: Callable[..., Any],
argnums: Union[int, Sequence[int]] = 0,
vectorized_argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False,
) -> Callable[..., Tuple[Any, Any]]:
# [WIP], not a consistent impl compared to tf and jax backend, but pytorch backend is not fully supported anyway
if isinstance(vectorized_argnums, int):
vectorized_argnums = (vectorized_argnums,)
def wrapper(
*args: Any, **kws: Any
) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]:
jf = self.value_and_grad(f, argnums=argnums, has_aux=has_aux)
jf = self.vmap(jf, vectorized_argnums=vectorized_argnums)
vs, gs = jf(*args, **kws)
if isinstance(argnums, int):
argnums_list = [argnums]
gs = [gs]
else:
argnums_list = argnums # type: ignore
gs = list(gs)
for i, (j, g) in enumerate(zip(argnums_list, gs)):
if j not in vectorized_argnums: # type: ignore
gs[i] = self.tree_map(partial(torchlib.sum, dim=0), g)
if isinstance(argnums, int):
gs = gs[0]
else:
gs = tuple(gs)
return vs, gs
return wrapper
[docs]
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
return torchlib.unsqueeze(a, dim=axis)
vvag = vectorized_value_and_grad
[docs]
def meshgrid(self, *args: Any, **kws: Any) -> Tensor:
return torchlib.meshgrid(*args, **kws)
optimizer = torch_optimizer