Skip to content

Commit

Permalink
[VITS] Fix nightly tests (#25986)
Browse files Browse the repository at this point in the history
* fix tokenizer

* make bs even

* fix multi gpu test

* style

* model forward

* fix torch import

* revert tok pin
  • Loading branch information
sanchit-gandhi authored Sep 7, 2023
1 parent 3744126 commit 2af87d0
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/models/vits/test_modeling_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_flaky,
is_torch_available,
require_torch,
require_torch_multi_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -177,6 +178,30 @@ def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)

@require_torch_multi_gpu
# override to force all elements of the batch to have the same sequence length across GPUs
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_stochastic_duration_prediction = False

# move input tensors to cuda:O
for key, value in inputs_dict.items():
if torch.is_tensor(value):
# make all elements of the batch the same -> ensures the output seq lengths are the same for DP
value[1:] = value[0]
inputs_dict[key] = value.to(0)

for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()

# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
set_seed(555)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform

@unittest.skip("VITS is not deterministic")
def test_determinism(self):
pass
Expand Down

0 comments on commit 2af87d0

Please sign in to comment.