Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
np.broadcast_to extension
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jan 17, 2020
1 parent 22c7ef7 commit 03eaedc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,21 @@ bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::shape_is_known(ishape)) return false;
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
CHECK(mxnet::shape_is_known(param.shape))
<< "the objective shape for broadcasting array must be known";
CHECK_LE(ishape.ndim(), param.shape.ndim())
<< "shape " << ishape << " is not broadcastable to " << param.shape;
TShape pshape = param.shape;
for (int i = param.shape.ndim() - 1; i >= 0; --i) {
int j = i - param.shape.ndim() + ishape.ndim();
if (j < 0) break;
CHECK(ishape[j] == param.shape[i] || ishape[j] == 1)
<< "shape " << ishape << " is not broadcastable to " << param.shape;
if (pshape[i] == -2) {
pshape[i] = ishape[j];
}
CHECK(ishape[j] == pshape[i] || ishape[j] == 1)
<< "shape " << ishape << " is not broadcastable to " << pshape;
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
CHECK(mxnet::shape_is_known(pshape))
<< "the objective shape for broadcasting array must be known";
SHAPE_ASSIGN_CHECK(*out_attrs, 0, pshape);
return true;
}

Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def hybrid_forward(self, F, x):
((4, 1), (1, 2, 3, 4, 5)),
((4, 1), (1, 0, 3, 4, 5))
]

for src_shape, dst_shape in shapes:
for hybridize in [True, False]:
test_broadcast_to = TestBroadcastTo(dst_shape)
Expand Down Expand Up @@ -1578,6 +1579,32 @@ def hybrid_forward(self, F, x):
ret = test_scalar_broadcast_to(np.empty(()))
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)

# Test npx functionality
shapes = [
((5,), (3, 4, -2), (3, 4, 5)),
((5,), (0, -2), (0, 5)),
((1, 0), (2, -2, -2), (2, 1, 0)),
((3, 4), (1, 2, 3, -2), (1, 2, 3, 4)),
((3, 4), (1, 0, -2, 4), (1, 0, 3, 4))
]

for src_shape, npx_dst_shape, np_dst_shape in shapes:
for hybridize in [True, False]:
test_broadcast_to = TestBroadcastTo(npx_dst_shape)
if hybridize:
test_broadcast_to.hybridize()

a = _np.random.uniform(size=src_shape).astype(np.float32)
expected_ret = _np.broadcast_to(a, np_dst_shape)
a_mx = np.array(a, dtype=a.dtype)
a_mx.attach_grad()
with mx.autograd.record():
ret = test_broadcast_to(a_mx)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)
ret.backward()
expected_grad = collapse_sum_like(_np.ones_like(expected_ret), src_shape)
assert_almost_equal(a_mx.grad.asnumpy(), expected_grad, rtol=1e-5, atol=1e-6, use_broadcast=False)


@with_seed()
@use_np
Expand Down

0 comments on commit 03eaedc

Please sign in to comment.