Skip to content

Commit

Permalink
Leveraged dpctl.tensor.stack() implementation (#1509)
Browse files Browse the repository at this point in the history
* Leveraged dpctl.tensor.stack() implementation

* Relaxed check in a test of SYCL queue to account the error of floating operations
  • Loading branch information
antonwolfy authored Aug 7, 2023
1 parent 6538397 commit 0a45a54
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 89 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ env:
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
TEST_SCOPE: >-
test_arraycreation.py
test_arraymanipulation.py
test_dot.py
test_dparray.py
test_fft.py
Expand All @@ -23,6 +24,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/manipulation_tests/test_join.py
third_party/cupy/math_tests/test_explog.py
third_party/cupy/math_tests/test_misc.py
third_party/cupy/math_tests/test_trigonometric.py
Expand Down
71 changes: 66 additions & 5 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def broadcast_to(x, /, shape, subok=False):
return call_origin(numpy.broadcast_to, x, shape=shape, subok=subok)


def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
def concatenate(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
"""
Join a sequence of arrays along an existing axis.
Expand All @@ -253,8 +253,7 @@ def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
will be raised.
Parameter `out` is supported with default value.
Parameter `dtype` is supported with default value.
Parameters `out` and `dtype are supported with default value.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the function will be executed sequentially on CPU.
Expand Down Expand Up @@ -834,15 +833,77 @@ def squeeze(x, /, axis=None):
return call_origin(numpy.squeeze, x, axis)


def stack(arrays, axis=0, out=None):
def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
"""
Join a sequence of arrays along a new axis.
For full documentation refer to :obj:`numpy.stack`.
Returns
-------
out : dpnp.ndarray
The stacked array which has one more dimension than the input arrays.
Limitations
-----------
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
will be raised.
Parameters `out` and `dtype are supported with default value.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the function will be executed sequentially on CPU.
See Also
--------
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
:obj:`dpnp.block` : Assemble an nd-array from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size.
Examples
--------
>>> import dpnp as np
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
>>> np.stack(arrays, axis=0).shape
(10, 3, 4)
>>> np.stack(arrays, axis=1).shape
(3, 10, 4)
>>> np.stack(arrays, axis=2).shape
(3, 4, 10)
>>> a = np.array([1, 2, 3])
>>> b = np.array([4, 5, 6])
>>> np.stack((a, b))
array([[1, 2, 3],
[4, 5, 6]])
>>> np.stack((a, b), axis=-1)
array([[1, 4],
[2, 5],
[3, 6]])
"""

return call_origin(numpy.stack, arrays, axis, out)
if kwargs:
pass
elif out is not None:
pass
elif dtype is not None:
pass
else:
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
usm_res = dpt.stack(usm_arrays, axis=axis)
return dpnp_array._create_from_usm_ndarray(usm_res)

return call_origin(
numpy.stack,
arrays,
axis=axis,
out=out,
dtype=dtype,
**kwargs,
)


def swapaxes(x1, axis1, axis2):
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def allow_fall_back_on_numpy(monkeypatch):
)


@pytest.fixture
def suppress_complex_warning():
sup = numpy.testing.suppress_warnings("always")
sup.filter(numpy.ComplexWarning)
with sup:
yield


@pytest.fixture
def suppress_divide_numpy_warnings():
# divide: treatment for division by zero (infinite result obtained from finite numbers)
Expand Down
Loading

0 comments on commit 0a45a54

Please sign in to comment.