Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use @mx.use_np_compat instead of mx.np_compat in index_array op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nickguletskii committed May 13, 2019
1 parent 3a430da commit 139567b
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8068,27 +8068,27 @@ def test_index_array_default():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])

@mx.use_np_compat
def test_index_array_default_zero_dim():
with mx.np_compat(active=True):
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)

input_array = np.ones(())
expected = np.zeros((0,))
input_array = np.ones(())
expected = np.zeros((0,))

check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])

@mx.use_np_compat
def test_index_array_default_zero_size():
with mx.np_compat(active=True):
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)

input_array = np.ones((0, 0, 0))
expected = np.zeros((0, 0, 0, 3))
input_array = np.ones((0, 0, 0))
expected = np.zeros((0, 0, 0, 3))

check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])

def test_index_array_select_axes():
shape = (5, 7, 11, 13, 17, 19)
Expand All @@ -8103,16 +8103,16 @@ def test_index_array_select_axes():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])

@mx.use_np_compat
def test_index_array_select_axes_zero_size():
with mx.np_compat(active=True):
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data, axes=(2, 1))
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data, axes=(2, 1))

input_array = np.ones((0, 0, 0, 0))
expected = np.zeros((0, 0, 2))
input_array = np.ones((0, 0, 0, 0))
expected = np.zeros((0, 0, 2))

check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)])

test_index_array_default()
test_index_array_default_zero_dim()
Expand Down

0 comments on commit 139567b

Please sign in to comment.