Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFT functions #50

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions unumpy/fft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._multimethods import *
116 changes: 116 additions & 0 deletions unumpy/fft/_multimethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import functools
import operator
from uarray import create_multimethod, mark_as, all_of_type, Dispatchable
import builtins

create_numpy = functools.partial(create_multimethod, domain="numpy")

from .._multimethods import ndarray, _identity_argreplacer, _self_argreplacer


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def fft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def ifft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def fft2(a, s=None, axes=(-2, -1), norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def ifft2(a, s=None, axes=(-2, -1), norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def fftn(a, s=None, axes=None, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def ifftn(a, s=None, axes=None, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def rfft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def irfft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def rfft2(a, s=None, axes=(-2, -1), norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def irfft2(a, s=None, axes=(-2, -1), norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def rfftn(a, s=None, axes=None, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def irfftn(a, s=None, axes=None, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def hfft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def ihfft(a, n=None, axis=-1, norm=None):
return (a,)


@create_numpy(_identity_argreplacer)
@all_of_type(ndarray)
def fftfreq(n, d=1.0):
return ()


@create_numpy(_identity_argreplacer)
@all_of_type(ndarray)
def rfftfreq(n, d=1.0):
return ()


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def fftshift(x, axes=None):
return (x,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def ifftshift(x, axes=None):
return (x,)
10 changes: 6 additions & 4 deletions unumpy/numpy_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from uarray import Dispatchable, wrap_single_convertor
from unumpy import ufunc, ufunc_list, ndarray, dtype
from unumpy import ufunc, ufunc_list, ndarray, dtype, fft
import unumpy
import functools

Expand All @@ -24,11 +24,13 @@ def __ua_function__(method, args, kwargs):
if method in _implementations:
return _implementations[method](*args, **kwargs)

if not hasattr(np, method.__name__):
if hasattr(np, method.__name__):
return getattr(np, method.__name__)(*args, **kwargs)
elif hasattr(np.linalg, method.__name__):
return getattr(np.linalg, method.__name__)(*args, **kwargs)
else:
return NotImplemented

return getattr(np, method.__name__)(*args, **kwargs)


@wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
Expand Down
38 changes: 38 additions & 0 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,41 @@ def test_array_creation(backend, method, args, kwargs):
assert ret.dtype == ndt(dtype)
else:
assert ret.dtype == dtype


@pytest.mark.parametrize(
"method, args, kwargs",
[
# (np.fft.fft, (4 * np.eye(4),), {}),
(np.fft.ifft, ([[1, 2], [3, 4]],), {}),
(np.fft.fft2, (4 * np.eye(4),), {}),
(np.fft.ifft2, (4 * np.eye(4),), {}),
(np.fft.fftn, (4 * np.eye(4),), {}),
(np.fft.ifftn, (4 * np.eye(4),), {}),
(np.fft.rfft, (4 * np.eye(4),), {}),
(np.fft.irfft, (4 * np.eye(4),), {}),
(np.fft.rfft2, (4 * np.eye(4),), {}),
(np.fft.irfft2, (4 * np.eye(4),), {}),
(np.fft.rfftn, (4 * np.eye(4),), {}),
(np.fft.irfftn, (4 * np.eye(4)), {}),
(np.fft.hfft, (4 * np.eye(4)), {}),
(np.fft.ihfft, (4 * np.eye(4)), {}),
(np.fft.rfftn, (4 * np.eye(4)), {}),
(np.fft.fftfreq, (10), {}),
(np.fft.rfftfreq, (10), {}),
Comment on lines +304 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are missing the commas that makes args a tuple.

Suggested change
(np.fft.fftfreq, (10), {}),
(np.fft.rfftfreq, (10), {}),
(np.fft.fftfreq, (10,), {}),
(np.fft.rfftfreq, (10,), {}),

I think that code coverage may be crashing because the ast is invalid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, a bunch of the lists above are missing the comma as well.

Copy link
Collaborator

@hameerabbasi hameerabbasi Mar 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the AST is valid, just that *a isn't valid when a isn't a sequence or iterable, which is a runtime error.

(np.fft.fftshift, (4 * np.eye(4),), {}),
(np.fft.ifftshift, (4 * np.eye(4),), {}),
],
)
def test_fft(backend, method, args, kwargs):
backend, types = backend
try:
with ua.set_backend(backend, coerce=True):
ret = method(*args, **kwargs)
except ua.BackendNotImplementedError:
if backend in FULLY_TESTED_BACKENDS and (backend, method) not in EXCEPTIONS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")

if isinstance(ret, da.Array):
ret.compute()