Skip to content

Commit

Permalink
using rowvar flag in dpnp.cov
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Apr 12, 2023
1 parent d65a635 commit 4e04716
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
18 changes: 11 additions & 7 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from dpnp.dpnp_algo import *
from dpnp.dpnp_utils import *
from dpnp.dpnp_array import dpnp_array
import dpnp


Expand Down Expand Up @@ -244,11 +245,10 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
Limitations
-----------
Input array ``m`` is supported as :obj:`dpnp.ndarray`.
Dimension of input array ``m`` is limited by ``m.ndim > 2``.
Input array ``x1`` is supported as :obj:`dpnp.ndarray`.
Dimension of input array ``x1`` is limited by ``x1.ndim > 2``.
Size and shape of input arrays are supported to be equal.
Prameters ``y`` is supported only with default value ``None``.
Prameters ``rowvar`` is supported only with default value ``True``.
Prameters ``bias`` is supported only with default value ``False``.
Prameters ``ddof`` is supported only with default value ``None``.
Prameters ``fweights`` is supported only with default value ``None``.
Expand Down Expand Up @@ -280,8 +280,6 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
pass
elif y is not None:
pass
elif not rowvar:
pass
elif bias:
pass
elif ddof is not None:
Expand All @@ -291,8 +289,14 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
elif aweights is not None:
pass
else:
if x1_desc.dtype != dpnp.float64:
x1_desc = dpnp.get_dpnp_descriptor(dpnp.astype(x1, dpnp.float64), copy_when_nondefault_queue=False)
if not rowvar and x1.shape[0] != 1:
x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1
x1 = dpnp_array._create_from_usm_ndarray(x1.mT)
x1 = dpnp.astype(x1, dpnp.float64) if x1_desc.dtype != dpnp.float64 else x1
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
elif x1_desc.dtype != dpnp.float64:
x1 = dpnp.astype(x1, dpnp.float64)
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)

return dpnp_cov(x1_desc).get_pyobj()

Expand Down
17 changes: 16 additions & 1 deletion tests/test_statistics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

from .helper import get_all_dtypes
import dpnp

import numpy
Expand Down Expand Up @@ -114,3 +114,18 @@ def test_bincount_weights(self, array, weights):
expected = numpy.bincount(np_a, weights=weights)
result = dpnp.bincount(dpnp_a, weights=weights)
numpy.testing.assert_array_equal(expected, result)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True))
def test_cov_rowvar1(dtype):
a = dpnp.array([[0, 2], [1, 1], [2, 0]], dtype=dtype)
b = numpy.array([[0, 2], [1, 1], [2, 0]], dtype=dtype)
numpy.testing.assert_array_equal(dpnp.cov(a.T), dpnp.cov(a,rowvar=False))
numpy.testing.assert_array_equal(numpy.cov(b,rowvar=False), dpnp.cov(a,rowvar=False))

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True))
def test_cov_rowvar2(dtype):
a = dpnp.array([[0, 1, 2]], dtype=dtype)
b = numpy.array([[0, 1, 2]], dtype=dtype)
numpy.testing.assert_array_equal(numpy.cov(b,rowvar=False), dpnp.cov(a,rowvar=False))


0 comments on commit 4e04716

Please sign in to comment.