Skip to content

Commit

Permalink
[numpy] Fix np branch after rebase (apache#15086)
Browse files Browse the repository at this point in the history
* Add np_array semantics for Gluon

Fix notebook

Fix sanity

Fix gluon deferred infer shape

Add np.random.uniform

Add random normal

Add boolean comparison ops

Add np.ndarray indexing

Reformat test ndarray indexing

Fix unit tests

Add one more test of indexing

Fix sanity

Enable amp test

Add np.arange

Revert cython unit test to ctypes

Delete unnecessary use_np_shape decorator from test

Rebase with numpy branch

support range as index

Fix python2 range type check

Add argmax

Disable clojure test

* Fix ci

* Add np.linalg.norm for ord='fro'

* Fix pylint
  • Loading branch information
reminisce authored and haojin2 committed Aug 1, 2019
1 parent 26e682e commit 2170697
Show file tree
Hide file tree
Showing 41 changed files with 1,836 additions and 148 deletions.
18 changes: 12 additions & 6 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def compile_unix_cpu_openblas() {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_openblas', false)
utils.pack_lib('cpu', mx_lib_cython, true)
// utils.pack_lib('cpu', mx_lib_cython, true)
utils.pack_lib('cpu', mx_lib, true)
}
}
}
Expand Down Expand Up @@ -266,7 +267,8 @@ def compile_unix_cmake_gpu() {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_gpu_cu101', 'build_ubuntu_gpu_cmake', false)
utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
// utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
utils.pack_lib('cmake_gpu', mx_cmake_lib, true)
}
}
}
Expand Down Expand Up @@ -643,8 +645,10 @@ def test_unix_python2_cpu() {
node(NODE_LINUX_CPU) {
ws('workspace/ut-python2-cpu') {
try {
utils.unpack_and_init('cpu', mx_lib_cython, true)
python2_ut_cython('ubuntu_cpu')
// utils.unpack_and_init('cpu', mx_lib_cython, true)
// python2_ut_cython('ubuntu_cpu')
utils.unpack_and_init('cpu', mx_lib, true)
python2_ut('ubuntu_cpu')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_unittest.xml', 'nosetests_python2_cpu_unittest.xml')
Expand Down Expand Up @@ -745,8 +749,10 @@ def test_unix_python3_gpu() {
node(NODE_LINUX_GPU) {
ws('workspace/ut-python3-gpu') {
try {
utils.unpack_and_init('gpu', mx_lib_cython, true)
python3_gpu_ut_cython('ubuntu_gpu_cu101')
// utils.unpack_and_init('gpu', mx_lib_cython, true)
// python3_gpu_ut_cython('ubuntu_gpu_cu100')
utils.unpack_and_init('gpu', mx_lib, true)
python3_gpu_ut('ubuntu_gpu_cu101')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_gpu.xml')
Expand Down
4 changes: 2 additions & 2 deletions ci/jenkins/Jenkinsfile_unix_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ core_logic: {
custom_steps.test_unix_python3_mkldnn_mkl_cpu(),
custom_steps.test_unix_scala_cpu(),
custom_steps.test_unix_scala_mkldnn_cpu(),
custom_steps.test_unix_clojure_cpu(),
custom_steps.test_unix_clojure_integration_cpu(),
// custom_steps.test_unix_clojure_cpu(),
// custom_steps.test_unix_clojure_integration_cpu(),
custom_steps.test_unix_perl_cpu(),
custom_steps.test_unix_r_cpu(),
custom_steps.test_unix_r_mkldnn_cpu(),
Expand Down
2 changes: 1 addition & 1 deletion example/numpy/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@
"from mxnet import gluon, autograd, np\n",
"\n",
"\n",
"@np.use_np_compat\n",
"@np.use_np\n",
"class LinearRegression(gluon.HybridBlock):\n",
" def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n",
" super(LinearRegression, self).__init__()\n",
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from . import engine
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
from . import base
from . import contrib
from . import ndarray
from . import ndarray as nd
from . import numpy
from . import numpy_extension
from . import numpy as np
from . import numpy_extension as npe
from . import numpy_extension as npx
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = True if isinstance(sym, _Symbol) else False
self.is_np_sym = bool(isinstance(sym, _Symbol))

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def _sanity_check_params(func_name, unsupported_params, param_dict):
_NP_OP_PREFIX = '_np_'
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']

_NP_EXT_OP_PREFIX = '_npe_'
_NP_EXT_OP_PREFIX = '_npx_'

_NP_INTERNAL_OP_PREFIX = '_npi_'

Expand Down Expand Up @@ -813,14 +813,14 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
op_names.append(name)

if mx_module_name is None:
# register np/npe ops for imperative programming
# register np/npx ops for imperative programming
op_module_name = "%s.%s._op" % (root_module_name, np_module_name) # e.g. mxnet.numpy._op
op_submodule_name = "%s.%s" % (root_module_name, np_module_name) # e.g. mxnet.numpy.random
elif mx_module_name == 'ndarray' or mx_module_name == 'symbol':
# register numpy internal ops and np/npe ops for use in Gluon
elif mx_module_name in ('ndarray', 'symbol'):
# register numpy internal ops and np/npx ops for use in Gluon
# np internal ops are registered in mxnet.ndarray/symbol.numpy._internal
# np ops are registered in mxnet.ndarray/symbol.numpy._op
# npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op
# npx ops are registered in mxnet.ndarray/symbol.numpy_extension._op
op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
if op_name_prefix != _NP_INTERNAL_OP_PREFIX:
op_module_name += '._op'
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle
from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np


Expand Down Expand Up @@ -551,7 +552,7 @@ def __call__(self, *args):

for hook in self._forward_hooks.values():
hook(self, args, out)
if _mx_np.is_np_shape():
if _mx_npx.is_np_array():
_check_all_np_ndarrays(_flatten(out, "output")[0])
return out

Expand Down
13 changes: 7 additions & 6 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..context import Context, cpu
from .. import autograd
from .utils import _indent, _brief_print_list, shape_is_known
from ..util import is_np_shape
from ..util import is_np_shape, is_np_array

# pylint: disable= invalid-name
tensor_types = (symbol.Symbol, ndarray.NDArray)
Expand Down Expand Up @@ -188,9 +188,9 @@ def shape(self, new_shape):
if self._shape is None:
self._shape = new_shape
return
unknown_dim_size = -1 if is_np_shape() else 0

assert len(self._shape) == len(new_shape) and \
all(j in (unknown_dim_size, i) for i, j in zip(new_shape, self._shape)), \
all(j in (0, i) for i, j in zip(new_shape, self._shape)), \
"Expected shape %s is incompatible with given shape %s."%(
str(new_shape), str(self._shape))

Expand Down Expand Up @@ -317,6 +317,7 @@ def _finish_deferred_init(self):
return
init, ctx, default_init, data = self._deferred_init
self._deferred_init = ()

assert shape_is_known(self.shape), \
"Cannot initialize Parameter '%s' because it has " \
"invalid shape: %s. Please specify in_units, " \
Expand All @@ -330,7 +331,7 @@ def _finish_deferred_init(self):
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)
# TODO(junwu): use np random operators when available
if is_np_shape():
if is_np_array():
data = data.as_np_ndarray() # convert to np.ndarray

self._init_impl(data, ctx)
Expand All @@ -357,7 +358,7 @@ def _init_grad(self):
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
stype=self._grad_stype) for i in self._data]
# TODO(junwu): use np.zeros
if is_np_shape():
if is_np_array():
self._grad = [arr.as_np_ndarray() for arr in self._grad]

autograd.mark_variables(self._check_and_get(self._data, list),
Expand Down Expand Up @@ -606,7 +607,7 @@ def var(self):
self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype,
lr_mult=self.lr_mult, wd_mult=self.wd_mult,
init=self.init, stype=self._stype)
if is_np_shape():
if is_np_array():
self._var = self._var.as_np_ndarray()
return self._var

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _check_same_symbol_type(symbols):
the symbols."""
from ..symbol.numpy import _Symbol as np_symbol
from ..symbol import Symbol as classic_symbol
is_np_sym = True if isinstance(symbols[0], np_symbol) else False
is_np_sym = bool(isinstance(symbols[0], np_symbol))
for s in symbols[1:]:
if is_np_sym != isinstance(s, np_symbol):
raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .sparse import _ndarray_cls
from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
from . import numpy as np
from . import numpy_extension as npe
from . import numpy_extension as npx

__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
78 changes: 77 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack']
__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -201,3 +201,79 @@ def get_list(arrays):

arrays = get_list(arrays)
return _npi.stack(*arrays, axis=axis, out=out)


@set_module('mxnet.ndarray.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.
Values are generated within the half-open interval ``[start, stop)``
(in other words, the interval including `start` but excluding `stop`).
For integer arguments the function is equivalent to the Python built-in
`range` function, but returns an ndarray rather than a list.
Parameters
----------
start : number, optional
Start of interval. The interval includes this value. The default
start value is 0.
stop : number
End of interval. The interval does not include this value, except
in some cases where `step` is not an integer and floating point
round-off affects the length of `out`.
step : number, optional
Spacing between values. For any output `out`, this is the distance
between two adjacent values, ``out[i+1] - out[i]``. The default
step size is 1. If `step` is specified as a position argument,
`start` must also be given.
dtype : dtype
The type of the output array. The default is `float32`.
Returns
-------
arange : ndarray
Array of evenly spaced values.
For floating point arguments, the length of the result is
``ceil((stop - start)/step)``. Because of floating point overflow,
this rule may result in the last element of `out` being greater
than `stop`.
"""
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if stop is None:
stop = start
start = 0
if step is None:
step = 1
if start is None and stop is None:
raise ValueError('start and stop cannot be both None')
if step == 0:
raise ZeroDivisionError('step cannot be 0')
return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)


@set_module('mxnet.ndarray.numpy')
def argmax(a, axis=None, out=None):
"""Returns the indices of the maximum values along an axis.
Parameters
----------
a : ndarray
Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
axis : int, optional
By default, the index is into the flattened array, otherwise
along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
Returns
-------
index_array : ndarray of indices whose dtype is same as the input ndarray.
Array of indices into the array. It has the same shape as `a.shape`
with the dimension along `axis` removed.
"""
return _npi.argmax(a, axis=axis, keepdims=False, out=out)
50 changes: 49 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,52 @@

"""Namespace for operators used in Gluon dispatched by F=ndarray."""

__all__ = []
from __future__ import absolute_import
from . import _op as _mx_nd_np

__all__ = ['norm']


def norm(x, ord=None, axis=None, keepdims=False):
r"""Matrix or vector norm.
This function can only support Frobenius norm for now.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
Parameters
----------
x : ndarray
Input array.
ord : {'fro'}, optional
Order of the norm.
axis : {int, 2-tuple of ints, None}, optional
If `axis` is an integer, it specifies the axis of `x` along which to
compute the vector norms. If `axis` is a 2-tuple, it specifies the
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None, the norm of the whole ndarray is
returned.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `x`.
Returns
-------
n : float or ndarray
Norm of the matrix or vector(s).
References
----------
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
"""
if ord is not None and ord != 'fro':
raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord)))
if isinstance(axis, tuple) and len(axis) > 2:
raise ValueError('Improper number of dimensions to norm')
if ord == 'fro' and x.ndim > 2 and axis is None:
raise ValueError('Improper number of dimensions to norm')
return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims))
Loading

0 comments on commit 2170697

Please sign in to comment.