Skip to content

Commit

Permalink
various fixes for torch
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 21, 2024
1 parent 924e081 commit 9acdbde
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,7 +1866,7 @@ def tensorflow_indices(dimensions):
@shape.register("torch")
def torch_shape(x):
# torch returns a Size object, we want tuple[int]
return tuple(x.shape)
return tuple(map(int, x.shape))


@size.register("torch")
Expand Down Expand Up @@ -2003,7 +2003,7 @@ def torch_zeros_ones_wrap(fn):
def numpy_like(shape, dtype=None, **kwargs):
if dtype is not None:
dtype = to_backend_dtype(dtype, like="torch")
return fn(shape, dtype=dtype)
return fn(shape, dtype=dtype, **kwargs)

return numpy_like

Expand All @@ -2021,6 +2021,14 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
return numpy_like


def torch_sort_wrap(fn):
@functools.wraps(fn)
def numpy_like(a, axis=-1):
return fn(a, dim=axis)[0]

return numpy_like


def torch_indices(dimensions):
_meshgrid = get_lib_fn("torch", "meshgrid")
_arange = get_lib_fn("torch", "arange")
Expand All @@ -2045,6 +2053,7 @@ def torch_indices(dimensions):
_FUNC_ALIASES["torch", "conjugate"] = "conj"
_FUNC_ALIASES["torch", "expand_dims"] = "unsqueeze"
_FUNC_ALIASES["torch", "linalg.expm"] = "matrix_exp"
_FUNC_ALIASES["torch", "scipy.linalg.expm"] = "matrix_exp"
_FUNC_ALIASES["torch", "max"] = "amax"
_FUNC_ALIASES["torch", "min"] = "amin"
_FUNC_ALIASES["torch", "power"] = "pow"
Expand All @@ -2055,6 +2064,7 @@ def torch_indices(dimensions):
_FUNC_ALIASES["torch", "identity"] = "eye"

_SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "scipy.linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "random.normal"] = "torch"
_SUBMODULE_ALIASES["torch", "random.uniform"] = "torch"

Expand Down Expand Up @@ -2087,6 +2097,7 @@ def torch_indices(dimensions):
_CUSTOM_WRAPPERS["torch", "expand_dims"] = make_translator(
[("a", ("input",)), ("axis", ("dim",))]
)
_CUSTOM_WRAPPERS["torch", "sort"] = torch_sort_wrap

# for older versions of torch, can provide some alternative implementations
_MODULE_ALIASES["torch[alt]"] = "torch"
Expand Down

0 comments on commit 9acdbde

Please sign in to comment.