Skip to content

Commit

Permalink
add multi dispatch for various binary funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 12, 2023
1 parent 4b67e22 commit d5d0c27
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def import_lib_fn(backend, fn):
full_location = ".".join([full_location] + split_fn[:-1])
only_fn = split_fn[-1]

# try aliases for global (not function specific) modules and submodules
# try aliases for global (not function specific) modules and
# submodules:
# e.g. 'decimal' -> 'math'
# e.g. 'cupy.scipy' -> 'cupyx.scipy'
# we don't do this if the function location has been explicitly
Expand Down Expand Up @@ -1325,14 +1326,22 @@ def einsum_dispatcher(*args, **_):
register_dispatch("einsum", einsum_dispatcher)


def tensordot_dispatcher(*args, **_):
"""There are cases when we want to take into account both backends."""
def binary_dispatcher(*args, **_):
"""There are cases when we want to take into account both backends of two
arguments, e.g. a lazy variable and a constant array.
"""
return infer_backend_multi(*args[:2])


register_dispatch("tensordot", tensordot_dispatcher)
register_dispatch("tensordot", binary_dispatcher)
register_dispatch("matmul", binary_dispatcher)
register_dispatch("multiply", binary_dispatcher)
register_dispatch("divide", binary_dispatcher)
register_dispatch("true_divide", binary_dispatcher)
register_dispatch("add", binary_dispatcher)
register_dispatch("subtract", binary_dispatcher)

# TODO: register other binary functions such as add, matmul etc?
# TODO: register other binary functions?

# --------------- object to act as drop-in replace for numpy ---------------- #

Expand Down

0 comments on commit d5d0c27

Please sign in to comment.