diff --git a/project/algorithms/jax_ppo_test.py b/project/algorithms/jax_ppo_test.py index ecc17f75..bbb44a6b 100644 --- a/project/algorithms/jax_ppo_test.py +++ b/project/algorithms/jax_ppo_test.py @@ -3,7 +3,6 @@ import dataclasses import functools import operator -import sys import time from collections.abc import Callable, Iterable, Sequence from logging import getLogger @@ -206,7 +205,7 @@ def test_ours_with_trainer( algo.visualize(ts_i, gif_path=gif_path, eval_rng=eval_rng_i) -@pytest.mark.xfail(sys.platform == "darwin" and IN_GITHUB_CI, reason="Fails on macOS in CI.") +@pytest.mark.xfail(not torch.cuda.is_available(), reason="Fails on CPU in the CI") def test_results_are_same_with_or_without_jax_trainer( results_ours: tuple[PPOState, EvalMetrics], results_ours_with_trainer: tuple[PPOState, EvalMetrics],