Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speech Commands v2 dataset doesn't match AST-v2 config #6446

Closed
vymao opened this issue Nov 22, 2023 · 3 comments
Closed

Speech Commands v2 dataset doesn't match AST-v2 config #6446

vymao opened this issue Nov 22, 2023 · 3 comments

Comments

@vymao
Copy link

vymao commented Nov 22, 2023

Describe the bug

According to MIT/ast-finetuned-speech-commands-v2, the model was trained on the Speech Commands v2 dataset. However, while the model config says the model should have 35 class labels, the dataset itself has 36 class labels. Moreover, the class labels themselves don't match between the model config and the dataset. It is difficult to reproduce the data used to fine tune MIT/ast-finetuned-speech-commands-v2.

Steps to reproduce the bug

>>> model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-speech-commands-v2")
>>> model.config.id2label
{0: 'backward', 1: 'follow', 2: 'five', 3: 'bed', 4: 'zero', 5: 'on', 6: 'learn', 7: 'two', 8: 'house', 9: 'tree', 10: 'dog', 11: 'stop', 12: 'seven', 13: 'eight', 14: 'down', 15: 'six', 16: 'forward', 17: 'cat', 18: 'right', 19: 'visual', 20: 'four', 21: 'wow', 22: 'no', 23: 'nine', 24: 'off', 25: 'three', 26: 'left', 27: 'marvin', 28: 'yes', 29: 'up', 30: 'sheila', 31: 'happy', 32: 'bird', 33: 'go', 34: 'one'}

>>> dataset = load_dataset("speech_commands", "v0.02", split="test")
>>> torch.unique(torch.Tensor(dataset['label']))
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35.])

If you try to explore the dataset itself, you can see that the id to label does not match what is provided by model.config.id2label.

Expected behavior

The labels should match completely and there should be the same number of label classes between the model config and the dataset itself.

Environment info

datasets = 2.14.6, transformers = 4.33.3

@mariosasko
Copy link
Collaborator

You can use .align_labels_with_mapping on the dataset to align the labels with the model config.

Regarding the number of labels, only the special _silence_ label corresponding to noise is missing, which is consistent with the model paper (reports training on 35 labels). You can run a .filter to drop it.

PS: You should create a discussion on a model/dataset repo (on the Hub) for these kinds of questions

@vymao
Copy link
Author

vymao commented Nov 23, 2023

Thanks, will keep that in mind. But I tried running dataset_aligned = dataset.align_labels_with_mapping(model.config.id2label, 'label'), and received this error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/victor/anaconda3/envs/transformers-v2/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 5928, in align_labels_with_mapping
    label2id = {k.lower(): v for k, v in label2id.items()}
  File "/Users/victor/anaconda3/envs/transformers-v2/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 5928, in <dictcomp>
    label2id = {k.lower(): v for k, v in label2id.items()}
AttributeError: 'int' object has no attribute 'lower'

My guess is that the dataset label column is purely an int ID, and I'm not sure there's a way to identify which class label the ID belongs to in the dataset easily.

@mariosasko
Copy link
Collaborator

Replacing model.config.id2label with model.config.label2id should fix the issue.

So, the full code to align the labels with the model config is as follows:

from datasets import load_dataset
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

# extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-speech-commands-v2")
model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-speech-commands-v2")

ds = load_dataset("speech_commands", "v0.02")
ds = ds.filter(lambda label: label != ds["train"].features["label"].str2int("_silence_"), input_columns="label")
ds = ds.align_labels_with_mapping(model.config.label2id, "label")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants