Skip to content

Commit

Permalink
📻 (AST) Audio data classification optimization and data pre-process (#…
Browse files Browse the repository at this point in the history
…762)

## Describe your changes

For this issue #735, I tried to
add an example for AST model optimization with Olive data configs on
huggingface examples to achieve script-free for model optimization.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
trajepl committed Dec 1, 2023
1 parent 9bf2df9 commit 32322a8
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 1 deletion.
28 changes: 28 additions & 0 deletions examples/AST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# AST Optimization
This folder contains examples of AST(Audio Spectrogram Transformer) optimization using olive workflows.

- Model: https://huggingface.co/MIT/ast-finetuned-speech-commands-v2
- Dataset: https://huggingface.co/datasets/speech_commands

### Run example using config

The `ast.json` is used on CPU optimization which tries to quantize the model and tune the inference config for better performance.

First, install required packages according to passes.
```sh
python -m olive.workflows.run --config ast.json --setup
```

Then, optimize the model
```sh
python -m olive.workflows.run --config ast.json
```

or run simply with python code:
```python
from olive.workflows import run as olive_run
olive_run("ast.json")
```

After running the above command, the model candidates and corresponding config will be saved in the output directory.
You can then select the best model and config from the candidates and run the model with the selected config.
106 changes: 106 additions & 0 deletions examples/AST/ast.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_class": "ASTForAudioClassification",
"model_name": "MIT/ast-finetuned-speech-commands-v2",
"task": "audio-classification",
"dataset": {
"data_name":"speech_commands",
"subset": "v0.02",
"split": "validation",
"input_cols": ["audio"],
"label_cols": ["label"],
"max_samples": 100,
"batch_size": 1,
"component_kwargs": {
"pre_process_data": {
"labels_to_filter": ["_silence_"]
}
}
}
},
"io_config": {
"input_names": ["input_values"],
"output_names": ["logits"],
"dynamic_axes": {
"input_values": {
"0": "batch_size", "1": "max_length", "2": "num_mel_bins"
},
"logits": {
"0": "batch_size"
}
}

}
}
},
"evaluators": {
"common_evaluator": {
"metrics":[
{
"name": "accuracy",
"type": "accuracy",
"backend": "huggingface_metrics",
"sub_types": [
{"name": "accuracy", "priority": 1, "goal": {"type": "max-degradation", "value": 0.05}},
{"name": "f1", "metric_config": {"compute_params": {"average": "macro"}}}
]
},
{
"name": "latency",
"type": "latency",
"sub_types": [
{"name": "avg", "priority": 2, "goal": {"type": "percent-min-improvement", "value": 5}},
{"name": "max"},
{"name": "min"}
]
}
]
}
},
"passes": {
"conversion": {
"type": "OnnxConversion"
},
"transformers_optimization": {
"type": "OrtTransformersOptimization",
"disable_search": true,
"config": {
"model_type": "vit"
}
},
"quantization": {
"type": "OnnxQuantization",
"disable_search": true,
"config": {
"quant_mode": "static",
"quant_preprocess": true,
"per_channel": false,
"reduce_range": false,
"data_config": "__input_model_data_config__"
}
},
"perf_tuning": {
"type": "OrtPerfTuning",
"config": {
"data_config": "__input_model_data_config__"
}
}
},
"engine": {
"search_strategy": {
"execution_order": "joint",
"search_algorithm": "tpe",
"search_algorithm_config": {
"num_samples": 3,
"seed": 0
}
},
"evaluator": "common_evaluator",
"execution_providers": ["CPUExecutionProvider"],
"cache_dir": "cache",
"output_dir" : "models/ast_cpu"
}
}
70 changes: 69 additions & 1 deletion olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


from copy import deepcopy
from typing import Optional
from typing import Any, Dict, List, Optional

from olive.data.component.dataset import BaseDataset
from olive.data.component.text_generation import (
Expand Down Expand Up @@ -197,3 +197,71 @@ def text_generation_huggingface_pre_process(
return text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs)
else:
return text_gen_pair_pre_process(dataset, tokenizer, all_kwargs)


@Registry.register_pre_process()
def audio_classification_pre_process(
dataset,
model_name: str,
input_cols: List,
label_cols: List,
max_samples: Optional[int] = None,
trust_remote_code: Optional[bool] = None,
feature_extractor_args: Optional[Dict[str, Any]] = None,
**kwargs
):
"""Pre-process data for audio classification task.
Args:
dataset (object): Data to be pre-processed, reserved for internal dataset assignment.
model_name (str): Name of the huggingface model.
input_cols (list): List of input columns.
label_cols (list): List of label columns.
max_samples (int, optional): Max number of samples to use. Defaults to None.
trust_remote_code (bool, optional): Whether or not to allow for custom models defined on the Hub in their own
modeling files. Defaults to None.
feature_extractor_args (dict, optional): Additional arguments for feature extractor.
**kwargs: Additional arguments.
The common arguments are the fields in olive.data.component.audio_classification.AudioClassificationParams.
Extra arguments:
- max_duration (int, optional): Max duration of audio in seconds. Defaults to 30.
- labels_to_filter (list, optional): List of labels to filter. Defaults to None.
Note: the AudioClassificationParams subclass already includes the common arguments.
"""
from datasets import Audio
from transformers import AutoConfig, AutoFeatureExtractor

assert len(input_cols) == 1, "Only one input column is supported for audio classification task."

# align labels with model configs
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
labels_to_filter = kwargs.get("labels_to_filter", None) or []
dataset = dataset.filter(
lambda x: x not in dataset.features["label"].str2int(labels_to_filter), input_columns=label_cols[0]
)
dataset = dataset.align_labels_with_mapping(model_config.label2id, label_cols[0])

fe_args = feature_extractor_args or {}
fea_extractor = AutoFeatureExtractor.from_pretrained(model_name, trust_remote_code=trust_remote_code, **fe_args)
dataset.cast_column(input_cols[0], Audio(sampling_rate=fea_extractor.sampling_rate))

def _tokenizer_and_align_labels(examples):
max_duration = kwargs.get("max_duration", 30)

audio_arrays = [x["array"] for x in examples[input_cols[0]]]
tokenized_inputs = fea_extractor(
audio_arrays,
sampling_rate=fea_extractor.sampling_rate,
max_length=int(fea_extractor.sampling_rate * max_duration),
truncation=True,
return_attention_mask=True,
)

tokenized_inputs["label"] = examples[label_cols[0]]

return tokenized_inputs

tokenized_datasets = _huggingface_pre_process_helper(
dataset, model_name, input_cols, label_cols, _tokenizer_and_align_labels, **kwargs
)
return BaseDataset(tokenized_datasets, label_cols=label_cols, max_samples=max_samples)
4 changes: 4 additions & 0 deletions olive/data/container/huggingface_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ class HuggingfaceContainer(DataContainer):
DataComponentType.PRE_PROCESS_DATA.value: "text_generation_huggingface_pre_process",
DataComponentType.POST_PROCESS_DATA.value: "text_generation_post_process",
},
"audio-classification": {
DataComponentType.PRE_PROCESS_DATA.value: "audio_classification_pre_process",
DataComponentType.POST_PROCESS_DATA.value: "text_classification_post_process",
},
}

0 comments on commit 32322a8

Please sign in to comment.