You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Shortform:
importtorchfromtransformersimportAutoProcessor, WhisperForConditionalGenerationfromdatasetsimportload_datasetprocessor=AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model=WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
ds=load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs=processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features=inputs.input_featuresgenerated_ids=model.generate(
inputs=input_features,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
## MODIFIED: Comment out this line to test. ## Results in `GenerateEncoderDecoderOutput` as return type, but `torch.Tensor` is expectedlogprob_threshold=-1,
## MODIFIED: This should have no impact, but it does. ## Results in `dict` as return type, but `torch.Tensor` is expectedreturn_segments=True,
)
print(type(generated_ids))
print(generated_ids)
transcription=processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription
Longform:
importtorchfromtransformersimportAutoProcessor, WhisperForConditionalGenerationfromdatasetsimportload_dataset, Audioprocessor=AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model=WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.cuda()
# load audios > 30 secondsds=load_dataset("distil-whisper/meanwhile", "default")["test"]
# resample to 16kHzds=ds.cast_column("audio", Audio(sampling_rate=16000))
# take first 8 audios and retrieve arrayaudio=ds[:8]["audio"]
audio= [x["array"] forxinaudio]
# make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audioinputs=processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
inputs=inputs.to("cuda", torch.float32)
# transcribe audio to idsgenerated_ids=model.generate(
**inputs,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
## MODIFIED: Results in `torch.Tensor` as the return type, as expectedlogprob_threshold=-1,
## MODIFIED: Comment out this line to test. Results in `dict` as return type, as expectedreturn_segments=True,
)
print(type(generated_ids))
print(generated_ids)
transcription=processor.batch_decode(generated_ids, skip_special_tokens=True)
transcription[0]
Since shortform generation is now supported with fallback, I don't think the documentation for the parameter is accurate, that logprob_thresholdonly applies to longform audio,
I think there's also now some inconsistency in how the return_* parameters are handled by Whisper.generate(). return_segments, which should only apply to longform audio, actually applies to both shortform and longform audio.
Likewise, the return type when return_dict_in_generate=True is GenerateEncoderDecoderOutput, but for return_segments=True, it is dict, and without either (or logprob_threshold), it is torch.Tensor.
I'm willing to contribute a PR cleaning this up a little, but I wasn't sure how to proceed.
It looks like return_segments could be removed as a parameter, instead relying on return_dict_in_generate for both longform and shortform audios. The return handling could then be combined so that parameters such as return_token_timestamps would apply to longform audio as well, and the return type for longform and shortform is GenerateEncoderDecoderOutput when return_dict_in_generate=True
Likewise, should setting logprob_threshold still result in return_dict_in_generate=True?
Edit: Opening a separate issue for this Finally, I noticed that the attention_maskis not actually passed down to generate_with_fallback, so it doesn't get passed to the underlying super().generate() call, resulting in this error being shown even when attention_mask is set:
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.43.1Who can help?
@sanchit-gandhi @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Shortform:
Longform:
Expected behavior
I noticed that setting
logprob_threshold
resulted inreturn_dict_in_generate
being set toTrue
, which in turn impacts the return type forshortform
generation only.Since shortform generation is now supported with fallback, I don't think the documentation for the parameter is accurate, that
logprob_threshold
only applies to longform audio,I think there's also now some inconsistency in how the
return_*
parameters are handled byWhisper.generate()
.return_segments
, which should only apply to longform audio, actually applies to both shortform and longform audio.Likewise, the return type when
return_dict_in_generate=True
isGenerateEncoderDecoderOutput
, but forreturn_segments=True
, it isdict
, and without either (orlogprob_threshold
), it istorch.Tensor
.I'm willing to contribute a PR cleaning this up a little, but I wasn't sure how to proceed.
It looks like
return_segments
could be removed as a parameter, instead relying onreturn_dict_in_generate
for both longform and shortform audios. The return handling could then be combined so that parameters such asreturn_token_timestamps
would apply to longform audio as well, and the return type for longform and shortform isGenerateEncoderDecoderOutput
whenreturn_dict_in_generate=True
Likewise, should setting
logprob_threshold
still result inreturn_dict_in_generate=True
?Edit: Opening a separate issue for this
Finally, I noticed that theattention_mask
is not actually passed down togenerate_with_fallback
, so it doesn't get passed to the underlyingsuper().generate()
call, resulting in this error being shown even whenattention_mask
is set:The text was updated successfully, but these errors were encountered: