Skip to content

Commit

Permalink
Remove redundant tests (#92)
Browse files Browse the repository at this point in the history
* Remove redundant tests

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove redundant comments

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add xfail on flaky tests caused by HF hub errors

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Nov 29, 2024
1 parent f922edf commit b40ef22
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 130 deletions.
132 changes: 7 additions & 125 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,110 +51,10 @@ def forward_pass(self, algorithm: LightningModule, input: PyTree[torch.Tensor]):
return algorithm(**input)
return algorithm(input)

def test_initialization_is_deterministic(
self,
experiment_config: Config,
datamodule: lightning.LightningDataModule | None,
seed: int,
trainer: lightning.Trainer,
device: torch.device,
):
"""Checks that the weights initialization is consistent given the a random seed."""

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm_1 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_1, lightning.LightningModule)

with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_1._device = device
algorithm_1.configure_model()

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm_2 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_2, lightning.LightningModule)

with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_2._device = device
algorithm_2.configure_model()

torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())

def test_forward_pass_is_deterministic(
self, forward_pass_input: Any, algorithm: AlgorithmType, seed: int
):
"""Checks that the forward pass output is consistent given the a random seed and a given
input."""

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
out1 = self.forward_pass(algorithm, forward_pass_input)
with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
out2 = self.forward_pass(algorithm, forward_pass_input)

torch.testing.assert_close(out1, out2)

# @pytest.mark.timeout(10)
def test_backward_pass_is_deterministic(
self,
datamodule: LightningDataModule,
algorithm: AlgorithmType,
seed: int,
accelerator: str,
devices: int | list[int] | Literal["auto"],
tmp_path: Path,
):
"""Check that the backward pass is reproducible given the same input, weights, and random
seed."""

algorithm_1 = copy.deepcopy(algorithm)
algorithm_2 = copy.deepcopy(algorithm)

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
gradients_callback = GetStuffFromFirstTrainingStep()
self.do_one_step_of_training(
algorithm_1,
datamodule,
accelerator,
devices=devices,
callbacks=[gradients_callback],
tmp_path=tmp_path / "run1",
)

batch_1 = gradients_callback.batch
gradients_1 = gradients_callback.grads
training_step_outputs_1 = gradients_callback.outputs

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
gradients_callback = GetStuffFromFirstTrainingStep()
self.do_one_step_of_training(
algorithm_2,
datamodule,
accelerator=accelerator,
devices=devices,
callbacks=[gradients_callback],
tmp_path=tmp_path / "run2",
)
batch_2 = gradients_callback.batch
gradients_2 = gradients_callback.grads
training_step_outputs_2 = gradients_callback.outputs

torch.testing.assert_close(batch_1, batch_2)
torch.testing.assert_close(gradients_1, gradients_2)
torch.testing.assert_close(training_step_outputs_1, training_step_outputs_2)

def test_initialization_is_reproducible(
self,
experiment_config: Config,
datamodule: lightning.LightningDataModule,
datamodule: lightning.LightningDataModule | None,
seed: int,
tensor_regression: TensorRegressionFixture,
trainer: lightning.Trainer,
Expand All @@ -165,14 +65,15 @@ def test_initialization_is_reproducible(
torch.random.manual_seed(seed)
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
assert isinstance(algorithm, lightning.LightningModule)
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer here.
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = device
algorithm.configure_model()

tensor_regression.check(
algorithm.state_dict(),
# todo: is this necessary? Shouldn't the weights be the same on CPU and GPU?
# Save the regression files on a different subfolder for each device (cpu / cuda)
additional_label=next(algorithm.parameters()).device.type,
include_gpu_name_in_stats=False,
Expand Down Expand Up @@ -236,33 +137,14 @@ def test_backward_pass_is_reproducible(
"grads": gradients_callback.grads,
"outputs": outputs,
},
default_tolerance={"rtol": 1e-5, "atol": 1e-6}, # some tolerance for the jax example.
# todo: this tolerance was mainly added for the jax example.
default_tolerance={"rtol": 1e-5, "atol": 1e-6}, # some tolerance
# todo: check if this actually differs between cpu / gpu.
# Save the regression files on a different subfolder for each device (cpu / cuda)
additional_label=accelerator if accelerator not in ["auto", "gpu"] else None,
include_gpu_name_in_stats=False,
)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# algorithm_under_test = _get_algorithm_class_from_generic_arg(cls)
# # find all algorithm configs that create algorithms of this type.
# configs_for_this_algorithm = get_all_configs_in_group_with_target(
# "algorithm", algorithm_under_test
# )
# # assert not hasattr(cls, "algorithm_config"), cls
# cls.algorithm_config = ParametrizedFixture(
# name="algorithm_config",
# values=configs_for_this_algorithm,
# ids=configs_for_this_algorithm,
# ,
# )

# TODO: Could also add a parametrize_when_used mark to parametrize the datamodule, network,
# etc, based on the type annotations of the algorithm constructor? For example, if an algo
# shows that it accepts any LightningDataModule, then parametrize it with all the datamodules,
# but if the algo says it only works with ImageNet, then parametrize with all the configs
# that have the ImageNet datamodule as their target (or a subclass of ImageNetDataModule).

@pytest.fixture(scope="session")
def forward_pass_input(self, training_batch: PyTree[torch.Tensor], device: torch.device):
"""Extracts the model input from a batch of data coming from the dataloader.
Expand Down
7 changes: 2 additions & 5 deletions project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from project.utils.typing_utils import is_sequence_of


# @use_overrides(["datamodule.num_workers=0"])
# @pytest.mark.timeout(25, func_only=True)
@pytest.mark.slow
@pytest.mark.parametrize(
"stage",
Expand All @@ -47,9 +45,8 @@ def test_first_batch(
stage: RunningStage,
datadir: Path,
):
# todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI).

# TODO: This causes hanging issues when tests fail, since dataloader workers aren't cleaned up.
# Note: using dataloader workers in tests can cause issues, since if a test fails, dataloader
# workers aren't always cleaned up properly.
if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"):
datamodule.num_workers = 0 # type: ignore

Expand Down
6 changes: 6 additions & 0 deletions project/datamodules/text/text_classification_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import huggingface_hub.errors
import lightning
import pytest

Expand Down Expand Up @@ -61,6 +62,11 @@ def prepared_datamodule(
datamodule.working_path = _slurm_tmpdir_before


@pytest.mark.xfail(
raises=huggingface_hub.errors.HfHubHTTPError,
strict=False,
reason="Can sometimes get 'Too many requests for url'",
)
@pytest.mark.parametrize(datamodule.__name__, datamodule_configs, indirect=True)
def test_dataset_location(
prepared_datamodule: TextClassificationDataModule,
Expand Down

0 comments on commit b40ef22

Please sign in to comment.