diff --git a/.regression_files/project/algorithms/llm_finetuning_test/test_training_batch_doesnt_change/llm_finetuning.yaml b/.regression_files/project/algorithms/llm_finetuning_test/test_training_batch_doesnt_change/llm_finetuning.yaml new file mode 100644 index 00000000..9a3de835 --- /dev/null +++ b/.regression_files/project/algorithms/llm_finetuning_test/test_training_batch_doesnt_change/llm_finetuning.yaml @@ -0,0 +1,28 @@ +GPU: Quadro RTX 8000 +attention_mask: + device: cuda:0 + max: 1 + mean: '1.e+00' + min: 1 + shape: + - 8 + - 256 + sum: 2048 +input_ids: + device: cuda:0 + max: 50118 + mean: '5.265e+03' + min: 2 + shape: + - 8 + - 256 + sum: 10781837 +labels: + device: cuda:0 + max: 50118 + mean: '5.265e+03' + min: 2 + shape: + - 8 + - 256 + sum: 10781837 diff --git a/project/algorithms/llm_finetuning_test.py b/project/algorithms/llm_finetuning_test.py index 794cc9c9..6e1a3312 100644 --- a/project/algorithms/llm_finetuning_test.py +++ b/project/algorithms/llm_finetuning_test.py @@ -2,6 +2,7 @@ import copy import operator +from typing import Any import jax import lightning @@ -82,7 +83,8 @@ def training_batch( with torch.random.fork_rng(list(range(torch.cuda.device_count()))): # TODO: This ugliness is because torchvision transforms use the global pytorch RNG! - torch.random.manual_seed(42) + # torch.random.manual_seed(42) + lightning.seed_everything(42, workers=True) batch = next(dataloader_iterator) return jax.tree.map(operator.methodcaller("to", device=device), batch) @@ -97,11 +99,15 @@ def forward_pass_input(self, training_batch: PyTree[torch.Tensor], device: torch assert isinstance(training_batch, dict) return training_batch - # Checking all the weights against the 900mb reference .npz file is a bit slow. + def test_training_batch_doesnt_change( + self, training_batch: dict, tensor_regression: TensorRegressionFixture + ): + tensor_regression.check(training_batch) + @pytest.mark.xfail( SLURM_JOB_ID is not None, reason="TODO: Seems to be failing when run on a SLURM cluster." ) - @pytest.mark.slow + @pytest.mark.slow # Checking against the 900mb reference .npz file is a bit slow. def test_initialization_is_reproducible( self, experiment_config: Config, @@ -117,3 +123,20 @@ def test_initialization_is_reproducible( tensor_regression=tensor_regression, trainer=trainer, ) + + @pytest.mark.xfail( + SLURM_JOB_ID is not None, reason="TODO: Seems to be failing when run on a SLURM cluster." + ) + def test_forward_pass_is_reproducible( + self, + forward_pass_input: Any, + algorithm: LLMFinetuningExample, + seed: int, + tensor_regression: TensorRegressionFixture, + ): + return super().test_forward_pass_is_reproducible( + forward_pass_input=forward_pass_input, + algorithm=algorithm, + seed=seed, + tensor_regression=tensor_regression, + )