Skip to content

Commit

Permalink
add torch "expand_dims"
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Aug 1, 2023
1 parent 84b070c commit 9f3065a
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,7 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_FUNC_ALIASES["torch", "linalg.expm"] = "matrix_exp"
_FUNC_ALIASES["torch", "conjugate"] = "conj"
_FUNC_ALIASES["torch", "split"] = "tensor_split"
_FUNC_ALIASES["torch", "expand_dims"] = "unsqueeze"

_SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "random.normal"] = "torch"
Expand Down Expand Up @@ -2034,6 +2035,9 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_CUSTOM_WRAPPERS["torch", "take"] = make_translator(
[("a", ("input",)), ("indices", ("index",)), ("axis", ("dim",))]
)
_CUSTOM_WRAPPERS["torch", "expand_dims"] = make_translator(
[("a", ("input",)), ("axis", ("dim",))]
)

# for older versions of torch, can provide some alternative implementations
_MODULE_ALIASES["torch[alt]"] = "torch"
Expand Down

0 comments on commit 9f3065a

Please sign in to comment.