From 8b4a7b0e27af75c1b887a93e57580362888e6de8 Mon Sep 17 00:00:00 2001 From: alicia <32725332+Alicia1529@users.noreply.github.com> Date: Fri, 14 Feb 2020 02:07:26 +0800 Subject: [PATCH] add polyval (#17416) --- python/mxnet/ndarray/numpy/_op.py | 66 +++++++++- python/mxnet/numpy/multiarray.py | 54 +++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 40 +++++- src/operator/numpy/np_polynomial_op-inl.h | 100 +++++++++++++++ src/operator/numpy/np_polynomial_op.cc | 120 ++++++++++++++++++ src/operator/numpy/np_polynomial_op.cu | 93 ++++++++++++++ .../unittest/test_numpy_interoperability.py | 15 +++ tests/python/unittest/test_numpy_op.py | 67 ++++++++++ 9 files changed, 548 insertions(+), 8 deletions(-) create mode 100644 src/operator/numpy/np_polynomial_op-inl.h create mode 100644 src/operator/numpy/np_polynomial_op.cc create mode 100644 src/operator/numpy/np_polynomial_op.cu diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 5b0c9755143f..f560248f5a0e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -44,7 +44,7 @@ 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', - 'diff', 'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount'] + 'diff', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount'] @set_module('mxnet.ndarray.numpy') @@ -203,7 +203,11 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None): >>> np.zeros_like(y) array([0., 0., 0.], dtype=float64) """ - return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=None, out=None) + if order != 'C': + raise NotImplementedError + if ctx is None: + ctx = current_context() + return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=ctx, out=out) @set_module('mxnet.ndarray.numpy') @@ -259,7 +263,11 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None): >>> np.ones_like(y) array([1., 1., 1.], dtype=float64) """ - return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=None, out=None) + if order != 'C': + raise NotImplementedError + if ctx is None: + ctx = current_context() + return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=ctx, out=out) @set_module('mxnet.ndarray.numpy') @@ -6922,6 +6930,58 @@ def where(condition, x=None, y=None): # pylint: disable=too-many-return-stateme raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y)))) +@set_module('mxnet.ndarray.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : ndarray + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : ndarray + An array of numbers, at which to evaluate p. + + Returns + ------- + values : ndarray + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + + Examples + -------- + >>> p = np.array([3, 0, 1]) + array([3., 0., 1.]) + >>> x = np.array([5]) + array([5.]) + >>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1 + array([76.]) + >>> x = np.array([5, 4]) + array([5., 4.]) + >>> np.polyval(p, x) + array([76., 49.]) + """ + from ...numpy import ndarray + if isinstance(p, ndarray) and isinstance(x, ndarray): + return _npi.polyval(p, x) + elif not isinstance(p, ndarray) and not isinstance(x, ndarray): + return _np.polyval(p, x) + else: + raise TypeError('type not supported') + + @set_module('mxnet.ndarray.numpy') def bincount(x, weights=None, minlength=0): """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index fc1a489ca902..4a7a85ad6b31 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -65,7 +65,7 @@ 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'matmul', - 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount'] + 'nan_to_num', 'isnan', 'isinf', 'polyval', 'where', 'bincount'] __all__ += fallback.__all__ @@ -8596,7 +8596,7 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin >>> np.full_like(y, 0.1) array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) """ - return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype, order=order, ctx=None, out=None) + return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype, order=order, ctx=ctx, out=out) @set_module('mxnet.numpy') @@ -8652,7 +8652,7 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None): >>> np.zeros_like(y) array([0., 0., 0.], dtype=float64) """ - return _mx_nd_np.full_like(a, fill_value=0, dtype=dtype, order=order, ctx=None, out=None) + return _mx_nd_np.full_like(a, fill_value=0, dtype=dtype, order=order, ctx=ctx, out=ctx) @set_module('mxnet.numpy') @@ -8708,7 +8708,7 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None): >>> np.ones_like(y) array([1., 1., 1.], dtype=float64) """ - return _mx_nd_np.full_like(a, fill_value=1, dtype=dtype, order=order, ctx=None, out=None) + return _mx_nd_np.full_like(a, fill_value=1, dtype=dtype, order=order, ctx=ctx, out=out) @set_module('mxnet.numpy') @@ -8972,6 +8972,52 @@ def where(condition, x=None, y=None): return _mx_nd_np.where(condition, x, y) +@set_module('mxnet.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : ndarray + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : ndarray + An array of numbers, at which to evaluate p. + + Returns + ------- + values : ndarray + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + + Examples + -------- + >>> p = np.array([3, 0, 1]) + array([3., 0., 1.]) + >>> x = np.array([5]) + array([5.]) + >>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1 + array([76.]) + >>> x = np.array([5, 4]) + array([5., 4.]) + >>> np.polyval(p, x) + array([76., 49.]) + """ + return _mx_nd_np.polyval(p, x) + + @set_module('mxnet.numpy') def bincount(x, weights=None, minlength=0): """ diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 2c21377efe9e..bf4341fff21d 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -162,6 +162,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'meshgrid', 'outer', 'einsum', + 'polyval', 'shares_memory', 'may_share_memory', 'quantile', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index bb556f5e04f8..6c3b5234bd90 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -51,7 +51,7 @@ 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', - 'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount'] + 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount'] @set_module('mxnet.symbol.numpy') @@ -6182,6 +6182,44 @@ def load_json(json_str): return _Symbol(handle) +@set_module('mxnet.symbol.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : _Symbol + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : _Symbol + An array of numbers, at which to evaluate p. + + Returns + ------- + values : _Symbol + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + """ + if isinstance(p, Symbol) and isinstance(x, Symbol): + return _npi.polyval(p, x) + elif not isinstance(p, Symbol) and not isinstance(x, Symbol): + return _np.polyval(p, x) + else: + raise TypeError('type not supported') + + @set_module('mxnet.symbol.numpy') def bincount(x, weights=None, minlength=0): """ diff --git a/src/operator/numpy/np_polynomial_op-inl.h b/src/operator/numpy/np_polynomial_op-inl.h new file mode 100644 index 000000000000..f3b4424283dc --- /dev/null +++ b/src/operator/numpy/np_polynomial_op-inl.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_polynomial_op.h + * \brief Functions for dealing with polynomials. + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_ + +#include +#include +#include +#include +#include "../mxnet_op.h" +#include "../../common/utils.h" +#include "../tensor/elemwise_binary_broadcast_op.h" + + +namespace mxnet { +namespace op { + +inline bool NumpyPolyvalShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& p_shape = in_attrs->at(0); + const mxnet::TShape& x_shape = in_attrs->at(1); + const mxnet::TShape& v_shape = out_attrs->at(0); + CHECK_EQ(p_shape.ndim(), 1U) << "ValueError: p has to be an 1-D array."; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, x_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, v_shape); + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +template +struct polyval_forward { + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + const DType* p_data, + const DType* x_data, + const index_t p_size) { + DType val = 0; + for (index_t j = 0; j < p_size; j++) { + val = val * x_data[i] + p_data[j]; + } + KERNEL_ASSIGN(out_data[i], req, val); + } +}; + +template +void NumpyPolyvalForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& p_data = inputs[0]; + const TBlob& x_data = inputs[1]; + const TBlob& out_data = outputs[0]; + const size_t p_size = p_data.Size(); + using namespace mxnet_op; + + MSHADOW_TYPE_SWITCH(x_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, out_data.Size(), out_data.dptr(), + p_data.dptr(), x_data.dptr(), p_size); + }); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_ diff --git a/src/operator/numpy/np_polynomial_op.cc b/src/operator/numpy/np_polynomial_op.cc new file mode 100644 index 000000000000..155c98fd1cc5 --- /dev/null +++ b/src/operator/numpy/np_polynomial_op.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 202o by Contributors + * \file np_polynomial_op.cc +*/ + +#include +#include "np_polynomial_op-inl.h" + +namespace mxnet { +namespace op { + +template +struct polyval_backward_x { + template + MSHADOW_XINLINE static void Map(int i, const DType* p_dptr, const DType* x_dptr, + DType* igrad_x_dptr, const DType* ograd_dptr, + const index_t p_size) { + DType igrad_x = 0; + index_t j = p_size - 1; + while (j > 0) { + igrad_x = igrad_x * x_dptr[i] + p_dptr[p_size - j - 1] * j; + j--; + } + KERNEL_ASSIGN(igrad_x_dptr[i], req, igrad_x * ograd_dptr[i]); + } +}; + +template +struct polyval_backward_p { + template + MSHADOW_XINLINE static void Map(int i, const DType* p_dptr, const DType* x_dptr, + DType* igrad_p_dptr, const DType* ograd_dptr, + const index_t p_size, const index_t x_size) { + DType igrad_p = 0; + index_t j = x_size - 1; + while (j >= 0) { + igrad_p += pow(x_dptr[j], p_size - i - 1) * ograd_dptr[j]; + j--; + } + KERNEL_ASSIGN(igrad_p_dptr[i], req, igrad_p); + } +}; + +void NumpyPolyvalBackwardCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_NE(req[0], kWriteInplace); + + if (inputs[1].type_flag_ != inputs[2].type_flag_ || + !common::is_float(inputs[1].type_flag_) || + !common::is_float(inputs[2].type_flag_)) { + return; + } + + mshadow::Stream *s = ctx.get_stream(); + const TBlob& ograd = inputs[0]; + const TBlob& p = inputs[1]; + const TBlob& x = inputs[2]; + const TBlob& igrad_p = outputs[0]; + const TBlob& igrad_x = outputs[1]; + const size_t p_size = p.Size(); + + using namespace mxnet_op; + MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, cpu>::Launch( + s, ograd.Size(), p.dptr(), x.dptr(), + igrad_x.dptr(), ograd.dptr(), p_size); + Kernel, cpu>::Launch( + s, p_size, p.dptr(), x.dptr(), + igrad_p.dptr(), ograd.dptr(), p_size, x.Size()); + }); + }); +} + +NNVM_REGISTER_OP(_npi_polyval) +.set_num_inputs(2) +.set_num_outputs(1) +.add_argument("p", "NDArray-or-Symbol", "polynomial coefficients") +.add_argument("x", "NDArray-or-Symbol", "variables") +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"p", "x"}; +}) +.set_attr("FInferShape", NumpyPolyvalShape) +.set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) +.set_attr("FCompute", NumpyPolyvalForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_npi_backward_polyval"}); + +NNVM_REGISTER_OP(_npi_backward_polyval) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyPolyvalBackwardCPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_polynomial_op.cu b/src/operator/numpy/np_polynomial_op.cu new file mode 100644 index 000000000000..31f284b7a2a8 --- /dev/null +++ b/src/operator/numpy/np_polynomial_op.cu @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_polynomial_op.cu + */ + +#include "np_polynomial_op-inl.h" +#include "../../common/cuda_utils.h" + +namespace mxnet { +namespace op { + +template +struct polyval_backward_gpu { + template + MSHADOW_XINLINE static void Map(int i, const DType* p_dptr, const DType* x_dptr, + DType* igrad_x_dptr, DType* igrad_p_dptr, + const DType* ograd_dptr, const index_t p_size) { + DType igrad_p = 1; + DType igrad_x = 0; + index_t j = p_size - 1; + while (j > 0) { + // atomic add since different threads could update same variable + atomicAdd(&igrad_p_dptr[j], igrad_p * ograd_dptr[i]); + igrad_p *= x_dptr[i]; + igrad_x = igrad_x * x_dptr[i] + p_dptr[p_size - j - 1] * j; + j--; + } + atomicAdd(&igrad_p_dptr[j], igrad_p * ograd_dptr[i]); + KERNEL_ASSIGN(igrad_x_dptr[i], req, igrad_x * ograd_dptr[i]); + } +}; + +void NumpyPolyvalBackwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_NE(req[0], kWriteInplace); + + if (inputs[1].type_flag_ != inputs[2].type_flag_ || + !common::is_float(inputs[1].type_flag_) || + !common::is_float(inputs[2].type_flag_)) { + return; + } + + mshadow::Stream *s = ctx.get_stream(); + const TBlob& ograd = inputs[0]; + const TBlob& p = inputs[1]; + const TBlob& x = inputs[2]; + const TBlob& igrad_p = outputs[0]; + const TBlob& igrad_x = outputs[1]; + const size_t p_size = p.Size(); + + using namespace mxnet_op; + MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, gpu>::Launch( + s, ograd.Size(), p.dptr(), x.dptr(), + igrad_x.dptr(), igrad_p.dptr(), + ograd.dptr(), p_size); + }); + }); +} + +NNVM_REGISTER_OP(_npi_polyval) +.set_attr("FCompute", NumpyPolyvalForward); + +NNVM_REGISTER_OP(_npi_backward_polyval) +.set_attr("FCompute", NumpyPolyvalBackwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 586844a5d463..610b4caa69d5 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1884,6 +1884,20 @@ def _add_workload_isinf(): OpArgMngr.add_workload('isinf', array3) +def _add_workload_polyval(): + p1 = np.arange(20) + p2 = np.arange(1) + x1 = np.arange(20) + x2 = np.ones((3,3)) + x3 = np.array(2) + OpArgMngr.add_workload('polyval', p1, x1) + OpArgMngr.add_workload('polyval', p1, x2) + OpArgMngr.add_workload('polyval', p1, x3) + OpArgMngr.add_workload('polyval', p2, x1) + OpArgMngr.add_workload('polyval', p2, x2) + OpArgMngr.add_workload('polyval', p2, x3) + + def _add_workload_linalg_cond(): A = np.array([[1., 0, 1], [0, -2., 0], [0, 0, 3.]]) OpArgMngr.add_workload('linalg.cond', A, np.inf) @@ -2069,6 +2083,7 @@ def _prepare_workloads(): _add_workload_nan_to_num() _add_workload_isnan() _add_workload_isinf() + _add_workload_polyval() _add_workload_heaviside() _add_workload_spacing() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5ee7f1f0a0e2..cb89879142f7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -7423,6 +7423,73 @@ def hybrid_forward(self, F, a): check_unary_func("isinf") +@with_seed() +@use_np +def test_np_polyval(): + class TestPolyval(HybridBlock): + def __init__(self): + super(TestPolyval, self).__init__() + + def hybrid_forward(self, F, p, x, *args, **kwargs): + return F.np.polyval(p, x) + + def polyval_grad(p, x): + x_shape = x.shape + x = x.reshape((x.size, 1)) + x = _np.broadcast_to(x, (x.size, p.size)) + exp = _np.arange(p.size-1, -1, -1) + p_grad = _np.power(x, exp) + coeff = exp-1 + coeff[-1] = 0 + x_grad = _np.power(x, coeff) * p * exp + p_grad = _np.sum(p_grad, axis=0) + x_grad = _np.sum(x_grad, axis=-1).reshape(x_shape) + return (p_grad, x_grad) + + dtypes = ['float32', 'float64', 'int32', 'int64'] + x_shapes = [ + (5,), + (10), + (3, 3), + (3, 4), + (3, 3, 3), + (2, 2, 4, 3), + (2, 0, 2, 3) + ] + flags = [True, False] + for dtype, x_shape, hybridize in itertools.product(dtypes, x_shapes, flags): + p_shape = (random.randint(1, 8),) + test_polyval = TestPolyval() + if hybridize: + test_polyval.hybridize() + rtol = 1e-2 + atol = 1e-4 + if dtype in ['int32', 'int64']: + p = np.random.randint(-16, 16, p_shape, dtype=dtype) + x = np.random.randint(-5, 5, x_shape, dtype=dtype) + else: + p = np.random.uniform(-1.0, 1.0, size=p_shape, dtype=dtype) + x = np.random.uniform(-1.0, 1.0, size=x_shape, dtype=dtype) + + p.attach_grad() + x.attach_grad() + np_out = _np.polyval(p.asnumpy(), x.asnumpy()) + with mx.autograd.record(): + mx_out = test_polyval(p, x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol) + + mx_out.backward() + if dtype in ['float16', 'float32', 'float64']: + p_grad, x_grad = polyval_grad(p.asnumpy(), x.asnumpy()) + assert_almost_equal(p.grad.asnumpy(), p_grad, atol=atol, rtol=rtol) + assert_almost_equal(x.grad.asnumpy(), x_grad, atol=atol, rtol=rtol) + + mx_out = np.polyval(p, x) + np_out = _np.polyval(p.asnumpy(), x.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol) + + @with_seed() @use_np def test_np_where():