Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix broadcast_like (#20169)
Browse files Browse the repository at this point in the history
  • Loading branch information
barry-jin authored Apr 15, 2021
1 parent f591b62 commit 2cac53b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/api/operator/numpy_extension/npx_broadcast_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ MXNET_REGISTER_API("_npx.broadcast_like")
// lhs_axes
if (args[2].type_code() == kNull) {
param.lhs_axes = dmlc::optional<mxnet::TShape>();
} else if (args[2].type_code() == kDLInt) {
param.lhs_axes = TShape(1, args[2].operator int64_t());
} else {
param.lhs_axes = mxnet::TShape(args[2].operator ObjectRef());
}
// rhs_axes
if (args[2].type_code() == kNull) {
if (args[3].type_code() == kNull) {
param.rhs_axes = dmlc::optional<mxnet::TShape>();
} else if (args[3].type_code() == kDLInt) {
param.rhs_axes = TShape(1, args[3].operator int64_t());
} else {
param.rhs_axes = mxnet::TShape(args[2].operator ObjectRef());
param.rhs_axes = mxnet::TShape(args[3].operator ObjectRef());
}

attrs.op = op;
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10338,3 +10338,14 @@ def test_modulated_deformable_convolution(num_batch, num_channel_data, num_defor
rtol, atol = 1.0, 1e-2
else:
rtol, atol = 0.05, 1e-3


@use_np
def test_broadcast_like_different_types():
x = mx.np.zeros((2, 1))
y = mx.np.ones((2, 2))

y = mx.np.array(y).astype('int32')
z = mx.npx.broadcast_like(x, y, 1, 1)
assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]]))
assert x.dtype == z.dtype

0 comments on commit 2cac53b

Please sign in to comment.