diff --git a/tests/conftest.py b/tests/conftest.py index b36df94e564..6753c647cfd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -140,6 +140,11 @@ def skip_if_no_tf_support(tf_support): pytest.skip("Skipped, no tf support") +@pytest.fixture +def skip_if_no_jax_support(): + pytest.importorskip("jax") + + @pytest.fixture(scope="module", params=[1, 2, 3]) def seed(request): diff --git a/tests/qnn/test_cost.py b/tests/qnn/test_cost.py index 57113403f5d..716dc5ef96a 100644 --- a/tests/qnn/test_cost.py +++ b/tests/qnn/test_cost.py @@ -42,7 +42,9 @@ def skip_if_no_torch_support(): @pytest.mark.parametrize("interface", ALLOWED_INTERFACES) -@pytest.mark.usefixtures("skip_if_no_torch_support", "skip_if_no_tf_support") +@pytest.mark.usefixtures( + "skip_if_no_torch_support", "skip_if_no_tf_support", "skip_if_no_jax_support" +) class TestSquaredErrorLoss: def test_no_target(self, interface): with pytest.raises(ValueError, match="The target cannot be None"):