From 2170697c86f722fd08a07aab4df5051a4d682ac7 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 2 Jun 2019 10:37:15 -0700 Subject: [PATCH] [numpy] Fix np branch after rebase (#15086) * Add np_array semantics for Gluon Fix notebook Fix sanity Fix gluon deferred infer shape Add np.random.uniform Add random normal Add boolean comparison ops Add np.ndarray indexing Reformat test ndarray indexing Fix unit tests Add one more test of indexing Fix sanity Enable amp test Add np.arange Revert cython unit test to ctypes Delete unnecessary use_np_shape decorator from test Rebase with numpy branch support range as index Fix python2 range type check Add argmax Disable clojure test * Fix ci * Add np.linalg.norm for ord='fro' * Fix pylint --- ci/jenkins/Jenkins_steps.groovy | 18 +- ci/jenkins/Jenkinsfile_unix_cpu | 4 +- example/numpy/demo.ipynb | 2 +- python/mxnet/__init__.py | 3 +- python/mxnet/_ctypes/ndarray.py | 2 +- python/mxnet/base.py | 10 +- python/mxnet/gluon/block.py | 3 +- python/mxnet/gluon/parameter.py | 13 +- python/mxnet/gluon/utils.py | 2 +- python/mxnet/ndarray/__init__.py | 2 +- python/mxnet/ndarray/numpy/_op.py | 78 ++++- python/mxnet/ndarray/numpy/linalg.py | 50 ++- python/mxnet/ndarray/numpy/random.py | 119 ++++++- python/mxnet/numpy/__init__.py | 1 - python/mxnet/numpy/linalg.py | 44 ++- python/mxnet/numpy/multiarray.py | 197 +++++++++-- python/mxnet/numpy/random.py | 82 ++++- python/mxnet/numpy_extension/__init__.py | 3 + python/mxnet/symbol/__init__.py | 2 +- python/mxnet/symbol/numpy/_symbol.py | 148 ++++++-- python/mxnet/symbol/numpy/linalg.py | 49 ++- python/mxnet/symbol/numpy/random.py | 120 ++++++- python/mxnet/test_utils.py | 2 +- python/mxnet/util.py | 230 ++++++++++++- .../numpy/np_broadcast_reduce_op_index.cc | 61 ++++ .../numpy/np_broadcast_reduce_op_index.cu | 34 ++ .../numpy/np_broadcast_reduce_op_value.cc | 2 +- .../numpy/np_broadcast_reduce_op_value.cu | 2 +- .../numpy/np_elemwise_unary_op_basic.cc | 4 +- .../numpy/np_elemwise_unary_op_basic.cu | 4 +- src/operator/numpy/np_init_op.cc | 27 ++ src/operator/numpy/np_init_op.cu | 3 + src/operator/random/sample_op.cc | 2 + src/operator/tensor/broadcast_reduce_op.h | 50 ++- .../elemwise_binary_broadcast_op_logic.cc | 6 + .../tensor/elemwise_binary_scalar_op_logic.cc | 6 + tests/python/unittest/test_contrib_amp.py | 3 - tests/python/unittest/test_numpy_gluon.py | 12 +- tests/python/unittest/test_numpy_ndarray.py | 319 ++++++++++++++++-- tests/python/unittest/test_numpy_op.py | 229 +++++++++++-- tests/python/unittest/test_thread_local.py | 36 ++ 41 files changed, 1836 insertions(+), 148 deletions(-) create mode 100644 src/operator/numpy/np_broadcast_reduce_op_index.cc create mode 100644 src/operator/numpy/np_broadcast_reduce_op_index.cu diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index c27a61383e46..31b869fac8a6 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -112,7 +112,8 @@ def compile_unix_cpu_openblas() { timeout(time: max_time, unit: 'MINUTES') { utils.init_git() utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_openblas', false) - utils.pack_lib('cpu', mx_lib_cython, true) + // utils.pack_lib('cpu', mx_lib_cython, true) + utils.pack_lib('cpu', mx_lib, true) } } } @@ -266,7 +267,8 @@ def compile_unix_cmake_gpu() { timeout(time: max_time, unit: 'MINUTES') { utils.init_git() utils.docker_run('ubuntu_gpu_cu101', 'build_ubuntu_gpu_cmake', false) - utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true) + // utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true) + utils.pack_lib('cmake_gpu', mx_cmake_lib, true) } } } @@ -643,8 +645,10 @@ def test_unix_python2_cpu() { node(NODE_LINUX_CPU) { ws('workspace/ut-python2-cpu') { try { - utils.unpack_and_init('cpu', mx_lib_cython, true) - python2_ut_cython('ubuntu_cpu') + // utils.unpack_and_init('cpu', mx_lib_cython, true) + // python2_ut_cython('ubuntu_cpu') + utils.unpack_and_init('cpu', mx_lib, true) + python2_ut('ubuntu_cpu') utils.publish_test_coverage() } finally { utils.collect_test_results_unix('nosetests_unittest.xml', 'nosetests_python2_cpu_unittest.xml') @@ -745,8 +749,10 @@ def test_unix_python3_gpu() { node(NODE_LINUX_GPU) { ws('workspace/ut-python3-gpu') { try { - utils.unpack_and_init('gpu', mx_lib_cython, true) - python3_gpu_ut_cython('ubuntu_gpu_cu101') + // utils.unpack_and_init('gpu', mx_lib_cython, true) + // python3_gpu_ut_cython('ubuntu_gpu_cu100') + utils.unpack_and_init('gpu', mx_lib, true) + python3_gpu_ut('ubuntu_gpu_cu101') utils.publish_test_coverage() } finally { utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_gpu.xml') diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu index fa0942988d9c..c3a1481f5ec5 100644 --- a/ci/jenkins/Jenkinsfile_unix_cpu +++ b/ci/jenkins/Jenkinsfile_unix_cpu @@ -52,8 +52,8 @@ core_logic: { custom_steps.test_unix_python3_mkldnn_mkl_cpu(), custom_steps.test_unix_scala_cpu(), custom_steps.test_unix_scala_mkldnn_cpu(), - custom_steps.test_unix_clojure_cpu(), - custom_steps.test_unix_clojure_integration_cpu(), + // custom_steps.test_unix_clojure_cpu(), + // custom_steps.test_unix_clojure_integration_cpu(), custom_steps.test_unix_perl_cpu(), custom_steps.test_unix_r_cpu(), custom_steps.test_unix_r_mkldnn_cpu(), diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb index 1f0627563159..31c13e97e3dd 100644 --- a/example/numpy/demo.ipynb +++ b/example/numpy/demo.ipynb @@ -372,7 +372,7 @@ "from mxnet import gluon, autograd, np\n", "\n", "\n", - "@np.use_np_compat\n", + "@np.use_np\n", "class LinearRegression(gluon.HybridBlock):\n", " def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n", " super(LinearRegression, self).__init__()\n", diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 883e84604132..f288b4c65926 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -25,6 +25,7 @@ from . import engine from .base import MXNetError from .util import is_np_shape, set_np_shape, np_shape, use_np_shape +from .util import is_np_array, np_array, use_np_array, use_np from . import base from . import contrib from . import ndarray @@ -32,7 +33,7 @@ from . import numpy from . import numpy_extension from . import numpy as np -from . import numpy_extension as npe +from . import numpy_extension as npx from . import name # use mx.sym as short for symbol from . import symbol as sym diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 6404d895b884..dd429e6f6c46 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -118,7 +118,7 @@ def __init__(self, sym, flags=()): self.handle = CachedOpHandle() from ..symbol.numpy._symbol import _Symbol - self.is_np_sym = True if isinstance(sym, _Symbol) else False + self.is_np_sym = bool(isinstance(sym, _Symbol)) check_call(_LIB.MXCreateCachedOpEx( sym.handle, diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 5393c511ce07..e73bd9387577 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -756,7 +756,7 @@ def _sanity_check_params(func_name, unsupported_params, param_dict): _NP_OP_PREFIX = '_np_' _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] -_NP_EXT_OP_PREFIX = '_npe_' +_NP_EXT_OP_PREFIX = '_npx_' _NP_INTERNAL_OP_PREFIX = '_npi_' @@ -813,14 +813,14 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op op_names.append(name) if mx_module_name is None: - # register np/npe ops for imperative programming + # register np/npx ops for imperative programming op_module_name = "%s.%s._op" % (root_module_name, np_module_name) # e.g. mxnet.numpy._op op_submodule_name = "%s.%s" % (root_module_name, np_module_name) # e.g. mxnet.numpy.random - elif mx_module_name == 'ndarray' or mx_module_name == 'symbol': - # register numpy internal ops and np/npe ops for use in Gluon + elif mx_module_name in ('ndarray', 'symbol'): + # register numpy internal ops and np/npx ops for use in Gluon # np internal ops are registered in mxnet.ndarray/symbol.numpy._internal # np ops are registered in mxnet.ndarray/symbol.numpy._op - # npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op + # npx ops are registered in mxnet.ndarray/symbol.numpy_extension._op op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name) if op_name_prefix != _NP_INTERNAL_OP_PREFIX: op_module_name += '._op' diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 136289136d61..4363c0fb0fed 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -35,6 +35,7 @@ from .parameter import Parameter, ParameterDict, DeferredInitializationError from .utils import _indent, _brief_print_list, HookHandle from .utils import _check_same_symbol_type, _check_all_np_ndarrays +from .. import numpy_extension as _mx_npx from .. import numpy as _mx_np @@ -551,7 +552,7 @@ def __call__(self, *args): for hook in self._forward_hooks.values(): hook(self, args, out) - if _mx_np.is_np_shape(): + if _mx_npx.is_np_array(): _check_all_np_ndarrays(_flatten(out, "output")[0]) return out diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 2d3e8c05462f..86ee9ad4a55b 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -31,7 +31,7 @@ from ..context import Context, cpu from .. import autograd from .utils import _indent, _brief_print_list, shape_is_known -from ..util import is_np_shape +from ..util import is_np_shape, is_np_array # pylint: disable= invalid-name tensor_types = (symbol.Symbol, ndarray.NDArray) @@ -188,9 +188,9 @@ def shape(self, new_shape): if self._shape is None: self._shape = new_shape return - unknown_dim_size = -1 if is_np_shape() else 0 + assert len(self._shape) == len(new_shape) and \ - all(j in (unknown_dim_size, i) for i, j in zip(new_shape, self._shape)), \ + all(j in (0, i) for i, j in zip(new_shape, self._shape)), \ "Expected shape %s is incompatible with given shape %s."%( str(new_shape), str(self._shape)) @@ -317,6 +317,7 @@ def _finish_deferred_init(self): return init, ctx, default_init, data = self._deferred_init self._deferred_init = () + assert shape_is_known(self.shape), \ "Cannot initialize Parameter '%s' because it has " \ "invalid shape: %s. Please specify in_units, " \ @@ -330,7 +331,7 @@ def _finish_deferred_init(self): initializer.create(default_init)( initializer.InitDesc(self.name, {'__init__': init}), data) # TODO(junwu): use np random operators when available - if is_np_shape(): + if is_np_array(): data = data.as_np_ndarray() # convert to np.ndarray self._init_impl(data, ctx) @@ -357,7 +358,7 @@ def _init_grad(self): self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context, stype=self._grad_stype) for i in self._data] # TODO(junwu): use np.zeros - if is_np_shape(): + if is_np_array(): self._grad = [arr.as_np_ndarray() for arr in self._grad] autograd.mark_variables(self._check_and_get(self._data, list), @@ -606,7 +607,7 @@ def var(self): self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype, lr_mult=self.lr_mult, wd_mult=self.wd_mult, init=self.init, stype=self._stype) - if is_np_shape(): + if is_np_array(): self._var = self._var.as_np_ndarray() return self._var diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index b21e06dbeabf..fee22daf910d 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -438,7 +438,7 @@ def _check_same_symbol_type(symbols): the symbols.""" from ..symbol.numpy import _Symbol as np_symbol from ..symbol import Symbol as classic_symbol - is_np_sym = True if isinstance(symbols[0], np_symbol) else False + is_np_sym = bool(isinstance(symbols[0], np_symbol)) for s in symbols[1:]: if is_np_sym != isinstance(s, np_symbol): raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol ' diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index c326850ec3e4..f6b8712a2513 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -31,7 +31,7 @@ from .sparse import _ndarray_cls from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle from . import numpy as np -from . import numpy_extension as npe +from . import numpy_extension as npx __all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \ ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension'] diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 76825f1a59b0..34218e3b36f2 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -24,7 +24,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax'] @set_module('mxnet.ndarray.numpy') @@ -201,3 +201,79 @@ def get_list(arrays): arrays = get_list(arrays) return _npi.stack(*arrays, axis=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def arange(start, stop=None, step=1, dtype=None, ctx=None): + """Return evenly spaced values within a given interval. + + Values are generated within the half-open interval ``[start, stop)`` + (in other words, the interval including `start` but excluding `stop`). + For integer arguments the function is equivalent to the Python built-in + `range` function, but returns an ndarray rather than a list. + + Parameters + ---------- + start : number, optional + Start of interval. The interval includes this value. The default + start value is 0. + stop : number + End of interval. The interval does not include this value, except + in some cases where `step` is not an integer and floating point + round-off affects the length of `out`. + step : number, optional + Spacing between values. For any output `out`, this is the distance + between two adjacent values, ``out[i+1] - out[i]``. The default + step size is 1. If `step` is specified as a position argument, + `start` must also be given. + dtype : dtype + The type of the output array. The default is `float32`. + + Returns + ------- + arange : ndarray + Array of evenly spaced values. + + For floating point arguments, the length of the result is + ``ceil((stop - start)/step)``. Because of floating point overflow, + this rule may result in the last element of `out` being greater + than `stop`. + """ + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if stop is None: + stop = start + start = 0 + if step is None: + step = 1 + if start is None and stop is None: + raise ValueError('start and stop cannot be both None') + if step == 0: + raise ZeroDivisionError('step cannot be 0') + return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) + + +@set_module('mxnet.ndarray.numpy') +def argmax(a, axis=None, out=None): + """Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : ndarray + Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : array, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. + + Returns + ------- + index_array : ndarray of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + """ + return _npi.argmax(a, axis=axis, keepdims=False, out=out) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 8f521fd0d456..36f3f21a7588 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -17,4 +17,52 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" -__all__ = [] +from __future__ import absolute_import +from . import _op as _mx_nd_np + +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + if ord is not None and ord != 'fro': + raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) + if isinstance(axis, tuple) and len(axis) > 2: + raise ValueError('Improper number of dimensions to norm') + if ord == 'fro' and x.ndim > 2 and axis is None: + raise ValueError('Improper number of dimensions to norm') + return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims)) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 8f521fd0d456..3d9fd6a7d6fe 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -16,5 +16,122 @@ # under the License. """Namespace for operators used in Gluon dispatched by F=ndarray.""" +from __future__ import absolute_import +from ...base import numeric_types +from ...context import current_context +from . import _internal as _npi -__all__ = [] +__all__ = ['uniform', 'normal'] + + +def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs): + """Helper function for random generators.""" + from ...numpy import ndarray as np_ndarray + if isinstance(params[0], np_ndarray): + for i in params[1:]: + assert isinstance(i, np_ndarray), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s." % (type(params[0]), type(i)) + return sampler(*params, shape=shape, dtype=dtype, out=out, **kwargs) + elif isinstance(params[0], numeric_types): + if ctx is None: + ctx = current_context() + if shape is None and out is None: + shape = () + for i in params[1:]: + assert isinstance(i, numeric_types), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs) + + raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers, " + "but got %s." % type(params[0])) + + +def uniform(low=0.0, high=1.0, size=None, **kwargs): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ndarray, optional + Store output to an existing ndarray. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + + + Notes + ----- + This function currently does not support ``low`` and ``high`` as ndarrays. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + return _random_helper(_npi.random_uniform, None, + [low, high], size, dtype, ctx, out, kwargs) + + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as ndarrays. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + return _random_helper(_npi.random_normal, None, + [loc, scale], size, dtype, ctx, out, kwargs) diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 6f1c02d6462b..344483dc3d00 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -26,6 +26,5 @@ from . import _op from . import _register from ._op import * # pylint: disable=wildcard-import -from ..util import use_np_shape, set_np_shape, np_shape, is_np_shape __all__ = [] diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index e49bfcf6a97c..9758af47233d 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -17,4 +17,46 @@ """Namespace for ops used in imperative programming.""" -__all__ = [] +from __future__ import absolute_import +from ..ndarray import numpy as _mx_nd_np + +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + return _mx_nd_np.linalg.norm(x, ord, axis, keepdims) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index da7e61e46707..212dfe30d293 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -23,19 +23,22 @@ from __future__ import absolute_import from __future__ import division from array import array as native_array +import sys import ctypes +import warnings import numpy as _np from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP from ..ndarray._internal import _set_np_ndarray_class from . import _op as _mx_np_op from ..base import check_call, _LIB, NDArrayHandle -from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types +from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types from ..util import _sanity_check_params, set_module, use_np_shape from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi -__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack'] +__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', + 'argmax'] # This function is copied from ndarray.py since pylint @@ -74,6 +77,17 @@ def _np_ndarray_cls(handle, writable=True, stype=0): _set_np_ndarray_class(_np_ndarray_cls) +def _get_index(idx): + if isinstance(idx, NDArray) and not isinstance(idx, ndarray): + raise TypeError('Cannot have mx.nd.NDArray as index') + if isinstance(idx, ndarray): + return idx._as_classic_ndarray() + elif sys.version_info[0] > 2 and isinstance(idx, range): + return arange(idx.start, idx.stop, idx.step, dtype='int32')._as_classic_ndarray() + else: + return idx + + @set_module('mxnet.numpy') # pylint: disable=invalid-name @use_np_shape class ndarray(NDArray): @@ -83,22 +97,57 @@ class ndarray(NDArray): floating point number, or something else, etc.). Arrays should be constructed using `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported.""" - def __getitem__(self, item): - # TODO(junwu): make output shape of integer indexing correct - raise NotImplementedError + def __getitem__(self, key): + # TODO(junwu): calling base class __setitem__ is a temp solution + if self.ndim == 0: + if key != (): + raise IndexError('scalar tensor can only accept `()` as index') + if isinstance(key, tuple) and len(key) == 0: + return self + if isinstance(key, integer_types): + key = (key,) + if isinstance(key, tuple) and len(key) == self.ndim\ + and all(isinstance(idx, integer_types) for idx in key): + out = self._as_classic_ndarray() + for idx in key: + out = out[idx] + return out.reshape(()).as_np_ndarray() + if isinstance(key, ndarray): + key = key._as_classic_ndarray() + elif isinstance(key, tuple): + key = [_get_index(idx) for idx in key] + key = tuple(key) + elif isinstance(key, list): + key = [_get_index(idx) for idx in key] + elif sys.version_info[0] > 2 and isinstance(key, range): + key = _get_index(key) + return self._as_classic_ndarray().__getitem__(key).as_np_ndarray() def __setitem__(self, key, value): - if self.size == 0: - return + # TODO(junwu): calling base class __setitem__ is a temp solution + if isinstance(value, NDArray) and not isinstance(value, ndarray): + raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray') if self.ndim == 0: - if key != (): + if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') - # TODO(junwu): Better handling of this situation - hdl = NDArrayHandle() - check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) - classic_ndarray = NDArray(handle=hdl, writable=self.writable) - classic_ndarray.__setitem__(slice(None), value) + if isinstance(value, ndarray): + value = value._as_classic_ndarray() + # TODO(junwu): Better handling of this situation + if isinstance(key, tuple) and len(key) == 0: + self._as_classic_ndarray().__setitem__(slice(None), value) return + + if isinstance(key, integer_types): + key = (key,) + if isinstance(key, ndarray): + key = key._as_classic_ndarray() + elif isinstance(key, tuple): + key = [_get_index(idx) for idx in key] + key = tuple(key) + elif isinstance(key, list): + key = [_get_index(idx) for idx in key] + elif sys.version_info[0] > 2 and isinstance(key, range): + key = _get_index(key) self._as_classic_ndarray().__setitem__(key, value) def __add__(self, other): @@ -248,33 +297,78 @@ def __rpow__(self, other): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.equal(self, other) + elif isinstance(other, numeric_types): + return _npi.equal_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __hash__(self): raise NotImplementedError def __ne__(self, other): """x.__ne__(y) <=> x != y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.not_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.not_equal_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.greater(self, other) + elif isinstance(other, numeric_types): + return _npi.greater_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.greater_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.greater_equal_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.less(self, other) + elif isinstance(other, numeric_types): + return _npi.less_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __le__(self, other): """x.__le__(y) <=> x <= y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, ndarray): + return _npi.less_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.less_equal_scalar(self, float(other)) + else: + raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) def __bool__(self): - raise NotImplementedError + num_elements = self.size + if num_elements == 0: + warnings.simplefilter('default') + warnings.warn('The truth value of an empty array is ambiguous. Returning False, but in' + ' future this will result in an error.', DeprecationWarning) + return False + elif num_elements == 1: + return bool(self.item()) + else: + raise ValueError("The truth value of an ndarray with multiple elements is ambiguous.") def __len__(self): """Number of elements along the first axis.""" @@ -1329,3 +1423,66 @@ def stack(arrays, axis=0, out=None): stacked : ndarray The stacked array has one more dimension than the input arrays.""" return _mx_nd_np.stack(arrays, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def arange(start, stop=None, step=1, dtype=None, ctx=None): + """Return evenly spaced values within a given interval. + + Values are generated within the half-open interval ``[start, stop)`` + (in other words, the interval including `start` but excluding `stop`). + For integer arguments the function is equivalent to the Python built-in + `range` function, but returns an ndarray rather than a list. + + Parameters + ---------- + start : number, optional + Start of interval. The interval includes this value. The default + start value is 0. + stop : number + End of interval. The interval does not include this value, except + in some cases where `step` is not an integer and floating point + round-off affects the length of `out`. + step : number, optional + Spacing between values. For any output `out`, this is the distance + between two adjacent values, ``out[i+1] - out[i]``. The default + step size is 1. If `step` is specified as a position argument, + `start` must also be given. + dtype : dtype + The type of the output array. The default is `float32`. + + Returns + ------- + arange : ndarray + Array of evenly spaced values. + + For floating point arguments, the length of the result is + ``ceil((stop - start)/step)``. Because of floating point overflow, + this rule may result in the last element of `out` being greater + than `stop`. + """ + return _mx_nd_np.arange(start, stop, step, dtype, ctx) + + +@set_module('mxnet.numpy') +def argmax(a, axis=None, out=None): + """Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : ndarray + Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : array, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. + + Returns + ------- + index_array : ndarray of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + """ + return _mx_nd_np.argmax(a, axis, out) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index e49bfcf6a97c..baeab8bb5bf4 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -17,4 +17,84 @@ """Namespace for ops used in imperative programming.""" -__all__ = [] +from __future__ import absolute_import +from ..ndarray import numpy as _mx_nd_np + +__all__ = ['uniform', 'normal'] + + +def uniform(low=0.0, high=1.0, size=None, **kwargs): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ndarray, optional + Store output to an existing ndarray. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + + + Notes + ----- + This function currently does not support ``low`` and ``high`` as ndarrays. + """ + return _mx_nd_np.random.uniform(low, high, size, **kwargs) + + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as ndarrays. + """ + return _mx_nd_np.random.normal(loc, scale, size, **kwargs) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index bd5117528e7d..0c89a88908d7 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -24,5 +24,8 @@ from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import +from ..util import use_np_shape, np_shape, is_np_shape +from ..util import use_np_array, np_array, is_np_array, use_np +from .. import autograd __all__ = [] diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index 1cd805792b41..2ce395bdd279 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -28,7 +28,7 @@ from .symbol import * # pylint: enable=wildcard-import from . import numpy as np -from . import numpy_extension as npe +from . import numpy_extension as npx __all__ = op.__all__ + symbol.__all__\ + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension'] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d55a87881fdf..b2d8a5bd6b10 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -29,7 +29,7 @@ from .._internal import _set_np_symbol_class from . import _internal as _npi -__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax'] @set_module('mxnet.symbol.numpy') @@ -114,8 +114,7 @@ def __mod__(self, other): elif isinstance(other, numeric_types): return _npi.mod_scalar(self, float(other)) else: - raise TypeError("_Symbol does not support type {} as operand" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __rmod__(self, other): """x.__rmod__(y) <=> y % x""" @@ -124,8 +123,7 @@ def __rmod__(self, other): elif isinstance(other, numeric_types): return _npi.rmod_scalar(self, float(other)) else: - raise TypeError("_Symbol does not support type {} as operand" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __idiv__(self, other): raise NotImplementedError @@ -137,8 +135,7 @@ def __truediv__(self, other): elif isinstance(other, numeric_types): return _npi.true_divide_scalar(self, float(other)) else: - raise TypeError("_Symbol does not support type {} as divisor" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as divisor".format(str(type(other)))) def __rtruediv__(self, other): """x.__rtruediv__(y) <=> y / x""" @@ -147,8 +144,7 @@ def __rtruediv__(self, other): elif isinstance(other, numeric_types): return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray() else: - raise TypeError("_Symbol does not support type {} as dividend" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as dividend".format(str(type(other)))) def __itruediv__(self, other): raise NotImplementedError @@ -160,8 +156,7 @@ def __pow__(self, other): elif isinstance(other, numeric_types): return _npi.power_scalar(self, float(other)) else: - raise TypeError("_Symbol does not support type {} as operand" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" @@ -170,8 +165,7 @@ def __rpow__(self, other): elif isinstance(other, numeric_types): return _npi.rpower_scalar(self, float(other)) else: - raise TypeError("_Symbol does not support type {} as operand" - .format(str(type(other)))) + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __neg__(self): """x.__neg__() <=> - x""" @@ -182,27 +176,63 @@ def __deepcopy__(self, _): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.equal(self, other) + elif isinstance(other, numeric_types): + return _npi.equal_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __ne__(self, other): """x.__ne__(y) <=> x != y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.not_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.not_equal_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.greater(self, other) + elif isinstance(other, numeric_types): + return _npi.greater_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.greater_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.greater_equal_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.less(self, other) + elif isinstance(other, numeric_types): + return _npi.less_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __le__(self, other): """x.__le__(y) <=> x <= y""" - raise NotImplementedError + # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported + if isinstance(other, _Symbol): + return _npi.less_equal(self, other) + elif isinstance(other, numeric_types): + return _npi.less_equal_scalar(self, float(other)) + else: + raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __len__(self): raise NotImplementedError @@ -228,8 +258,8 @@ def dot(self, b, out=None): def reshape(self, shape, order='C'): # pylint: disable=arguments-differ if order != 'C': - raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' - 'received {}'.format(str(order))) + raise NotImplementedError('only supports order=\'C\', while received {}' + .format(str(order))) return _mx_np_op.reshape(self, newshape=shape, order=order) def reshape_like(self, *args, **kwargs): @@ -1030,4 +1060,80 @@ def get_list(arrays): return _npi.stack(*arrays, axis=axis, out=out) +@set_module('mxnet.symbol.numpy') +def arange(start, stop=None, step=1, dtype=None, ctx=None): + """Return evenly spaced values within a given interval. + + Values are generated within the half-open interval ``[start, stop)`` + (in other words, the interval including `start` but excluding `stop`). + For integer arguments the function is equivalent to the Python built-in + `range` function, but returns an ndarray rather than a list. + + Parameters + ---------- + start : number, optional + Start of interval. The interval includes this value. The default + start value is 0. + stop : number + End of interval. The interval does not include this value, except + in some cases where `step` is not an integer and floating point + round-off affects the length of `out`. + step : number, optional + Spacing between values. For any output `out`, this is the distance + between two adjacent values, ``out[i+1] - out[i]``. The default + step size is 1. If `step` is specified as a position argument, + `start` must also be given. + dtype : dtype + The type of the output array. The default is `float32`. + + Returns + ------- + arange : ndarray + Array of evenly spaced values. + + For floating point arguments, the length of the result is + ``ceil((stop - start)/step)``. Because of floating point overflow, + this rule may result in the last element of `out` being greater + than `stop`. + """ + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if stop is None: + stop = start + start = 0 + if step is None: + step = 1 + if start is None and stop is None: + raise ValueError('start and stop cannot be both None') + if step == 0: + raise ZeroDivisionError('step cannot be 0') + return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) + + +@set_module('mxnet.symbol.numpy') +def argmax(a, axis=None, out=None): + """Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : ndarray + Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : array, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. + + Returns + ------- + index_array : ndarray of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + """ + return _npi.argmax(a, axis=axis, keepdims=False, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 869fdeb276b9..2cb0d22e1f7a 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -17,4 +17,51 @@ """Namespace for operators used in Gluon dispatched by F=symbol.""" -__all__ = [] +from __future__ import absolute_import +from . import _op as _mx_nd_np + +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + if ord is not None and ord != 'fro': + raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) + if isinstance(axis, tuple) and len(axis) > 2: + raise ValueError('Improper number of dimensions to norm') + # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception + return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims)) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 869fdeb276b9..fd73478e49eb 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -17,4 +17,122 @@ """Namespace for operators used in Gluon dispatched by F=symbol.""" -__all__ = [] +from __future__ import absolute_import +from ...base import numeric_types +from ...context import current_context +from . import _internal as _npi + +__all__ = ['uniform', 'normal'] + + +def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs): + """Helper function for random generators.""" + from ._symbol import _Symbol as np_symbol + if isinstance(params[0], np_symbol): + for i in params[1:]: + assert isinstance(i, np_symbol), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s." % (type(params[0]), type(i)) + return sampler(*params, shape=shape, dtype=dtype, out=out, **kwargs) + elif isinstance(params[0], numeric_types): + if ctx is None: + ctx = current_context() + if shape is None and out is None: + shape = () + for i in params[1:]: + assert isinstance(i, numeric_types), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs) + + raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers, " + "but got %s." % type(params[0])) + + +def uniform(low=0.0, high=1.0, size=None, **kwargs): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ndarray, optional + Store output to an existing ndarray. + + Returns + ------- + out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs) + Drawn samples from the parameterized uniform distribution. + + + Notes + ----- + This function currently does not support ``low`` and ``high`` as symbols. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + return _random_helper(_npi.random_uniform, None, + [low, high], size, dtype, ctx, out, kwargs) + + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs) + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as `_Symbol`s. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + return _random_helper(_npi.random_normal, None, + [loc, scale], size, dtype, ctx, out, kwargs) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 925007ddd2f0..df0438dcb31f 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -947,7 +947,7 @@ def random_projection(shape): input_shape = {k: v.shape for k, v in location.items()} _, out_shape, _ = sym.infer_shape(**input_shape) proj = mx.sym.Variable("__random_proj") - is_np_sym = True if isinstance(sym, np_symbol) else False + is_np_sym = bool(isinstance(sym, np_symbol)) if is_np_sym: # convert to np symbol for using element-wise multiplication proj = proj.as_np_ndarray() out = sym * proj diff --git a/python/mxnet/util.py b/python/mxnet/util.py index d41137142a70..60c35bdf0c80 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -22,6 +22,7 @@ import functools import itertools import inspect +import threading from .base import _LIB, check_call @@ -84,8 +85,7 @@ def set_np_shape(active): def is_np_shape(): - """ - Checks whether the NumPy shape semantics is currently turned on. + """Checks whether the NumPy shape semantics is currently turned on. In NumPy shape semantics, `()` represents the shape of scalar tensors, and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent the shapes of zero-size tensors. This is turned off by default for keeping @@ -268,12 +268,12 @@ def value(self): Parameters ---------- - func : a user-provided callable function or class to be scoped by the NumPy compatibility state. + func : a user-provided callable function or class to be scoped by the NumPy-shape semantics. Returns ------- Function or class - A function or class wrapped in the NumPy compatibility scope. + A function or class wrapped in the NumPy-shape scope. """ if inspect.isclass(func): @@ -323,3 +323,225 @@ def decorator(func): func.__module__ = module return func return decorator + + +class _NumpyArrayScope(object): + """Scope for managing NumPy array creation. This is often used + with `is_np_array=True` in initializer to enforce array creation + as type `mxnet.numpy.ndarray`, instead of `mx.nd.NDArray` in Gluon. + + Do not use this class directly. Use `np_array(active)` instead. + """ + _current = threading.local() + + def __init__(self, is_np_array): #pylint: disable=redefined-outer-name + self._old_scope = None + self._is_np_array = is_np_array + + def __enter__(self): + if not hasattr(_NumpyArrayScope._current, "value"): + _NumpyArrayScope._current.value = _NumpyArrayScope(False) + self._old_scope = _NumpyArrayScope._current.value + _NumpyArrayScope._current.value = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_scope + _NumpyArrayScope._current.value = self._old_scope + + +def np_array(active=True): + """Returns an activated/deactivated NumPy-array scope to be used in 'with' statement + and captures code that needs the NumPy-array semantics. + + Currently, this is used in Gluon to enforce array creation in `Block`s as type + `mxnet.numpy.ndarray`, instead of `mx.nd.NDArray`. + + It is recommended to use the decorator `use_np_array` to decorate the classes + that need this semantics, instead of using this function in a `with` statement + unless you know exactly what has been scoped by this semantics. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy even within this scope. + + Parameters + ---------- + active : bool + Indicates whether to activate NumPy-array semantics. + + Returns + ------- + _NumpyShapeScope + A scope object for wrapping the code w/ or w/o NumPy-shape semantics. + """ + return _NumpyArrayScope(active) + + +def is_np_array(): + """Checks whether the NumPy-array semantics is currently turned on. + This is currently used in Gluon for checking whether an array of type `mxnet.numpy.ndarray` + or `mx.nd.NDArray` should be created. For example, at the time when a parameter + is created in a `Block`, an `mxnet.numpy.ndarray` is created if this returns true; else + an `mx.nd.NDArray` is created. + + Normally, users are not recommended to use this API directly unless you known exactly + what is going on under the hood. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy within this semantics. + + Returns + ------- + A bool value indicating whether the NumPy-array semantics is currently on. + """ + return _NumpyArrayScope._current.value._is_np_array if hasattr( + _NumpyArrayScope._current, "value") else False + + +def use_np_array(func): + """A decorator wrapping Gluon `Block`s and all its methods, properties, and static functions + with the semantics of NumPy-array, which means that where ndarrays are created, + `mxnet.numpy.ndarray`s should be created, instead of legacy ndarrays of type `mx.nd.NDArray`. + For example, at the time when a parameter is created in a `Block`, an `mxnet.numpy.ndarray` + is created if it's decorated with this decorator. + + Example:: + import mxnet as mx + from mxnet import gluon, np + + + class TestHybridBlock1(gluon.HybridBlock): + def __init__(self): + super(TestHybridBlock1, self).__init__() + self.w = self.params.get('w', shape=(2, 2)) + + def hybrid_forward(self, F, x, w): + return F.dot(x, w) + + + x = mx.nd.ones((2, 2)) + net1 = TestHybridBlock1() + net1.initialize() + out = net1.forward(x) + for _, v in net1.collect_params().items(): + assert type(v.data()) is mx.nd.NDArray + assert type(out) is mx.nd.NDArray + + + @np.use_np_array + class TestHybridBlock2(gluon.HybridBlock): + def __init__(self): + super(TestHybridBlock2, self).__init__() + self.w = self.params.get('w', shape=(2, 2)) + + def hybrid_forward(self, F, x, w): + return F.np.dot(x, w) + + + x = np.ones((2, 2)) + net2 = TestHybridBlock2() + net2.initialize() + out = net2.forward(x) + for _, v in net2.collect_params().items(): + print(type(v.data())) + assert type(v.data()) is np.ndarray + assert type(out) is np.ndarray + + Parameters + ---------- + func : a user-provided callable function or class to be scoped by the NumPy-array semantics. + + Returns + ------- + Function or class + A function or class wrapped in the NumPy-array scope. + """ + if inspect.isclass(func): + for name, method in inspect.getmembers( + func, + predicate= + lambda f: inspect.isfunction(f) or inspect.ismethod(f) or isinstance(f, property)): + if isinstance(method, property): + setattr(func, name, property(use_np_array(method.__get__), + method.__set__, + method.__delattr__, + method.__doc__)) + else: + setattr(func, name, use_np_array(method)) + return func + elif callable(func): + @wraps_safely(func) + def _with_np_array(*args, **kwargs): + with np_array(active=True): + return func(*args, **kwargs) + return _with_np_array + else: + raise TypeError('use_np_array can only decorate classes and callable objects, ' + 'while received a {}'.format(str(type(func)))) + + +def use_np(func): + """A convenience decorator for wrapping user provided functions and classes in the scope of + both NumPy-shape and NumPy-array semantics, which means that (1) empty tuples `()` and tuples + with zeros, such as `(0, 1)`, `(1, 0, 2)`, will be treated as scalar tensors' shapes and + zero-size tensors' shapes in shape inference functions of operators, instead of as unknown + in legacy mode; (2) ndarrays of type `mxnet.numpy.ndarray` should be created instead of + `mx.nd.NDArray`. + + Example:: + import mxnet as mx + from mxnet import gluon, np + + + class TestHybridBlock1(gluon.HybridBlock): + def __init__(self): + super(TestHybridBlock1, self).__init__() + self.w = self.params.get('w', shape=(2, 2)) + + def hybrid_forward(self, F, x, w): + return F.dot(x, w) + F.ones((1,)) + + + x = mx.nd.ones((2, 2)) + net1 = TestHybridBlock1() + net1.initialize() + out = net1.forward(x) + for _, v in net1.collect_params().items(): + assert type(v.data()) is mx.nd.NDArray + assert type(out) is mx.nd.NDArray + + + @np.use_np + class TestHybridBlock2(gluon.HybridBlock): + def __init__(self): + super(TestHybridBlock2, self).__init__() + self.w = self.params.get('w', shape=(2, 2)) + + def hybrid_forward(self, F, x, w): + return F.np.dot(x, w) + F.np.ones(()) + + + x = np.ones((2, 2)) + net2 = TestHybridBlock2() + net2.initialize() + out = net2.forward(x) + for _, v in net2.collect_params().items(): + print(type(v.data())) + assert type(v.data()) is np.ndarray + assert type(out) is np.ndarray + + Parameters + ---------- + func : a user-provided callable function or class to be scoped by the + NumPy-shape and NumPy-array semantics. + + Returns + ------- + Function or class + A function or class wrapped in the Numpy-shape and NumPy-array scope. + """ + return use_np_array(use_np_shape(func)) diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cc b/src/operator/numpy/np_broadcast_reduce_op_index.cc new file mode 100644 index 000000000000..bd6915cc9b27 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_index.cc + * \brief CPU Implementation of broadcast and reduce functions based on index. + */ +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const ReduceAxisParam& param = nnvm::get(attrs.parsed); + dmlc::optional> axes; + if (param.axis.has_value()) { + mxnet::Tuple t({param.axis.value()}); + axes = dmlc::optional>(t); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], axes, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + +NNVM_REGISTER_OP(_npi_argmax) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxisShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.add_argument("data", "NDArray-or-Symbol", "The input") +.set_attr("FCompute", SearchAxisCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_arguments(ReduceAxisParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu new file mode 100644 index 000000000000..aae66a6d660a --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_index.cu + * \brief GPU Implementation of reduce functions. + */ +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_argmax) +.set_attr("FCompute", SearchAxisCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index a72efd9a4d23..078cd46dc857 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file np_reduce_op_value.cc + * \file np_broadcast_reduce_op_value.cc * \brief CPU Implementation of broadcast and reduce functions based on value. */ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 2f50738832fe..7740c03de70b 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file np_reduce_op_value.cu + * \file np_broadcast_reduce_op_value.cu * \brief GPU Implementation of reduce functions based on value. */ #include "np_broadcast_reduce_op.h" diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index 87a765eb981c..1acec6f8c971 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -27,7 +27,7 @@ namespace mxnet { namespace op { -MXNET_OPERATOR_REGISTER_UNARY(_npe_relu) +MXNET_OPERATOR_REGISTER_UNARY(_npx_relu) .describe(R"code(Computes rectified linear activation. .. math:: @@ -37,7 +37,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_npe_relu) .set_attr("FCompute", UnaryOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_relu"}); -MXNET_OPERATOR_REGISTER_UNARY(_npe_sigmoid) +MXNET_OPERATOR_REGISTER_UNARY(_npx_sigmoid) .describe(R"code(Computes sigmoid of x element-wise. .. math:: diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index a3cdff93e902..13237685d963 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -26,10 +26,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npe_relu) +NNVM_REGISTER_OP(_npx_relu) .set_attr("FCompute", UnaryOp::Compute); -NNVM_REGISTER_OP(_npe_sigmoid) +NNVM_REGISTER_OP(_npx_sigmoid) .set_attr("FCompute", UnaryOp::Compute); NNVM_REGISTER_OP(_np_copy) diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index 83a44c8ae280..9edfa20eff99 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -28,6 +28,23 @@ namespace mxnet { namespace op { +inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const RangeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shapes->size(), 0U); + CHECK_EQ(out_shapes->size(), 1U); + CHECK_NE(param.step, 0) << "_npi_arange does not support step=0"; + CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat; + CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value"; + double out_size = std::ceil((param.stop.value() - param.start) / param.step); + if (out_size < 0) { + out_size = 0; + } + SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast(out_size)})); + return true; +} + NNVM_REGISTER_OP(_npi_zeros) .describe("Return a new array of given shape, type, and context, filled with zeros.") .set_num_inputs(0) @@ -107,5 +124,15 @@ Examples:: .add_argument("a", "NDArray-or-Symbol", "The shape and data-type of a define these same attributes of the returned array."); +NNVM_REGISTER_OP(_npi_arange) +.describe("Return evenly spaced values within a given interval.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(RangeParamParser) +.set_attr("FInferShape", NumpyRangeShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", RangeCompute) +.add_arguments(RangeParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 2eb8ed6d83b7..2c41e56736f2 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -40,5 +40,8 @@ NNVM_REGISTER_OP(_np_zeros_like) NNVM_REGISTER_OP(_np_ones_like) .set_attr("FCompute", FillCompute); +NNVM_REGISTER_OP(_npi_arange) +.set_attr("FCompute", RangeCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 56a162be5da4..543146257ddf 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -81,6 +81,7 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam); MXNET_OPERATOR_REGISTER_SAMPLE(_random_uniform, SampleUniformParam) .add_alias("uniform") .add_alias("random_uniform") +.add_alias("_npi_random_uniform") .describe(R"code(Draw random samples from a uniform distribution. .. note:: The existing alias ``uniform`` is deprecated. @@ -99,6 +100,7 @@ Example:: MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam) .add_alias("normal") .add_alias("random_normal") +.add_alias("_npi_random_normal") .describe(R"code(Draw random samples from a normal (Gaussian) distribution. .. note:: The existing alias ``normal`` is deprecated. diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index a6ee242c489a..cba9821fed25 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -168,15 +168,24 @@ struct BroadcastLikeParam : public dmlc::Parameter { } }; -inline int CheckAxis(int axis, int ndim) { - CHECK(axis < ndim && axis >= -ndim) - << "axis " << axis << " exceeds the input dimension of " << ndim; - return (axis + ndim)%ndim; +inline int CheckAxis(const int axis, const int ndim) { + if (ndim == 0) { + CHECK(axis == 0 || axis == -1) << "axis " << axis << " is out of bounds for array of" + " dimension 1"; + return 0; + } else { + CHECK(axis < ndim && axis >= -ndim) + << "axis " << axis << " exceeds the input dimension of " << ndim; + return (axis + ndim) % ndim; + } } inline mxnet::TShape AxisShapeCompact(mxnet::TShape shape, int *axis, bool allow_2d) { int ndim = shape.ndim(); - index_t leading = 1, trailing = 1, M = shape[*axis]; + index_t leading = 1, trailing = 1, M = 1; + if (shape.ndim() > *axis) { + M = shape[*axis]; + } for (int i = 0; i < *axis; ++i) leading *= shape[i]; for (int i = *axis + 1; i < ndim; ++i) trailing *= shape[i]; if (allow_2d && trailing == 1) { @@ -553,14 +562,37 @@ void SearchAxisCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; const ReduceAxisParam& param = nnvm::get(attrs.parsed); Stream *s = ctx.get_stream(); - if (!param.axis) LOG(FATAL) << "Global reduction not supported yet"; + int axis = inputs[0].ndim(); + TBlob input = inputs[0]; + if (param.axis.has_value()) { + axis = param.axis.value(); + } else { + // If global reduction, reshape the input tensor into 2D shape (1, inputs[0].shape_.Size()) + // and search on axis = 1. + mxnet::TShape shape_2d(2, 1); + shape_2d[1] = input.shape_.Size(); + input = TBlob(input.dptr_, shape_2d, input.dev_mask(), input.type_flag_, input.dev_id()); + axis = 1; + } - int axis = CheckAxis(param.axis.value(), inputs[0].shape_.ndim()); - mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, false); + axis = CheckAxis(axis, input.shape_.ndim()); + if (inputs[0].shape_.ndim() != 0) { + if (param.axis.has_value()) { + // cannot do argmax in an empty dimension + CHECK_NE(inputs[0].shape_[axis], 0) + << "searching input tensor of shape " << inputs[0].shape_ + << " along axis = " << axis << " of zero dim-size is not allowed"; + } else { + // cannot do argmax on an empty array + CHECK_NE(inputs[0].shape_.Size(), 0U) << "attempt to search an empty sequence"; + } + } + if (input.shape_.Size() == 0U) return; // zero-size tensor + mxnet::TShape shape = AxisShapeCompact(input.shape_, &axis, false); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { Tensor out = outputs[0].get_with_shape( Shape2(shape[0], shape[2]), s); - Tensor in = inputs[0].get_with_shape( + Tensor in = input.get_with_shape( shape.get<3>(), s); CHECK(req[0] != kAddTo) << "AddTo is not supported"; ASSIGN_DISPATCH(out, req[0], (reduce_with_axis(in, 1))); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc index cd433e00a770..e3c2e0e898d9 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc @@ -30,6 +30,7 @@ namespace mxnet { namespace op { MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_equal) +.add_alias("_npi_equal") .describe(R"code(Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. Example:: @@ -48,6 +49,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_not_equal) +.add_alias("_npi_not_equal") .describe(R"code(Returns the result of element-wise **not equal to** (!=) comparison operation with broadcasting. Example:: @@ -66,6 +68,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater) +.add_alias("_npi_greater") .describe(R"code(Returns the result of element-wise **greater than** (>) comparison operation with broadcasting. Example:: @@ -84,6 +87,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater_equal) +.add_alias("_npi_greater_equal") .describe(R"code(Returns the result of element-wise **greater than or equal to** (>=) comparison operation with broadcasting. Example:: @@ -102,6 +106,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser) +.add_alias("_npi_less") .describe(R"code(Returns the result of element-wise **lesser than** (<) comparison operation with broadcasting. Example:: @@ -120,6 +125,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser_equal) +.add_alias("_npi_less_equal") .describe(R"code(Returns the result of element-wise **lesser than or equal to** (<=) comparison operation with broadcasting. Example:: diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc index 17e76153ebb2..87ba394c99b2 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc @@ -71,26 +71,32 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_equal_scalar, mshadow_op::eq) +.add_alias("_npi_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_EqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_not_equal_scalar, mshadow_op::ne) +.add_alias("_npi_not_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_NotEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_scalar, mshadow_op::gt) +.add_alias("_npi_greater_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_equal_scalar, mshadow_op::ge) +.add_alias("_npi_greater_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_scalar, mshadow_op::lt) +.add_alias("_npi_less_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_equal_scalar, mshadow_op::le) +.add_alias("_npi_less_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserEqualScalar"); diff --git a/tests/python/unittest/test_contrib_amp.py b/tests/python/unittest/test_contrib_amp.py index c11d3f713581..ef3a6d81fb48 100644 --- a/tests/python/unittest/test_contrib_amp.py +++ b/tests/python/unittest/test_contrib_amp.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import unittest import mxnet as mx import warnings import collections @@ -23,8 +22,6 @@ import mxnet.contrib.amp as amp -# TODO(junwu): Enable test -@unittest.skip("Temporarily disabled for adding new np ops") def test_amp_coverage(): conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index b7656b75feb7..0fcb874f8472 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -19,7 +19,7 @@ from __future__ import absolute_import from __future__ import division import mxnet as mx -from mxnet import gluon, autograd, np +from mxnet import gluon, autograd, np, npx def test_create_np_param(): @@ -44,7 +44,7 @@ def __init__(self): def hybrid_forward(self, F, x, w): return F.dot(x, w) - @np.use_np_shape + @npx.use_np class TestBlock2(gluon.HybridBlock): def __init__(self): super(TestBlock2, self).__init__() @@ -62,9 +62,9 @@ def hybrid_forward(self, F, x, w): def test_optimizer_with_np_ndarrays(): - @np.use_np_shape + @npx.use_np class LinearRegression(gluon.HybridBlock): - def __init__(self, num_input_dim=-1, num_hidden_dim=100, num_output_dim=10): + def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10): super(LinearRegression, self).__init__() with self.name_scope(): self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim), @@ -74,11 +74,11 @@ def __init__(self, num_input_dim=-1, num_hidden_dim=100, num_output_dim=10): def hybrid_forward(self, F, x, w1, w2): h = x.dot(w1) # equivalent to F.np.dot(x, w1) - h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating np.ndarray + h_relu = F.npx.relu(h) # equivalent to F.relu(h) but generating np.ndarray y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2) return y_pred - @np.use_np_shape + @npx.use_np class TotalLoss(gluon.HybridBlock): def hybrid_forward(self, F, pred, label): return ((pred - label) ** 2).sum() # equivalent to F.np.sum(F.np.square(pred - label)) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 188cb6f3393a..1c714719ba5c 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -20,7 +20,7 @@ from __future__ import division import numpy as _np import mxnet as mx -from mxnet import np +from mxnet import np, npx from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception from common import with_seed @@ -29,9 +29,15 @@ @with_seed() def test_array_creation(): dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] - objects = [[], (), [[1, 2], [3, 4]], - _np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)), - mx.nd.array(_np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)))] + objects = [ + [], + (), + [[1, 2], [3, 4]], + _np.random.uniform(size=rand_shape_nd(3)), + _np.random.uniform(size=(3, 0, 4)), + np.random.uniform(size=rand_shape_nd(3)), + np.random.uniform(size=(3, 0, 4)) + ] for dtype in dtypes: for src in objects: mx_arr = np.array(src, dtype=dtype) @@ -47,7 +53,7 @@ def test_array_creation(): @with_seed() def test_zeros(): # test np.zeros in Gluon - @np.use_np_shape + @npx.use_np_shape class TestZeros(HybridBlock): def __init__(self, shape, dtype=None): super(TestZeros, self).__init__() @@ -57,13 +63,13 @@ def __init__(self, shape, dtype=None): def hybrid_forward(self, F, x, *args, **kwargs): return x + F.np.zeros(shape, dtype) - @np.use_np_shape + @npx.use_np_shape class TestZerosOutputType(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return x, F.np.zeros(shape=()) # test np.zeros in imperative - @np.use_np_shape + @npx.use_np_shape def check_zero_array_creation(shape, dtype): np_out = _np.zeros(shape=shape, dtype=dtype) mx_out = np.zeros(shape=shape, dtype=dtype) @@ -97,7 +103,7 @@ def check_zero_array_creation(shape, dtype): @with_seed() def test_ones(): # test np.ones in Gluon - @np.use_np_shape + @npx.use_np_shape class TestOnes(HybridBlock): def __init__(self, shape, dtype=None): super(TestOnes, self).__init__() @@ -107,13 +113,13 @@ def __init__(self, shape, dtype=None): def hybrid_forward(self, F, x, *args, **kwargs): return x * F.np.ones(shape, dtype) - @np.use_np_shape + @npx.use_np_shape class TestOnesOutputType(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return x, F.np.ones(shape=()) # test np.ones in imperative - @np.use_np_shape + @npx.use_np_shape def check_ones_array_creation(shape, dtype): np_out = _np.ones(shape=shape, dtype=dtype) mx_out = np.ones(shape=shape, dtype=dtype) @@ -146,17 +152,24 @@ def check_ones_array_creation(shape, dtype): @with_seed() def test_ndarray_binary_element_wise_ops(): - # Cannot test operators like >, because boolean arrays are not supported yet. - np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide, - 'mod': _np.mod, 'pow': _np.power, - # '>': _np.greater, '>=': _np.greater_equal, - # '<': _np.less, '<=': _np.less_equal - } + np_op_map = { + '+': _np.add, + '*': _np.multiply, + '-': _np.subtract, + '/': _np.divide, + 'mod': _np.mod, + 'pow': _np.power, + '==': _np.equal, + '>': _np.greater, + '>=': _np.greater_equal, + '<': _np.less, + '<=': _np.less_equal + } def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) - @np.use_np_shape + @npx.use_np_shape class TestBinaryElementWiseOp(HybridBlock): def __init__(self, op, scalar=None, reverse=False): super(TestBinaryElementWiseOp, self).__init__() @@ -197,29 +210,34 @@ def hybrid_forward(self, F, x, *args): return x ** args[0] if not self._reverse else args[0] ** x elif self._op == '>': if self._scalar is not None: - return x > self._scalar + return x > self._scalar if not self._reverse else self._scalar > x else: return x > args[0] elif self._op == '>=': if self._scalar is not None: - return x >= self._scalar + return x >= self._scalar if not self._reverse else self._scalar >= x else: return x >= args[0] elif self._op == '<': if self._scalar is not None: - return x < self._scalar + return x < self._scalar if not self._reverse else self._scalar < x else: return x < args[0] elif self._op == '<=': if self._scalar is not None: - return x <= self._scalar + return x <= self._scalar if not self._reverse else self._scalar <= x else: return x <= args[0] + elif self._op == '==': + if self._scalar is not None: + return x == self._scalar if not self._reverse else self._scalar == x + else: + return x == args[0] else: print(self._op) assert False - @np.use_np_shape + @npx.use_np_shape def check_binary_op_result(shape1, shape2, op, dtype=None): if shape1 is None: mx_input1 = abs(_np.random.uniform()) + 1 @@ -289,10 +307,10 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): @with_seed() def test_hybrid_block_multiple_outputs(): - @np.use_np_shape + @npx.use_np_shape class TestAllNumpyOutputs(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): - return F.npe.relu(x), F.np.sum(x) + return F.npx.relu(x), F.np.sum(x) class TestAllClassicOutputs(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): @@ -309,7 +327,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert type(out1) is expected_out_type assert type(out2) is expected_out_type - @np.use_np_shape + @npx.use_np_array class TestMixedTypeOutputsFailure(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return F.relu(x.as_classic_ndarray()), F.np.sum(x) @@ -357,6 +375,257 @@ def test_np_ndarray_copy(): assert same(mx_ret.asnumpy(), np_ret) +@with_seed() +def test_np_ndarray_indexing(): + def test_getitem(np_array, index): + """`is_scalar` indicates whether we should expect a scalar for the result. + If so, the indexed array of NDArray should call asscalar to compare + with numpy's indexed array.""" + np_index = index + if isinstance(index, np.ndarray): + np_index = index.asnumpy() + if isinstance(index, tuple): + np_index = [] + for idx in index: + if isinstance(idx, np.ndarray): + np_index.append(idx.asnumpy()) + else: + np_index.append(idx) + np_index = tuple(np_index) + + np_indexed_array = np_array[np_index] + mx_array = np.array(np_array, dtype=np_array.dtype) + mx_indexed_array = mx_array[index].asnumpy() + assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index) + + def test_setitem(np_array, index): + def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): + if np_value is not None: + np_array[np_index] = np_value + elif isinstance(mx_value, np.ndarray): + np_array[np_index] = mx_value.asnumpy() + else: + np_array[np_index] = mx_value + mx_array[mx_index] = mx_value + assert same(np_array, mx_array.asnumpy()) + + np_index = index + if isinstance(index, np.ndarray): + np_index = index.asnumpy() + if isinstance(index, tuple): + np_index = [] + for idx in index: + if isinstance(idx, np.ndarray): + np_index.append(idx.asnumpy()) + else: + np_index.append(idx) + np_index = tuple(np_index) + + mx_array = np.array(np_array, dtype=np_array.dtype) + np_array = mx_array.asnumpy() + indexed_array_shape = np_array[np_index].shape + np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape) + # test value is a numpy array without broadcast + assert_same(np_array, np_index, mx_array, index, np_indexed_array) + # test value is an numeric_type + assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0)) + if len(indexed_array_shape) > 1: + # test ndarray with broadcast + assert_same(np_array, np_index, mx_array, index, + np.random.uniform(low=-10000, high=0, size=(indexed_array_shape[-1],))) + # test numpy array with broadcast + assert_same(np_array, np_index, mx_array, index, + _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],))) + # test list with broadcast + assert_same(np_array, np_index, mx_array, index, + [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) + + def test_getitem_autograd(np_array, index): + x = np.array(np_array, dtype=np_array.dtype) + x.attach_grad() + with npx.autograd.record(): + y = x[index] + y.backward() + value = np.ones_like(y) + x_grad = np.zeros_like(x) + x_grad[index] = value + assert same(x_grad.asnumpy(), x.grad.asnumpy()) + + def test_setitem_autograd(np_array, index): + x = np.array(np_array, dtype=np_array.dtype) + out_shape = x[index].shape + y = np.random.uniform(size=out_shape) + y.attach_grad() + try: + with npx.autograd.record(): + x[index] = y + assert False # should not reach here + except mx.base.MXNetError as err: + assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1 + + def np_int(index, int_type=_np.int32): + def convert(num): + if num is None: + return num + else: + return int_type(num) + + if isinstance(index, slice): + return slice(convert(index.start), convert(index.stop), convert(index.step)) + elif isinstance(index, tuple): # tuple of slices and integers + ret = [] + for elem in index: + if isinstance(elem, slice): + ret.append(slice(convert(elem.start), convert(elem.stop), convert(elem.step))) + else: + ret.append(convert(elem)) + return tuple(ret) + else: + assert False + + shape = (8, 16, 9, 9) + np_array = _np.arange(_np.prod(shape), dtype='int32').reshape(shape) + index_list = [ + (), + 0, + _np.int32(0), + _np.int64(0), + 5, + _np.int32(5), + _np.int64(5), + -1, + _np.int32(-1), + _np.int64(-1), + slice(5), + np_int(slice(5), _np.int32), + np_int(slice(5), _np.int64), + slice(1, 5), + np_int(slice(1, 5), _np.int32), + np_int(slice(1, 5), _np.int64), + slice(1, 5, 2), + np_int(slice(1, 5, 2), _np.int32), + np_int(slice(1, 5, 2), _np.int64), + slice(7, 0, -1), + np_int(slice(7, 0, -1)), + np_int(slice(7, 0, -1), _np.int64), + slice(None, 6), + np_int(slice(None, 6)), + np_int(slice(None, 6), _np.int64), + slice(None, 6, 3), + np_int(slice(None, 6, 3)), + np_int(slice(None, 6, 3), _np.int64), + slice(1, None), + np_int(slice(1, None)), + np_int(slice(1, None), _np.int64), + slice(1, None, 3), + np_int(slice(1, None, 3)), + np_int(slice(1, None, 3), _np.int64), + slice(None, None, 2), + np_int(slice(None, None, 2)), + np_int(slice(None, None, 2), _np.int64), + slice(None, None, -1), + np_int(slice(None, None, -1)), + np_int(slice(None, None, -1), _np.int64), + slice(None, None, -2), + np_int(slice(None, None, -2), _np.int32), + np_int(slice(None, None, -2), _np.int64), + (slice(None), slice(None), 1, 8), + (slice(None), slice(None), -1, 8), + (slice(None), slice(None), 1, -8), + (slice(None), slice(None), -1, -8), + np_int((slice(None), slice(None), 1, 8)), + np_int((slice(None), slice(None), 1, 8), _np.int64), + (slice(None), slice(None), 1, 8), + np_int((slice(None), slice(None), -1, -8)), + np_int((slice(None), slice(None), -1, -8), _np.int64), + (slice(None), 2, slice(1, 5), 1), + np_int((slice(None), 2, slice(1, 5), 1)), + np_int((slice(None), 2, slice(1, 5), 1), _np.int64), + (1, 2, 3), + np_int((1, 2, 3)), + np_int((1, 2, 3), _np.int64), + (-1, -2, -3), + np_int((-1, -2, -3)), + np_int((-1, -2, -3), _np.int64), + (1, 2, 3, 4), + np_int((1, 2, 3, 4)), + np_int((1, 2, 3, 4), _np.int64), + (-4, -3, -2, -1), + np_int((-4, -3, -2, -1)), + np_int((-4, -3, -2, -1), _np.int64), + (slice(None, None, -1), 2, slice(1, 5), 1), + np_int((slice(None, None, -1), 2, slice(1, 5), 1)), + np_int((slice(None, None, -1), 2, slice(1, 5), 1), _np.int64), + (slice(None, None, -1), 2, slice(1, 7, 2), 1), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), _np.int64), + (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), _np.int64), + (slice(1, 8, 2), 1, slice(3, 8), 2), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2), _np.int64), + [1], + [1, 2], + [2, 1, 3], + [7, 5, 0, 3, 6, 2, 1], + _np.array([6, 3], dtype=_np.int32), + _np.array([[3, 4], [0, 6]], dtype=_np.int32), + _np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int32), + _np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int64), + _np.array([[2], [0], [1]], dtype=_np.int32), + _np.array([[2], [0], [1]], dtype=_np.int64), + np.array([4, 7], dtype=_np.int32), + np.array([4, 7], dtype=_np.int64), + np.array([[3, 6], [2, 1]], dtype=_np.int32), + np.array([[3, 6], [2, 1]], dtype=_np.int64), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int64), + (1, [2, 3]), + (1, [2, 3], _np.array([[3], [0]], dtype=_np.int32)), + (1, [2, 3]), + (1, [2, 3], _np.array([[3], [0]], dtype=_np.int64)), + (1, [2], _np.array([[5], [3]], dtype=_np.int32), slice(None)), + (1, [2], _np.array([[5], [3]], dtype=_np.int64), slice(None)), + (1, [2, 3], _np.array([[6], [0]], dtype=_np.int32), slice(2, 5)), + (1, [2, 3], _np.array([[6], [0]], dtype=_np.int64), slice(2, 5)), + (1, [2, 3], _np.array([[4], [7]], dtype=_np.int32), slice(2, 5, 2)), + (1, [2, 3], _np.array([[4], [7]], dtype=_np.int64), slice(2, 5, 2)), + (1, [2], _np.array([[3]], dtype=_np.int32), slice(None, None, -1)), + (1, [2], _np.array([[3]], dtype=_np.int64), slice(None, None, -1)), + (1, [2], _np.array([[3]], dtype=_np.int32), np.array([[5, 7], [2, 4]], dtype=_np.int64)), + (1, [2], np.array([[4]], dtype=_np.int32), np.array([[1, 3], [5, 7]], dtype='int64')), + [0], + [0, 1], + [1, 2, 3], + [2, 0, 5, 6], + ([1, 1], [2, 3]), + ([1], [4], [5]), + ([1], [4], [5], [6]), + ([[1]], [[2]]), + ([[1]], [[2]], [[3]], [[4]]), + (slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), + ([[[[1]]]], [[1]], slice(0, 3), [1, 5]), + ([[[[1]]]], 3, slice(0, 3), [1, 3]), + ([[[[1]]]], 3, slice(0, 3), 0), + ([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), + ([1, 2], slice(3, 5), [2, 3], [3, 4]), + ([1, 2], slice(3, 5), (2, 3), [3, 4]), + range(4), + range(3, 0, -1), + (range(4,), [1]), + # slice(0, 0) does not support output zero-size tensor yet + ] + for index in index_list: + test_getitem(np_array, index) + test_setitem(np_array, index) + test_getitem_autograd(np_array, index) + if not isinstance(index, tuple) or len(index) != 0: + # When index = (), this is same a[()] = b is equivalent to b.copyto(a) + # which should have no problem to do autograd + test_setitem_autograd(np_array, index) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 360869020f2a..9804aea750ab 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -19,7 +19,7 @@ from __future__ import absolute_import import numpy as _np import mxnet as mx -from mxnet import np, npe +from mxnet import np, npx from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray from mxnet.test_utils import check_numeric_gradient @@ -79,7 +79,8 @@ def is_int(dtype): if itype == 'float32' and dtype == 'float32': x_sym = mx.sym.Variable("x").as_np_ndarray() mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray() - check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + check_numeric_gradient(mx_sym, [x.as_classic_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) # test imperative mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) @@ -88,7 +89,7 @@ def is_int(dtype): @with_seed() -@np.use_np_shape +@npx.use_np_shape def test_np_dot(): shapes = [ ((3, 0), (0, 4)), @@ -132,7 +133,7 @@ def test_np_dot(): @with_seed() def test_np_mean(): - @np.use_np_shape + @npx.use_np_shape class TestMean(HybridBlock): def __init__(self, axis=None, dtype=None, keepdims=False): super(TestMean, self).__init__() @@ -185,7 +186,8 @@ def is_int(dtype): if itype == 'float32' and dtype == 'float32': x_sym = mx.sym.Variable("x").as_np_ndarray() mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray() - check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + check_numeric_gradient(mx_sym, [x.as_classic_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) # test imperative mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) @@ -194,7 +196,6 @@ def is_int(dtype): @with_seed() -@np.use_np_shape def test_np_transpose(): # TODO(junwu): Add more test cases data = mx.sym.var('a').as_np_ndarray() @@ -224,39 +225,36 @@ def test_np_transpose(): @with_seed() -@np.use_np_shape -def test_relu(): +def test_npx_relu(): # TODO(junwu): Add more test cases data = mx.sym.var('data').as_np_ndarray() - ret = mx.sym.npe.relu(data) + ret = mx.sym.npx.relu(data) assert type(ret) == mx.sym.np._Symbol shapes = [(), (0, 2, 0)] shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) for shape in shapes: data = np.array(_np.random.uniform(size=shape).astype('float32')) - ret = npe.relu(data) + ret = npx.relu(data) assert type(ret) == np.ndarray @with_seed() -@np.use_np_shape -def test_sigmoid(): +def test_npx_sigmoid(): # TODO(junwu): Add more test cases data = mx.sym.var('data').as_np_ndarray() - ret = mx.sym.npe.sigmoid(data) + ret = mx.sym.npx.sigmoid(data) assert type(ret) == mx.sym.np._Symbol shapes = [(), (0, 2, 0)] shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) for shape in shapes: data = np.array(_np.random.uniform(size=shape).astype('float32')) - ret = npe.sigmoid(data) + ret = npx.sigmoid(data) assert type(ret) == np.ndarray @with_seed() -@np.use_np_shape def test_np_reshape(): # TODO(junwu): Add more test cases data = mx.sym.var('a').as_np_ndarray() @@ -272,7 +270,6 @@ def test_np_reshape(): @with_seed() -@np.use_np_shape def test_np_maximum(): # TODO(junwu): Add more test cases x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() @@ -293,7 +290,6 @@ def check_maximum(x1, x2): @with_seed() -@np.use_np_shape def test_np_minimum(): # TODO(junwu): Add more test cases x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() @@ -314,9 +310,9 @@ def check_minimum(x1, x2): @with_seed() -@mx.use_np_shape def test_np_unary_funcs(): def check_unary_func(func, ref_grad, shape, low, high): + @npx.use_np_shape class TestUnary(HybridBlock): def __init__(self, func): super(TestUnary, self).__init__() @@ -391,8 +387,8 @@ def hybrid_forward(self, F, a, *args, **kwargs): @with_seed() -@mx.use_np_shape def test_np_stack(): + @npx.use_np_shape class TestStack(HybridBlock): def __init__(self, axis=None): super(TestStack, self).__init__() @@ -442,6 +438,201 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +def test_np_random(): + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + dtypes = ['float16', 'float32', 'float64'] + op_names = ['uniform', 'normal'] + for shape in shapes: + for dtype in dtypes: + for op_name in op_names: + op = getattr(np.random, op_name, None) + assert op is not None + out = op(size=shape, dtype=dtype) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + @npx.use_np + class TestRandom(HybridBlock): + def __init__(self, shape, op_name): + super(TestRandom, self).__init__() + self._shape = shape + self._op_name = op_name + + def hybrid_forward(self, F, x): + op = getattr(F.np.random, self._op_name, None) + assert op is not None + return x + op(size=shape) + + x = np.ones(()) + for op_name in op_names: + for shape in shapes: + for hybridize in [False, True]: + net = TestRandom(shape, op_name) + if hybridize: + net.hybridize() + out = net(x) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + +@with_seed() +def test_np_arange(): + configs = [ + (1, 10, 2), + (1, 10, 4), + (1, -10, 4), + (1, -10, -2), + (1, -10, -4), + (2, 3), + (2, -3), + (-2, -3), + (-2, 3), + (4, 0, 5), + (-4, 0, 5), + (-4, 0, -5), + (0, 0), + (11, 11), + (0, 0, 2), + (0, 0, -2), + (0, 5, None), + (0, -5, None), + 0, + 6, + ] + dtypes = ['int32', 'float16', 'float32', 'float64', None] + for config in configs: + for dtype in dtypes: + if isinstance(config, tuple): + mx_ret = np.arange(*config, dtype=dtype) + np_ret = _np.arange(*config, dtype=dtype) + else: + mx_ret = np.arange(config, dtype=dtype) + np_ret = _np.arange(config, dtype=dtype) + assert same(mx_ret.asnumpy(), np_ret) + + @npx.use_np + class TestRange(HybridBlock): + def __init__(self, start, stop=None, step=None, dtype=None): + super(TestRange, self).__init__() + self._start = start + self._stop = stop + self._step = step + self._dtype = dtype + + def hybrid_forward(self, F, x): + return x + F.np.arange(self._start, self._stop, self._step, dtype=self._dtype) + + for dtype in dtypes: + x = np.zeros(shape=(), dtype=dtype) + for config in configs: + for hybridize in [False, True]: + if isinstance(config, tuple): + net = TestRange(*config, dtype=dtype) + np_out = _np.arange(*config, dtype=dtype) + else: + net = TestRange(config, dtype=dtype) + np_out = _np.arange(config, dtype=dtype) + if hybridize: + net.hybridize() + mx_out = net(x) + assert same(mx_out.asnumpy(), np_out) + + +@with_seed() +def test_np_argmax(): + workloads = [ + ((), 0, False), + ((), -1, False), + ((), 1, True), + ((5, 3), None, False), + ((5, 3), -1, False), + ((5, 3), 1, False), + ((5, 3), 3, True), + ((5, 0, 3), 0, False), + ((5, 0, 3), -1, False), + ((5, 0, 3), None, True), + ((5, 0, 3), 1, True), + ] + dtypes = ['float16', 'float32', 'float64'] + + @npx.use_np + class TestArgMax(HybridBlock): + def __init__(self, axis=None): + super(TestArgMax, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, x): + return F.np.argmax(x, self._axis) + + for shape, axis, throw_exception in workloads: + for dtype in dtypes: + a = np.random.uniform(size=shape, dtype=dtype) + if throw_exception: + # Cannot use assert_exception because sometimes the main thread + # proceeds to `assert False` before the exception is thrown + # in the worker thread. Have to use mx.nd.waitall() here + # to block the main thread. + try: + np.argmax(a, axis) + mx.nd.waitall() + assert False + except mx.MXNetError: + pass + else: + mx_ret = np.argmax(a, axis=axis) + np_ret = _np.argmax(a.asnumpy(), axis=axis) + assert same(mx_ret.asnumpy(), np_ret) + + for hybridize in [False, True]: + net = TestArgMax(axis) + if hybridize: + net.hybridize() + if throw_exception: + try: + net(a) + mx.nd.waitall() + assert False + except mx.MXNetError: + pass + else: + mx_ret = net(a) + assert same(mx_ret.asnumpy(), np_ret) + + +@with_seed() +def test_np_linalg_norm(): + @npx.use_np + class TestLinalgNorm(HybridBlock): + def __init__(self, ord=None, axis=None, keepdims=False): + super(TestLinalgNorm, self).__init__() + self._ord = ord + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, x): + return F.np.linalg.norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims) + + a = np.arange(5 * 6 * 7 * 8).reshape((5, 6, 7, 8)) + ords = [None, 'fro'] + axes = [None, (0, 2), (1, 0), (1, 2)] + for ord in ords: + for axis in axes: + if ord == 'fro' and axis is None and a.ndim > 2: + continue + for keepdims in [False, True]: + for hybridize in [False, True]: + net = TestLinalgNorm(ord, axis, keepdims) + if hybridize: + net.hybridize() + mx_ret = net(a) + np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index b553299ab4d7..ee56ba780a95 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -23,6 +23,7 @@ from mxnet.attribute import AttrScope from mxnet.name import NameManager from mxnet.test_utils import set_default_context +from mxnet.util import _NumpyArrayScope def test_context(): ctx_list = [] @@ -163,6 +164,41 @@ def f(): thread.join() assert status[0], "Failed to execute a symbolic graph within a thread" + +def test_np_array_scope(): + np_array_scope_list = [] + _NumpyArrayScope._current = _NumpyArrayScope(False) + np_array_scope_list.append(_NumpyArrayScope._current) + + def f(): + _NumpyArrayScope._current = _NumpyArrayScope(True) + np_array_scope_list.append(_NumpyArrayScope._current) + + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert len(np_array_scope_list) == 2 + assert not np_array_scope_list[0]._is_np_array + assert np_array_scope_list[1]._is_np_array + + event = threading.Event() + status = [False] + + def g(): + with mx.np_array(False): + event.wait() + if not mx.is_np_array(): + status[0] = True + + thread = threading.Thread(target=g) + thread.start() + _NumpyArrayScope._current = _NumpyArrayScope(True) + event.set() + thread.join() + event.clear() + assert status[0], "Spawned thread didn't set status correctly" + + if __name__ == '__main__': import nose nose.runmodule()