Skip to content

Commit

Permalink
Require 16gb vram for finetuning tests
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Nov 15, 2024
1 parent 827d09f commit b84bba7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 2 additions & 11 deletions project/algorithms/llm_finetuning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
)
from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests
from project.configs.config import Config
from project.conftest import command_line_overrides
from project.utils.env_vars import SLURM_JOB_ID
from project.utils.testutils import IN_GITHUB_COULD_CI, run_for_all_configs_of_type
from project.utils.testutils import run_for_all_configs_of_type, total_vram_gb
from project.utils.typing_utils import PyTree
from project.utils.typing_utils.protocols import DataModule

Expand Down Expand Up @@ -77,14 +75,7 @@ def _tuple_to_ndarray(v: tuple) -> np.ndarray:
return [to_ndarray(v_i) for v_i in v] # type: ignore


@pytest.mark.skipif(
IN_GITHUB_COULD_CI, reason="This test is too resource-intensive to run on the GitHub CI."
)
@pytest.mark.parametrize(
command_line_overrides.__name__,
["trainer.strategy=auto" if SLURM_JOB_ID is None else ""],
indirect=True,
)
@pytest.mark.skipif(total_vram_gb() < 16, reason="Not enough VRAM to run this test.")
@run_for_all_configs_of_type("algorithm", LLMFinetuningExample)
class TestLLMFinetuningExample(LearningAlgorithmTests[LLMFinetuningExample]):
@pytest.fixture(scope="function")
Expand Down
14 changes: 14 additions & 0 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from logging import getLogger as get_logger

import pytest
import torch
import torchvision.models

from project.datamodules.image_classification.fashion_mnist import FashionMNISTDataModule
Expand Down Expand Up @@ -207,3 +208,16 @@ def run_for_all_configs_in_group(
],
indirect=True,
)


def total_vram_gb() -> float:
"""Returns the total VRAM in GB."""
if not torch.cuda.is_available():
return 0.0
return (
sum(
torch.cuda.get_device_properties(i).total_memory
for i in range(torch.cuda.device_count())
)
/ 1024**3
)

0 comments on commit b84bba7

Please sign in to comment.