diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 8944ea1c..e1e0a602 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -353,16 +353,18 @@ def __init__(self, args: VoiceDatasetArgs) -> None: # The last 7 samples are missing audio files, so we exclude them. NUM_SAMPLES = 108193 - 7 super().__init__(args) - dataset = ( - datasets.load_dataset( - "json", - "anyinstruct", - data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl", - split="train", - ).select(range(NUM_SAMPLES)) - # TODO: make num_shards configurable if need be - .to_iterable_dataset(num_shards=16) + dataset = datasets.load_dataset( + "json", + "anyinstruct", + data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl", + split="train", + ).select(range(NUM_SAMPLES)) + dataset = dataset.train_test_split( + test_size=0.01, seed=args.shuffle_seed, shuffle=True ) + dataset = dataset["train" if args.split == DatasetSplit.TRAIN else "test"] + # TODO: make num_shards configurable if need be + dataset = dataset.to_iterable_dataset(num_shards=16) if args.shuffle: dataset = dataset.shuffle(seed=args.shuffle_seed) self._init_dataset(dataset) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 42ce4f17..ef26877a 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -25,6 +25,7 @@ class TrainConfig: data_dir: Optional[str] = None mds: bool = False num_samples: Optional[int] = None + val_num_samples: int = 100 eval_num_samples: int = 100 eval_max_new_tokens: Optional[int] = None eval_num_procs: int = 8 @@ -54,11 +55,11 @@ class TrainConfig: optimizer: str = "adamw_torch" num_epochs: int = 1 max_steps: int = 0 - eval_steps: Optional[int] = None + val_steps: Optional[int] = None save_steps: float = 0 logging_steps: int = 1 grad_accum_steps: int = 1 - eval_accum_steps: int = 1 + val_accum_steps: int = 1 batch_size: int = 2 lr: float = 1e-5 lr_scheduler: str = "cosine" diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 6ef66da1..08cc3224 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -8,7 +8,9 @@ train_on_inputs: False shuffle_data: True max_audio_duration_secs: 16 -eval_num_samples: 64 +val_num_samples: 128 +val_steps: 500 +eval_num_samples: 256 eval_max_new_tokens: 32 eval_num_procs: 16 diff --git a/ultravox/training/ddp_utils.py b/ultravox/training/ddp_utils.py index 171ff4a2..ba0382a5 100644 --- a/ultravox/training/ddp_utils.py +++ b/ultravox/training/ddp_utils.py @@ -1,6 +1,8 @@ import contextlib +from typing import List, TypeVar import torch.distributed +from torch.utils import data @contextlib.contextmanager @@ -18,3 +20,26 @@ def run_on_master_first(is_master: bool): if torch.distributed.is_initialized(): torch.distributed.barrier() yield + + +T = TypeVar("T") + + +def flatten(data: List[List[T]]) -> List[T]: + return [item for sublist in data for item in sublist] + + +def all_gather_list(data: List[T]) -> List[T]: + if not torch.distributed.is_initialized(): + return data + world_size = torch.distributed.get_world_size() + data_list = [None] * world_size + torch.distributed.all_gather_object(data_list, data) + return flatten(data_list) # type: ignore + + +def sharded_iterator(ds: data.IterableDataset, num_shards: int, shard_index: int): + # TODO: handle drop last gracefully + for i, sample in enumerate(ds): + if i % num_shards == shard_index: + yield sample diff --git a/ultravox/training/ddp_utils_test.py b/ultravox/training/ddp_utils_test.py new file mode 100644 index 00000000..5575a56c --- /dev/null +++ b/ultravox/training/ddp_utils_test.py @@ -0,0 +1,27 @@ +import os + +import torch.distributed +from torch import multiprocessing as mp + +from ultravox.training import ddp_utils + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "12355" + + +def test_all_gather_list(): + # Test without DDP + verify_all_gather(0, 1) + # Test with DDP: world_size = 2, 4 + mp.spawn(verify_all_gather, args=(2,), nprocs=2, join=True) + mp.spawn(verify_all_gather, args=(4,), nprocs=4, join=True) + + +def verify_all_gather(rank: int, world_size: int, k: int = 4): + if world_size > 1: + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + d = [rank * k + i for i in range(k)] + all_d = ddp_utils.all_gather_list(d) + assert all_d == list(range(world_size * k)) + if world_size > 1: + torch.distributed.destroy_process_group() diff --git a/ultravox/training/evaluation.py b/ultravox/training/evaluation.py index 4c06550a..7c4468b0 100644 --- a/ultravox/training/evaluation.py +++ b/ultravox/training/evaluation.py @@ -1,5 +1,6 @@ import concurrent.futures import functools +import os from typing import List, Optional import numpy as np @@ -8,18 +9,21 @@ from ultravox.data import datasets from ultravox.evaluation import eval from ultravox.evaluation import eval_types -from ultravox.inference import base +from ultravox.inference import infer +from ultravox.training import ddp_utils def dataset_infer( - inference: base.VoiceInference, + inference: infer.LocalInference, ds: data.IterableDataset, + world_size: int = 1, + local_rank: int = 0, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, ) -> List[eval_types.Sample]: eval_samples = [] - # TODO for multiprocessing: ds -> split_batches or sharded reader - for sample in ds: + + for sample in ddp_utils.sharded_iterator(ds, world_size, local_rank): # Store the original question and answer for JSON output. question_text = sample.audio_transcript or sample.messages[0]["content"] expected_answer = sample.messages[1]["content"] @@ -36,9 +40,8 @@ def dataset_infer( ) eval_samples.append(eval_sample) - # TODO for multiprocess: gather eval_samples - - return eval_samples + # Gather all the samples from all the processes. + return ddp_utils.all_gather_list(eval_samples) def get_metric_name(ds_name: str, metric: str) -> str: @@ -52,7 +55,7 @@ def get_metric_name(ds_name: str, metric: str) -> str: def evaluate( - inference: base.VoiceInference, + inference: infer.LocalInference, data_dir: Optional[str] = None, num_samples: int = 200, num_procs: int = 8, @@ -62,6 +65,9 @@ def evaluate( ): metrics = {} + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + ds_args = datasets.VoiceDatasetArgs( data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION ) @@ -74,9 +80,18 @@ def evaluate( ds = datasets.Range(datasets.create_dataset(ds_name, ds_args), num_samples) output_samples = dataset_infer( - inference, ds=ds, max_new_tokens=max_new_tokens, temperature=temperature + inference, + ds=ds, + max_new_tokens=max_new_tokens, + temperature=temperature, + world_size=world_size, + local_rank=local_rank, ) + if local_rank != 0: + # Only the master process should evaluate the samples. + continue + eval_per_sample = functools.partial(eval.evaluate_answer, metric=metric) with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor: @@ -99,10 +114,10 @@ def evaluate( print(f"X: {sample.expected_answer} [score: {score:.2f}]") average = np.mean(scores) - std = np.std(scores) + std = np.std(scores) / np.sqrt(len(scores)) metric_name = get_metric_name(ds_name, metric) metrics[f"eval_{metric_name}"] = average - metrics[f"eval_{metric_name}_std"] = std / np.sqrt(len(scores)) + metrics[f"eval_{metric_name}_std"] = std print(f"Aggregate {metric} score for {ds_name}: {average:.2f} ± {std:.2f}") diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 85dd2236..a5e68e62 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -5,6 +5,7 @@ import re import sys from datetime import datetime +from typing import List, Optional import datasets as hf_datasets import mlflow @@ -15,6 +16,7 @@ import transformers import wandb from torch.distributed.elastic.multiprocessing.errors import record +from torch.utils import data from ultravox.data import datasets from ultravox.inference import infer @@ -46,6 +48,24 @@ def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) +def prepare_dataset( + dataset_names: List[str], + data_args: datasets.VoiceDatasetArgs, + processor: ultravox_processing.UltravoxProcessor, + train_on_inputs: bool, + repeat_data: bool, + num_samples: Optional[int] = None, +) -> data.IterableDataset: + + data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + interleave = datasets.InterleaveDataset(data_sets, repeat=repeat_data) + ds_with_proc = ultravox_processing.UltravoxDataproc( + interleave, processor=processor, train_on_inputs=train_on_inputs + ) + limited_ds = datasets.Range(ds_with_proc, num_samples=num_samples) + return limited_ds + + @record def main() -> None: # Disable parallelism to avoid deadlocks in DataLoader, apparently @@ -160,31 +180,48 @@ def main() -> None: # TODO: check if the whole model can now be moved to dtype instead # Prepare dataset, subsetting if needed + train_dataset: data.IterableDataset + val_dataset: data.IterableDataset if is_master: - data_args = datasets.VoiceDatasetArgs( - num_prompts=args.num_prompts, - data_dir=args.data_dir, - shuffle=args.shuffle_data, - shuffle_seed=args.shuffle_seed, - max_audio_duration_secs=args.max_audio_duration_secs, - use_mds=args.mds, - mds_batch_size=args.batch_size, + train_dataset = prepare_dataset( + dataset_names=args.data_sets, + train_on_inputs=args.train_on_inputs, + repeat_data=args.repeat_data, + processor=processor, + num_samples=args.num_samples, + data_args=datasets.VoiceDatasetArgs( + num_prompts=args.num_prompts, + data_dir=args.data_dir, + shuffle=args.shuffle_data, + shuffle_seed=args.shuffle_seed, + max_audio_duration_secs=args.max_audio_duration_secs, + use_mds=args.mds, + mds_batch_size=args.batch_size, + ), ) - data_sets = [datasets.create_dataset(ds, data_args) for ds in args.data_sets] - interleaved = datasets.InterleaveDataset(data_sets, repeat=args.repeat_data) - train_dataset: torch.utils.data.IterableDataset = ( - ultravox_processing.UltravoxDataproc( - interleaved, processor=processor, train_on_inputs=args.train_on_inputs - ) + val_dataset = prepare_dataset( + dataset_names=args.data_sets, + train_on_inputs=args.train_on_inputs, + repeat_data=args.repeat_data, + processor=processor, + num_samples=args.val_num_samples, + data_args=datasets.VoiceDatasetArgs( + num_prompts=1, + data_dir=args.data_dir, + shuffle=False, + max_audio_duration_secs=16, + use_mds=args.mds, + mds_batch_size=args.batch_size, + ), ) - train_dataset = datasets.Range(train_dataset, args.num_samples) logging.info( - f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples}" + f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers # The point of this is to avoid unnecessary data processing/downloading in the workers. train_dataset = datasets.EmptyDataset() + val_dataset = datasets.EmptyDataset() # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio(tokenizer=text_tokenizer) @@ -198,6 +235,7 @@ def main() -> None: trainer = transformers.Seq2SeqTrainer( model, train_dataset=train_dataset, + eval_dataset=val_dataset, data_collator=data_collator, tokenizer=text_tokenizer, args=transformers.Seq2SeqTrainingArguments( @@ -207,7 +245,8 @@ def main() -> None: optim=args.optimizer, num_train_epochs=args.num_epochs, max_steps=args.max_steps, - eval_steps=args.eval_steps, + evaluation_strategy="steps", + eval_steps=args.val_steps, save_strategy="steps", save_steps=args.save_steps, logging_first_step=True, @@ -218,7 +257,7 @@ def main() -> None: per_device_train_batch_size=args.batch_size * world_size, accelerator_config={"split_batches": True}, gradient_accumulation_steps=args.grad_accum_steps, - eval_accumulation_steps=args.eval_accum_steps, + eval_accumulation_steps=args.val_accum_steps, # tf32=dtype == torch.float32 and device.type == "cuda", # TODO: check for Ampere GPU not just CUDA ddp_find_unused_parameters=False, learning_rate=args.lr, @@ -255,25 +294,25 @@ def main() -> None: logging.info(f"end time: {t_end}") logging.info(f"elapsed: {t_end - t_start}") + # Merge LoRA weights for better inference performance. + # Note: this is irreversible and changes model saving format + model.merge_and_unload() + inference = infer.LocalInference( + model=model, + processor=processor, + tokenizer=text_tokenizer, + device=args.device, + dtype=dtype, + ) + metrics = evaluation.evaluate( + inference, + data_dir=args.data_dir, + num_procs=args.eval_num_procs, + num_samples=args.eval_num_samples, + max_new_tokens=args.eval_max_new_tokens, + verbose=True, + ) if is_master: - # Merge LoRA weights for better performance. - # Note: this is irreversible and changes model saving format - model.merge_and_unload() - inference = infer.LocalInference( - model=model, - processor=processor, - tokenizer=text_tokenizer, - device=args.device, - dtype=dtype, - ) - metrics = evaluation.evaluate( - inference, - data_dir=args.data_dir, - num_procs=args.eval_num_procs, - num_samples=args.eval_num_samples, - max_new_tokens=args.eval_max_new_tokens, - verbose=True, - ) trainer.log(metrics)