Skip to content

Commit

Permalink
#3188: refactor get_bcast_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Nov 23, 2023
1 parent 16464c9 commit ca97a68
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def compare(tt_out, torch_out, atol=0.2, rtop=0.2):
return allclose_result, isclose_true_ratio


@skip_for_wormhole_b0
@pytest.mark.parametrize(
"input_a_shape",
(
Expand Down Expand Up @@ -126,6 +127,7 @@ def test_moreh_matmul(input_a_shape, input_b_shape, device):
assert allclose_result or isclose_true_ratio > 0.95


@skip_for_wormhole_b0
@pytest.mark.parametrize(
"input_a_shape",
(
Expand Down Expand Up @@ -190,6 +192,7 @@ def test_batched_moreh_matmul(input_a_shape, input_b_shape, device):
assert allclose_result or isclose_true_ratio > 0.95


@skip_for_wormhole_b0
@pytest.mark.parametrize(
"input_a_shape",
(
Expand Down Expand Up @@ -250,6 +253,7 @@ def test_moreh_matmul_transpose_b(input_a_shape, input_b_shape, device):
assert allclose_result or isclose_true_ratio > 0.95


@skip_for_wormhole_b0
@pytest.mark.parametrize(
"input_a_shape",
(
Expand Down Expand Up @@ -314,6 +318,7 @@ def test_batched_moreh_matmul_transpose_b(input_a_shape, input_b_shape, device):
assert allclose_result or isclose_true_ratio > 0.95


@skip_for_wormhole_b0
@pytest.mark.parametrize(
"input_shape",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ namespace operations {
namespace primary {

std::tuple<bool, bool> get_bcast_batch(const Shape &input0_shape, const Shape &input1_shape) {
bool in0_bcast_batch = input0_shape[1] < input1_shape[1] ? true : false;
bool in1_bcast_batch = input0_shape[1] > input1_shape[1] ? true : false;
return {in0_bcast_batch, in1_bcast_batch};
return {(input0_shape[1] < input1_shape[1]), (input0_shape[1] > input1_shape[1])};
}

operation::ProgramWithCallbacks moreh_matmul_multi_core(
Expand Down

0 comments on commit ca97a68

Please sign in to comment.