Skip to content

Commit

Permalink
Numpy compatible linspace (apache#15256)
Browse files Browse the repository at this point in the history
* draft

* finish linspace implementation

* finish linspace

* delete newline

* fix pylint

* add more unit test

* address comment

* add more test case

* disable too-many-arguments

* resolve confliction

* add ctx
  • Loading branch information
stu1130 authored and haojin2 committed Jul 31, 2019
1 parent ced5a52 commit 623c68e
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 3 deletions.
63 changes: 62 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from ...util import _sanity_check_params, set_module
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -629,3 +630,63 @@ def tile(A, reps):
The tiled output array.
"""
return _npi.tile(A, reps)


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): #pylint: disable=too-many-arguments
"""Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.
Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.
Notes
-----
This function currently does not support ``start`` and ``stop`` as ndarrays and
axis could only be 0 now.
"""
if isinstance(start, (list, _np.ndarray, NDArray)) or \
isinstance(stop, (list, _np.ndarray, NDArray)):
raise NotImplementedError('start and stop only support int')
if axis != 0:
raise NotImplementedError("the function only support axis 0")
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
if retstep:
step = (stop - start) / (num - 1)
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)
45 changes: 44 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1790,3 +1790,46 @@ def tile(A, reps):
The tiled output array.
"""
return _npi.tile(A, reps)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs):
"""Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.
Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.
"""
return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, axis, **kwargs)
62 changes: 61 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims', 'tile']
'expand_dims', 'tile', 'linspace']


def _num_outputs(sym):
Expand Down Expand Up @@ -1307,4 +1307,64 @@ def tile(A, reps):
return _npi.tile(A, reps)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): # pylint: disable=too-many-arguments
"""Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.
Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.
Notes
-----
This function currently does not support ``start`` and ``stop`` as ndarrays and
axis could only be 0 now.
"""
if isinstance(start, (list, _np.ndarray)) or \
isinstance(stop, (list, _np.ndarray)):
raise NotImplementedError('start and stop only support int')
if axis != 0:
raise NotImplementedError("the function only support axis 0")
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
if retstep:
step = (stop - start) / (num - 1)
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)


_set_np_symbol_class(_Symbol)
1 change: 1 addition & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Examples::
.add_argument("data", "NDArray-or-Symbol", "The input");

NNVM_REGISTER_OP(_linspace)
.add_alias("_npi_linspace")
.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy")
.set_num_inputs(0)
.set_num_outputs(1)
Expand Down
70 changes: 70 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,76 @@ def hybrid_forward(self, F, x):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@npx.use_np_shape
def test_np_linspace():
configs = [
(0.0, 1.0, 10),
(-2, 4, 30),
(5.234324, 8.98324, 324),
(2, 10, 100)
]
exception_configs = [
(0, 10, -1),
(0, 1, 2.5)
]
dtypes = ['int32', 'float16', 'float32', 'float64', None]
for config in configs:
for dtype in dtypes:
for endpoint in [False, True]:
for retstep in [False, True]:
if isinstance(config, tuple):
mx_ret = np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = _np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
else:
mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = _np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
if retstep:
assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5)
same(mx_ret[1], np_ret[1])
else:
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
# check for exception input
for config in exception_configs:
assertRaises(MXNetError, np.linspace, *config)
# check linspace equivalent to arange
for test_index in range(1000):
assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), mx.np.arange(test_index + 1).asnumpy())
@npx.use_np
class TestLinspace(HybridBlock):
def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0):
super(TestLinspace, self).__init__()
self._start = start
self._stop = stop
self._num = num
self._endpoint = endpoint
self._retstep = retstep
self._dtype = dtype

def hybrid_forward(self, F, x):
if self._retstep:
raise ValueError("linspace didn't support retstep = True inside HybridBlock")
else:
return x + F.np.linspace(self._start, self._stop, self._num, \
self._endpoint, self._retstep, self._dtype)

for dtype in dtypes:
x = np.zeros(shape=(), dtype=dtype)
for config in configs:
for hybridize in [False, True]:
for endpoint in [False, True]:
if isinstance(config, tuple):
net = TestLinspace(*config, endpoint=endpoint, dtype=dtype)
np_out = _np.linspace(*config, endpoint=endpoint, dtype=dtype)
else:
net = TestLinspace(config, endpoint=endpoint, dtype=dtype)
np_out = _np.linspace(config, endpoint=endpoint, dtype=dtype)
if hybridize:
net.hybridize()
mx_out = net(x)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)


@with_seed()
@npx.use_np_shape
def test_np_argmax():
Expand Down

0 comments on commit 623c68e

Please sign in to comment.