Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update default config to ultravox_v0.3 #84

Merged
merged 5 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mcloud.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ command: >-
cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
env_variables:
MLFLOW_TRACKING_URI: databricks
UV_BRANCH: main
TRAIN_ARGS: --config_path ultravox/training/configs/llama3_whisper_kd.yaml
UV_BRANCH: update_default_config_to_ultravox_v0.3
TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml
2 changes: 1 addition & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ audio_model: "facebook/wav2vec2-base-960h"

data_sets: ["gigaspeech"]
val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"]
stop_strategy: "last_exhausted"
stop_strategy: "LAST_EXHAUSTED"

train_on_inputs: False
shuffle_data: True
Expand Down
45 changes: 45 additions & 0 deletions ultravox/training/configs/release_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SLM with ultravox & llama3.1, trained wtih knowledge distillation.
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
exp_name: "ultravox-v0_3"

# Make sure to accept the license agreement on huggingface hub
text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
audio_model: "openai/whisper-small"


loss_config:
# Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence"
loss_function: "KL_Divergence"

# Temporarily remove heysquad_human from val_sets as it causes the training to fail.
val_sets: ["anyinstruct", "soda", "peoplespeech"]

batch_size: 24
max_steps: 7200 # x8x24 = 1,382,400 samples

data_sets: []
data_dicts:
- path: "fixie-ai/librispeech_asr"
name: "clean"
splits:
- "train.100" # 28_539 samples
- "train.360" # 104_014 samples
user_template: "Continue the following text using less than 50 words:\n\n<|audio|>"
assistant_template: "{{ continuation }}"
transcript_template: "{{ text }}"
weight: 1
- path: "fixie-ai/librispeech_asr"
name: "other"
splits:
- "train.500" # 148_688 samples
user_template: "Continue the following text using less than 50 words:\n\n<|audio|>"
assistant_template: "{{ continuation }}"
transcript_template: "{{ text }}"
weight: 1
- path: "fixie-ai/common_voice_17_0"
name: "en"
splits:
- "train" # 1_101_170 samples
user_template: "Continue the following text using less than 50 words:\n\n<|audio|>"
assistant_template: "{{ continuation }}"
transcript_template: "{{ text_proc.format_asr_text(sentence) }}"
weight: 8
Loading