diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 34218e3b36f2..6c83e1f6b178 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -24,7 +24,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax'] @set_module('mxnet.ndarray.numpy') @@ -277,3 +277,28 @@ def argmax(a, axis=None, out=None): with the dimension along `axis` removed. """ return _npi.argmax(a, axis=axis, keepdims=False, out=out) + + +@set_module('mxnet.ndarray.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + + Returns + ------- + res : ndarray + The concatenated array. + """ + return _npi.concatenate(*seq, dim=axis, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 212dfe30d293..6b3dcde73490 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -37,8 +37,8 @@ from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi -__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', - 'argmax'] +__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', + 'concatenate', 'arange', 'argmax'] # This function is copied from ndarray.py since pylint @@ -1486,3 +1486,28 @@ def argmax(a, axis=None, out=None): with the dimension along `axis` removed. """ return _mx_nd_np.argmax(a, axis, out) + + +@set_module('mxnet.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + + Returns + ------- + res : ndarray + The concatenated array. + """ + return _mx_nd_np.concatenate(seq, axis=axis, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index b2d8a5bd6b10..7a555471f8c1 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -29,7 +29,7 @@ from .._internal import _set_np_symbol_class from . import _internal as _npi -__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax'] @set_module('mxnet.symbol.numpy') @@ -1060,6 +1060,31 @@ def get_list(arrays): return _npi.stack(*arrays, axis=axis, out=out) +@set_module('mxnet.symbol.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + + Returns + ------- + res : ndarray + The concatenated array. + """ + return _npi.concatenate(*seq, dim=axis, out=out) + + @set_module('mxnet.symbol.numpy') def arange(start, stop=None, step=1, dtype=None, ctx=None): """Return evenly spaced values within a given interval. diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 8fb229889332..cda9c9a01f0d 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -32,9 +32,9 @@ namespace mxnet { namespace op { -static bool ConcatShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { +bool ConcatShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { using namespace mshadow; const ConcatParam& param_ = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); @@ -138,9 +138,9 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs, return shape_is_known(dshape); } -static bool ConcatType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { +bool ConcatType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { const ConcatParam& param_ = nnvm::get(attrs.parsed); int dtype = -1; diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index db479a0331d8..80d70e5269ff 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -24,6 +24,7 @@ */ #include "./np_matrix_op-inl.h" +#include "../nn/concat-inl.h" namespace mxnet { namespace op { @@ -252,5 +253,62 @@ Examples:: .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") .add_arguments(StackParam::__FIELDS__()); +bool ConcatShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape); + +bool ConcatType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type); + +struct NumpyConcatGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + CHECK_EQ(ograds.size(), 1); + std::vector heads(ograds.begin(), ograds.end()); + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + + +NNVM_REGISTER_OP(_npi_concatenate) +.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("data") + std::to_string(i)); + } + return ret; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"out"}; +}) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferType", ConcatType) +.set_attr("FInferShape", ConcatShape) +.set_attr("FCompute", ConcatCompute) +.set_attr("FGradient", NumpyConcatGrad{"_backward_np_concat"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_concat) +.set_num_outputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", ConcatGradCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 615dd26ff246..5980e81f1a0b 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -23,6 +23,7 @@ * \brief GPU Implementation of numpy matrix operations */ #include "./np_matrix_op-inl.h" +#include "../nn/concat-inl.h" namespace mxnet { namespace op { @@ -36,5 +37,8 @@ NNVM_REGISTER_OP(_np_reshape) NNVM_REGISTER_OP(_npi_stack) .set_attr("FCompute", StackOpForward); +NNVM_REGISTER_OP(_npi_concatenate) +.set_attr("FCompute", ConcatCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc index f7a810b1e404..5835701497d9 100644 --- a/src/operator/quantization/quantized_concat.cc +++ b/src/operator/quantization/quantized_concat.cc @@ -28,8 +28,8 @@ namespace mxnet { namespace op { -static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape, - mxnet::ShapeVector* out_shape) { +static bool QuantizedConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { const ConcatParam& param_ = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), static_cast(param_.num_args * 3)); CHECK_EQ(out_shape->size(), 3U); @@ -74,8 +74,8 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha return shape_is_known(dshape); } -static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector* in_type, - std::vector* out_type) { +static bool QuantizedConcatType(const nnvm::NodeAttrs& attrs, std::vector* in_type, + std::vector* out_type) { const ConcatParam& param_ = nnvm::get(attrs.parsed); CHECK_EQ(in_type->size(), static_cast(param_.num_args * 3)); CHECK_EQ(out_type->size(), 3U); @@ -130,8 +130,8 @@ If any input holds int8, then the output will be int8. Otherwise output will be // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) -.set_attr("FInferType", ConcatType) -.set_attr("FInferShape", ConcatShape) +.set_attr("FInferType", QuantizedConcatType) +.set_attr("FInferShape", QuantizedConcatShape) .set_attr("key_var_num_args", "num_args") .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") .add_arguments(ConcatParam::__FIELDS__()); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9804aea750ab..d00573ebfbfc 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -633,6 +633,57 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4) +@with_seed() +@npx.use_np_shape +def test_np_concat(): + class TestConcat(HybridBlock): + def __init__(self, axis=None): + super(TestConcat, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, a, *args): + return F.np.concatenate([a] + list(args), axis=self._axis) + + def get_new_shape(shape, axis): + shape_lst = list(shape) + shape_lst[axis] = random.randint(0, 3) + return tuple(shape_lst) + + for shape in [(0, 0), (2, 3)]: + for hybridize in [True, False]: + for axis in range(2): + # test gluon + test_concat = TestConcat(axis=axis) + if hybridize: + test_concat.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + a.attach_grad() + b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + b.attach_grad() + c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + c.attach_grad() + d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + d.attach_grad() + expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + with mx.autograd.record(): + y = test_concat(a, b, c, d) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + + y.backward() + + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.concatenate([a, b, c, d], axis=axis) + np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule()