From c89b0ceec36129f8c8dc96aacc9dc4678d27b3bb Mon Sep 17 00:00:00 2001 From: trbromley Date: Thu, 4 Feb 2021 11:26:33 -0500 Subject: [PATCH] Add jax skip --- tests/conftest.py | 5 +++++ tests/qnn/test_cost.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index b36df94e564..6753c647cfd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -140,6 +140,11 @@ def skip_if_no_tf_support(tf_support): pytest.skip("Skipped, no tf support") +@pytest.fixture +def skip_if_no_jax_support(): + pytest.importorskip("jax") + + @pytest.fixture(scope="module", params=[1, 2, 3]) def seed(request): diff --git a/tests/qnn/test_cost.py b/tests/qnn/test_cost.py index 57113403f5d..716dc5ef96a 100644 --- a/tests/qnn/test_cost.py +++ b/tests/qnn/test_cost.py @@ -42,7 +42,9 @@ def skip_if_no_torch_support(): @pytest.mark.parametrize("interface", ALLOWED_INTERFACES) -@pytest.mark.usefixtures("skip_if_no_torch_support", "skip_if_no_tf_support") +@pytest.mark.usefixtures( + "skip_if_no_torch_support", "skip_if_no_tf_support", "skip_if_no_jax_support" +) class TestSquaredErrorLoss: def test_no_target(self, interface): with pytest.raises(ValueError, match="The target cannot be None"):