From 866e789714d874570388e2b1deef5f7bd8144367 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Thu, 20 Jun 2019 10:39:30 -0700 Subject: [PATCH] Numpy compatible linspace (#15256) * draft * finish linspace implementation * finish linspace * delete newline * fix pylint * add more unit test * address comment * add more test case * disable too-many-arguments * resolve confliction * add ctx --- python/mxnet/ndarray/numpy/_op.py | 63 ++++++++++++++++++++++- python/mxnet/numpy/multiarray.py | 45 ++++++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 62 ++++++++++++++++++++++- src/operator/tensor/init_op.cc | 1 + tests/python/unittest/test_numpy_op.py | 70 ++++++++++++++++++++++++++ 5 files changed, 238 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 04de2cd9a5d2..cf14d89bdbd2 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -23,10 +23,11 @@ from ...util import _sanity_check_params, set_module from ...context import current_context from . import _internal as _npi +from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'split', 'swapaxes', 'expand_dims', 'tile'] + 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace'] @set_module('mxnet.ndarray.numpy') @@ -629,3 +630,63 @@ def tile(A, reps): The tiled output array. """ return _npi.tile(A, reps) + + +@set_module('mxnet.ndarray.numpy') +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): #pylint: disable=too-many-arguments + """Return evenly spaced numbers over a specified interval. + + Returns num evenly spaced samples, calculated over the interval [start, stop]. + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start : array_like + The starting value of the sequence. + stop : array_like + The end value of the sequence, unless endpoint is set to False. In + that case, the sequence consists of all but the last of num + 1 + evenly spaced samples, so that stop is excluded. Note that the step + size changes when endpoint is False. + num : int, optional + Number of samples to generate. Default is 50. Must be non-negative. + endpoint : bool, optional + If True, stop is the last sample. Otherwise, it is not included. + Default is True. + retstep: bool, optional + If True, return (samples, step), where step is the spacing between samples. + dtype: dtype, optional + The type of the output array. If dtype is not given, infer the data + type from the other input arguments. + axis : int, optional + The axis in the result to store the samples. Relevant only if start or + stop are array-like. By default (0), the samples will be along a new + axis inserted at the beginning. Use -1 to get an axis at the end. + Returns + ------- + samples : ndarray + There are num equally spaced samples in the closed interval + `[start, stop]` or the half-open interval `[start, stop)` + (depending on whether endpoint is True or False). + step : float, optional + Only returned if retstep is True + Size of spacing between samples. + + Notes + ----- + This function currently does not support ``start`` and ``stop`` as ndarrays and + axis could only be 0 now. + """ + if isinstance(start, (list, _np.ndarray, NDArray)) or \ + isinstance(stop, (list, _np.ndarray, NDArray)): + raise NotImplementedError('start and stop only support int') + if axis != 0: + raise NotImplementedError("the function only support axis 0") + ctx = kwargs.pop('ctx', current_context()) + if ctx is None: + ctx = current_context() + if retstep: + step = (stop - start) / (num - 1) + return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step) + else: + return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 52a2cf414ddf..dd13c8e64cfc 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'split', 'swapaxes', 'expand_dims', 'tile'] + 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace'] # This function is copied from ndarray.py since pylint @@ -1790,3 +1790,46 @@ def tile(A, reps): The tiled output array. """ return _npi.tile(A, reps) + + +@set_module('mxnet.numpy') +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): + """Return evenly spaced numbers over a specified interval. + + Returns num evenly spaced samples, calculated over the interval [start, stop]. + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start : array_like + The starting value of the sequence. + stop : array_like + The end value of the sequence, unless endpoint is set to False. In + that case, the sequence consists of all but the last of num + 1 + evenly spaced samples, so that stop is excluded. Note that the step + size changes when endpoint is False. + num : int, optional + Number of samples to generate. Default is 50. Must be non-negative. + endpoint : bool, optional + If True, stop is the last sample. Otherwise, it is not included. + Default is True. + retstep: bool, optional + If True, return (samples, step), where step is the spacing between samples. + dtype: dtype, optional + The type of the output array. If dtype is not given, infer the data + type from the other input arguments. + axis : int, optional + The axis in the result to store the samples. Relevant only if start or + stop are array-like. By default (0), the samples will be along a new + axis inserted at the beginning. Use -1 to get an axis at the end. + Returns + ------- + samples : ndarray + There are num equally spaced samples in the closed interval + `[start, stop]` or the half-open interval `[start, stop)` + (depending on whether endpoint is True or False). + step : float, optional + Only returned if retstep is True + Size of spacing between samples. + """ + return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, axis, **kwargs) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 11a1da81a855..e015b7a1a670 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -31,7 +31,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax', 'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes', - 'expand_dims', 'tile'] + 'expand_dims', 'tile', 'linspace'] def _num_outputs(sym): @@ -1307,4 +1307,64 @@ def tile(A, reps): return _npi.tile(A, reps) +@set_module('mxnet.symbol.numpy') +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): # pylint: disable=too-many-arguments + """Return evenly spaced numbers over a specified interval. + + Returns num evenly spaced samples, calculated over the interval [start, stop]. + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start : array_like + The starting value of the sequence. + stop : array_like + The end value of the sequence, unless endpoint is set to False. In + that case, the sequence consists of all but the last of num + 1 + evenly spaced samples, so that stop is excluded. Note that the step + size changes when endpoint is False. + num : int, optional + Number of samples to generate. Default is 50. Must be non-negative. + endpoint : bool, optional + If True, stop is the last sample. Otherwise, it is not included. + Default is True. + retstep: bool, optional + If True, return (samples, step), where step is the spacing between samples. + dtype: dtype, optional + The type of the output array. If dtype is not given, infer the data + type from the other input arguments. + axis : int, optional + The axis in the result to store the samples. Relevant only if start or + stop are array-like. By default (0), the samples will be along a new + axis inserted at the beginning. Use -1 to get an axis at the end. + Returns + ------- + samples : ndarray + There are num equally spaced samples in the closed interval + `[start, stop]` or the half-open interval `[start, stop)` + (depending on whether endpoint is True or False). + step : float, optional + Only returned if retstep is True + Size of spacing between samples. + + Notes + ----- + This function currently does not support ``start`` and ``stop`` as ndarrays and + axis could only be 0 now. + """ + if isinstance(start, (list, _np.ndarray)) or \ + isinstance(stop, (list, _np.ndarray)): + raise NotImplementedError('start and stop only support int') + if axis != 0: + raise NotImplementedError("the function only support axis 0") + ctx = kwargs.pop('ctx', current_context()) + if ctx is None: + ctx = current_context() + if retstep: + step = (stop - start) / (num - 1) + return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step) + else: + return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 0cbdaa43c198..710e11c4a509 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -137,6 +137,7 @@ Examples:: .add_argument("data", "NDArray-or-Symbol", "The input"); NNVM_REGISTER_OP(_linspace) +.add_alias("_npi_linspace") .describe("Return evenly spaced numbers over a specified interval. Similar to Numpy") .set_num_inputs(0) .set_num_outputs(1) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 031719c9088b..3ce04409bfce 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -644,6 +644,76 @@ def hybrid_forward(self, F, x): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@npx.use_np_shape +def test_np_linspace(): + configs = [ + (0.0, 1.0, 10), + (-2, 4, 30), + (5.234324, 8.98324, 324), + (2, 10, 100) + ] + exception_configs = [ + (0, 10, -1), + (0, 1, 2.5) + ] + dtypes = ['int32', 'float16', 'float32', 'float64', None] + for config in configs: + for dtype in dtypes: + for endpoint in [False, True]: + for retstep in [False, True]: + if isinstance(config, tuple): + mx_ret = np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype) + np_ret = _np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype) + else: + mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype) + np_ret = _np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype) + if retstep: + assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5) + same(mx_ret[1], np_ret[1]) + else: + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5) + # check for exception input + for config in exception_configs: + assertRaises(MXNetError, np.linspace, *config) + # check linspace equivalent to arange + for test_index in range(1000): + assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), mx.np.arange(test_index + 1).asnumpy()) + @npx.use_np + class TestLinspace(HybridBlock): + def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0): + super(TestLinspace, self).__init__() + self._start = start + self._stop = stop + self._num = num + self._endpoint = endpoint + self._retstep = retstep + self._dtype = dtype + + def hybrid_forward(self, F, x): + if self._retstep: + raise ValueError("linspace didn't support retstep = True inside HybridBlock") + else: + return x + F.np.linspace(self._start, self._stop, self._num, \ + self._endpoint, self._retstep, self._dtype) + + for dtype in dtypes: + x = np.zeros(shape=(), dtype=dtype) + for config in configs: + for hybridize in [False, True]: + for endpoint in [False, True]: + if isinstance(config, tuple): + net = TestLinspace(*config, endpoint=endpoint, dtype=dtype) + np_out = _np.linspace(*config, endpoint=endpoint, dtype=dtype) + else: + net = TestLinspace(config, endpoint=endpoint, dtype=dtype) + np_out = _np.linspace(config, endpoint=endpoint, dtype=dtype) + if hybridize: + net.hybridize() + mx_out = net(x) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5) + + @with_seed() @npx.use_np_shape def test_np_argmax():