Skip to content

Commit

Permalink
Implement custom tensor.isin
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Walmsley committed Jun 15, 2024
1 parent bd2f992 commit 6696abf
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
)
from transformers.generation.utils import GenerateOutput, SampleOutput, logger

def custom_isin(elements, test_elements):
# Flatten the tensors
elements_flat = elements.view(-1)
test_elements_flat = test_elements.view(-1)

# Create a mask tensor
mask = torch.zeros_like(elements_flat, dtype=torch.bool)

# Compare each element
for test_element in test_elements_flat:
mask |= (elements_flat == test_element)

# Reshape the mask to the original elements shape
return mask.view(elements.shape)

def setup_seed(seed):
if seed == -1:
Expand Down Expand Up @@ -202,10 +216,10 @@ def generate(
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)

is_pad_token_in_inputs = (pad_token_tensor is not None) and (
torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
Expand Down

0 comments on commit 6696abf

Please sign in to comment.