From bf9a38fabdf61dda54e74cf93a471a257d0753f1 Mon Sep 17 00:00:00 2001 From: Daniel Walmsley Date: Mon, 8 Jul 2024 14:40:35 -0700 Subject: [PATCH] Fix for latest TF --- TTS/tts/layers/xtts/stream_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 451c783af0..06b55be90e 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -430,7 +430,7 @@ def generate( elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -455,7 +455,7 @@ def generate( ) elif is_sample_gen_stream_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -517,7 +517,7 @@ def generate( elif is_beam_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.")