Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Inconsistent return types for Whisper generation #32202

Closed
2 of 4 tasks
benniekiss opened this issue Jul 24, 2024 · 2 comments
Closed
2 of 4 tasks

[Whisper] Inconsistent return types for Whisper generation #32202

benniekiss opened this issue Jul 24, 2024 · 2 comments
Labels

Comments

@benniekiss
Copy link
Contributor

benniekiss commented Jul 24, 2024

System Info

  • transformers version: 4.43.1
  • Platform: Linux-6.9.7-201.fsync.fc40.x86_64-x86_64-with-glibc2.35
  • Python version: 3.12.4
  • Huggingface_hub version: 0.24.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.31.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3080

Who can help?

@sanchit-gandhi @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Shortform:

import torch
from transformers import AutoProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features

generated_ids = model.generate(
    inputs=input_features,
    temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),

    ## MODIFIED: Comment out this line to test. 
    ## Results in `GenerateEncoderDecoderOutput` as return type, but `torch.Tensor` is expected
    logprob_threshold=-1, 

    ## MODIFIED: This should have no impact, but it does. 
    ## Results in `dict` as return type, but `torch.Tensor` is expected
    return_segments=True, 
)
print(type(generated_ids))
print(generated_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription

Longform:

import torch
from transformers import AutoProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.cuda()
# load audios > 30 seconds
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
# resample to 16kHz
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
# take first 8 audios and retrieve array
audio = ds[:8]["audio"]
audio = [x["array"] for x in audio]

# make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
inputs = inputs.to("cuda", torch.float32)

# transcribe audio to ids
generated_ids = model.generate(
    **inputs, 
    temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),

    ## MODIFIED: Results in `torch.Tensor` as the return type, as expected
    logprob_threshold=-1, 

    ## MODIFIED: Comment out this line to test. Results in `dict` as return type, as expected
    return_segments=True,
)
print(type(generated_ids))
print(generated_ids)

transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
transcription[0]

Expected behavior

I noticed that setting logprob_threshold resulted in return_dict_in_generate being set to True, which in turn impacts the return type for shortform generation only.

Since shortform generation is now supported with fallback, I don't think the documentation for the parameter is accurate, that logprob_threshold only applies to longform audio,

I think there's also now some inconsistency in how the return_* parameters are handled by Whisper.generate(). return_segments, which should only apply to longform audio, actually applies to both shortform and longform audio.

Likewise, the return type when return_dict_in_generate=True is GenerateEncoderDecoderOutput, but for return_segments=True, it is dict, and without either (or logprob_threshold), it is torch.Tensor.

I'm willing to contribute a PR cleaning this up a little, but I wasn't sure how to proceed.

It looks like return_segments could be removed as a parameter, instead relying on return_dict_in_generate for both longform and shortform audios. The return handling could then be combined so that parameters such as return_token_timestamps would apply to longform audio as well, and the return type for longform and shortform is GenerateEncoderDecoderOutput when return_dict_in_generate=True

Likewise, should setting logprob_threshold still result in return_dict_in_generate=True?

Edit: Opening a separate issue for this
Finally, I noticed that the attention_mask is not actually passed down to generate_with_fallback, so it doesn't get passed to the underlying super().generate() call, resulting in this error being shown even when attention_mask is set:

The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
@benniekiss benniekiss added the bug label Jul 24, 2024
@benniekiss
Copy link
Contributor Author

I see that this is actually being addressed already by #32178!

@ArthurZucker
Copy link
Collaborator

Glad this was fixed for you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants