Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add polyval
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 committed Feb 13, 2020
1 parent f5a1014 commit 8a385b8
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 8 deletions.
66 changes: 63 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']
'diff', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -203,7 +203,11 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None):
>>> np.zeros_like(y)
array([0., 0., 0.], dtype=float64)
"""
return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=None, out=None)
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=ctx, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -259,7 +263,11 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None):
>>> np.ones_like(y)
array([1., 1., 1.], dtype=float64)
"""
return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=None, out=None)
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=ctx, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6922,6 +6930,58 @@ def where(condition, x=None, y=None): # pylint: disable=too-many-return-stateme
raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y))))


@set_module('mxnet.ndarray.numpy')
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
If p is of length N, this function returns the value:
p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]
If x is a sequence, then p(x) is returned for each element of x.
If x is another polynomial then the composite polynomial p(x(t)) is returned.
Parameters
----------
p : ndarray
1D array of polynomial coefficients (including coefficients equal to zero)
from highest degree to the constant term.
x : ndarray
An array of numbers, at which to evaluate p.
Returns
-------
values : ndarray
Result array of polynomials
Notes
-----
This function differs from the original `numpy.polyval
<https://numpy.org/devdocs/reference/generated/numpy.polyval.html>`_ in
the following way(s):
- Does not support poly1d.
- X should be ndarray type even if it contains only one element.
Examples
--------
>>> p = np.array([3, 0, 1])
array([3., 0., 1.])
>>> x = np.array([5])
array([5.])
>>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1
array([76.])
>>> x = np.array([5, 4])
array([5., 4.])
>>> np.polyval(p, x)
array([76., 49.])
"""
from ...numpy import ndarray
if isinstance(p, ndarray) and isinstance(x, ndarray):
return _npi.polyval(p, x)
elif not isinstance(p, ndarray) and not isinstance(x, ndarray):
return _np.polyval(p, x)
else:
raise TypeError('type not supported')


@set_module('mxnet.ndarray.numpy')
def bincount(x, weights=None, minlength=0):
"""
Expand Down
54 changes: 50 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']
'nan_to_num', 'isnan', 'isinf', 'polyval', 'where', 'bincount']

__all__ += fallback.__all__

Expand Down Expand Up @@ -8596,7 +8596,7 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
>>> np.full_like(y, 0.1)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
"""
return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype, order=order, ctx=None, out=None)
return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype, order=order, ctx=ctx, out=out)


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -8652,7 +8652,7 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None):
>>> np.zeros_like(y)
array([0., 0., 0.], dtype=float64)
"""
return _mx_nd_np.full_like(a, fill_value=0, dtype=dtype, order=order, ctx=None, out=None)
return _mx_nd_np.full_like(a, fill_value=0, dtype=dtype, order=order, ctx=ctx, out=ctx)


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -8708,7 +8708,7 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None):
>>> np.ones_like(y)
array([1., 1., 1.], dtype=float64)
"""
return _mx_nd_np.full_like(a, fill_value=1, dtype=dtype, order=order, ctx=None, out=None)
return _mx_nd_np.full_like(a, fill_value=1, dtype=dtype, order=order, ctx=ctx, out=out)


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -8972,6 +8972,52 @@ def where(condition, x=None, y=None):
return _mx_nd_np.where(condition, x, y)


@set_module('mxnet.numpy')
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
If p is of length N, this function returns the value:
p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]
If x is a sequence, then p(x) is returned for each element of x.
If x is another polynomial then the composite polynomial p(x(t)) is returned.
Parameters
----------
p : ndarray
1D array of polynomial coefficients (including coefficients equal to zero)
from highest degree to the constant term.
x : ndarray
An array of numbers, at which to evaluate p.
Returns
-------
values : ndarray
Result array of polynomials
Notes
-----
This function differs from the original `numpy.polyval
<https://numpy.org/devdocs/reference/generated/numpy.polyval.html>`_ in
the following way(s):
- Does not support poly1d.
- X should be ndarray type even if it contains only one element.
Examples
--------
>>> p = np.array([3, 0, 1])
array([3., 0., 1.])
>>> x = np.array([5])
array([5.])
>>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1
array([76.])
>>> x = np.array([5, 4])
array([5., 4.])
>>> np.polyval(p, x)
array([76., 49.])
"""
return _mx_nd_np.polyval(p, x)


@set_module('mxnet.numpy')
def bincount(x, weights=None, minlength=0):
"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'meshgrid',
'outer',
'einsum',
'polyval',
'shares_memory',
'may_share_memory',
'quantile',
Expand Down
40 changes: 39 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -6182,6 +6182,44 @@ def load_json(json_str):
return _Symbol(handle)


@set_module('mxnet.symbol.numpy')
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
If p is of length N, this function returns the value:
p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]
If x is a sequence, then p(x) is returned for each element of x.
If x is another polynomial then the composite polynomial p(x(t)) is returned.
Parameters
----------
p : _Symbol
1D array of polynomial coefficients (including coefficients equal to zero)
from highest degree to the constant term.
x : _Symbol
An array of numbers, at which to evaluate p.
Returns
-------
values : _Symbol
Result array of polynomials
Notes
-----
This function differs from the original `numpy.polyval
<https://numpy.org/devdocs/reference/generated/numpy.polyval.html>`_ in
the following way(s):
- Does not support poly1d.
- X should be ndarray type even if it contains only one element.
"""
if isinstance(p, Symbol) and isinstance(x, Symbol):
return _npi.polyval(p, x)
elif not isinstance(p, Symbol) and not isinstance(x, Symbol):
return _np.polyval(p, x)
else:
raise TypeError('type not supported')


@set_module('mxnet.symbol.numpy')
def bincount(x, weights=None, minlength=0):
"""
Expand Down
100 changes: 100 additions & 0 deletions src/operator/numpy/np_polynomial_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_polynomial_op.h
* \brief Functions for dealing with polynomials.
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_

#include <mxnet/base.h>
#include <string>
#include <vector>
#include <type_traits>
#include "../mxnet_op.h"
#include "../../common/utils.h"
#include "../tensor/elemwise_binary_broadcast_op.h"


namespace mxnet {
namespace op {

inline bool NumpyPolyvalShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);

const mxnet::TShape& p_shape = in_attrs->at(0);
const mxnet::TShape& x_shape = in_attrs->at(1);
const mxnet::TShape& v_shape = out_attrs->at(0);
CHECK_EQ(p_shape.ndim(), 1U) << "ValueError: p has to be an 1-D array.";
SHAPE_ASSIGN_CHECK(*out_attrs, 0, x_shape);
SHAPE_ASSIGN_CHECK(*in_attrs, 1, v_shape);
return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
}

template<int req>
struct polyval_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
DType* out_data,
const DType* p_data,
const DType* x_data,
const index_t p_size) {
DType val = 0;
for (index_t j = 0; j < p_size; j++) {
val = val * x_data[i] + p_data[j];
}
KERNEL_ASSIGN(out_data[i], req, val);
}
};

template<typename xpu>
void NumpyPolyvalForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& p_data = inputs[0];
const TBlob& x_data = inputs[1];
const TBlob& out_data = outputs[0];
const size_t p_size = p_data.Size();
using namespace mxnet_op;

MSHADOW_TYPE_SWITCH(x_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<polyval_forward<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(),
p_data.dptr<DType>(), x_data.dptr<DType>(), p_size);
});
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_OP_INL_H_
Loading

0 comments on commit 8a385b8

Please sign in to comment.