Skip to content

Commit

Permalink
Improved Evaluations (#2)
Browse files Browse the repository at this point in the history
* anyinstruct bugfix: bring back val_split code

* evaluation generation with multi-processing + test

* validation loss for training

* refactor prepare_dataset
  • Loading branch information
farzadab authored Jun 1, 2024
1 parent caee028 commit 87eb25d
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 59 deletions.
20 changes: 11 additions & 9 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,18 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
# The last 7 samples are missing audio files, so we exclude them.
NUM_SAMPLES = 108193 - 7
super().__init__(args)
dataset = (
datasets.load_dataset(
"json",
"anyinstruct",
data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl",
split="train",
).select(range(NUM_SAMPLES))
# TODO: make num_shards configurable if need be
.to_iterable_dataset(num_shards=16)
dataset = datasets.load_dataset(
"json",
"anyinstruct",
data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl",
split="train",
).select(range(NUM_SAMPLES))
dataset = dataset.train_test_split(
test_size=0.01, seed=args.shuffle_seed, shuffle=True
)
dataset = dataset["train" if args.split == DatasetSplit.TRAIN else "test"]
# TODO: make num_shards configurable if need be
dataset = dataset.to_iterable_dataset(num_shards=16)
if args.shuffle:
dataset = dataset.shuffle(seed=args.shuffle_seed)
self._init_dataset(dataset)
Expand Down
5 changes: 3 additions & 2 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TrainConfig:
data_dir: Optional[str] = None
mds: bool = False
num_samples: Optional[int] = None
val_num_samples: int = 100
eval_num_samples: int = 100
eval_max_new_tokens: Optional[int] = None
eval_num_procs: int = 8
Expand Down Expand Up @@ -54,11 +55,11 @@ class TrainConfig:
optimizer: str = "adamw_torch"
num_epochs: int = 1
max_steps: int = 0
eval_steps: Optional[int] = None
val_steps: Optional[int] = None
save_steps: float = 0
logging_steps: int = 1
grad_accum_steps: int = 1
eval_accum_steps: int = 1
val_accum_steps: int = 1
batch_size: int = 2
lr: float = 1e-5
lr_scheduler: str = "cosine"
Expand Down
4 changes: 3 additions & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ train_on_inputs: False
shuffle_data: True
max_audio_duration_secs: 16

eval_num_samples: 64
val_num_samples: 128
val_steps: 500
eval_num_samples: 256
eval_max_new_tokens: 32
eval_num_procs: 16

Expand Down
25 changes: 25 additions & 0 deletions ultravox/training/ddp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
from typing import List, TypeVar

import torch.distributed
from torch.utils import data


@contextlib.contextmanager
Expand All @@ -18,3 +20,26 @@ def run_on_master_first(is_master: bool):
if torch.distributed.is_initialized():
torch.distributed.barrier()
yield


T = TypeVar("T")


def flatten(data: List[List[T]]) -> List[T]:
return [item for sublist in data for item in sublist]


def all_gather_list(data: List[T]) -> List[T]:
if not torch.distributed.is_initialized():
return data
world_size = torch.distributed.get_world_size()
data_list = [None] * world_size
torch.distributed.all_gather_object(data_list, data)
return flatten(data_list) # type: ignore


def sharded_iterator(ds: data.IterableDataset, num_shards: int, shard_index: int):
# TODO: handle drop last gracefully
for i, sample in enumerate(ds):
if i % num_shards == shard_index:
yield sample
27 changes: 27 additions & 0 deletions ultravox/training/ddp_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import torch.distributed
from torch import multiprocessing as mp

from ultravox.training import ddp_utils

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"


def test_all_gather_list():
# Test without DDP
verify_all_gather(0, 1)
# Test with DDP: world_size = 2, 4
mp.spawn(verify_all_gather, args=(2,), nprocs=2, join=True)
mp.spawn(verify_all_gather, args=(4,), nprocs=4, join=True)


def verify_all_gather(rank: int, world_size: int, k: int = 4):
if world_size > 1:
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
d = [rank * k + i for i in range(k)]
all_d = ddp_utils.all_gather_list(d)
assert all_d == list(range(world_size * k))
if world_size > 1:
torch.distributed.destroy_process_group()
37 changes: 26 additions & 11 deletions ultravox/training/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import concurrent.futures
import functools
import os
from typing import List, Optional

import numpy as np
Expand All @@ -8,18 +9,21 @@
from ultravox.data import datasets
from ultravox.evaluation import eval
from ultravox.evaluation import eval_types
from ultravox.inference import base
from ultravox.inference import infer
from ultravox.training import ddp_utils


def dataset_infer(
inference: base.VoiceInference,
inference: infer.LocalInference,
ds: data.IterableDataset,
world_size: int = 1,
local_rank: int = 0,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> List[eval_types.Sample]:
eval_samples = []
# TODO for multiprocessing: ds -> split_batches or sharded reader
for sample in ds:

for sample in ddp_utils.sharded_iterator(ds, world_size, local_rank):
# Store the original question and answer for JSON output.
question_text = sample.audio_transcript or sample.messages[0]["content"]
expected_answer = sample.messages[1]["content"]
Expand All @@ -36,9 +40,8 @@ def dataset_infer(
)
eval_samples.append(eval_sample)

# TODO for multiprocess: gather eval_samples

return eval_samples
# Gather all the samples from all the processes.
return ddp_utils.all_gather_list(eval_samples)


def get_metric_name(ds_name: str, metric: str) -> str:
Expand All @@ -52,7 +55,7 @@ def get_metric_name(ds_name: str, metric: str) -> str:


def evaluate(
inference: base.VoiceInference,
inference: infer.LocalInference,
data_dir: Optional[str] = None,
num_samples: int = 200,
num_procs: int = 8,
Expand All @@ -62,6 +65,9 @@ def evaluate(
):
metrics = {}

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

ds_args = datasets.VoiceDatasetArgs(
data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION
)
Expand All @@ -74,9 +80,18 @@ def evaluate(
ds = datasets.Range(datasets.create_dataset(ds_name, ds_args), num_samples)

output_samples = dataset_infer(
inference, ds=ds, max_new_tokens=max_new_tokens, temperature=temperature
inference,
ds=ds,
max_new_tokens=max_new_tokens,
temperature=temperature,
world_size=world_size,
local_rank=local_rank,
)

if local_rank != 0:
# Only the master process should evaluate the samples.
continue

eval_per_sample = functools.partial(eval.evaluate_answer, metric=metric)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor:
Expand All @@ -99,10 +114,10 @@ def evaluate(
print(f"X: {sample.expected_answer} [score: {score:.2f}]")

average = np.mean(scores)
std = np.std(scores)
std = np.std(scores) / np.sqrt(len(scores))
metric_name = get_metric_name(ds_name, metric)
metrics[f"eval_{metric_name}"] = average
metrics[f"eval_{metric_name}_std"] = std / np.sqrt(len(scores))
metrics[f"eval_{metric_name}_std"] = std

print(f"Aggregate {metric} score for {ds_name}: {average:.2f} ± {std:.2f}")

Expand Down
111 changes: 75 additions & 36 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from datetime import datetime
from typing import List, Optional

import datasets as hf_datasets
import mlflow
Expand All @@ -15,6 +16,7 @@
import transformers
import wandb
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils import data

from ultravox.data import datasets
from ultravox.inference import infer
Expand Down Expand Up @@ -46,6 +48,24 @@ def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def prepare_dataset(
dataset_names: List[str],
data_args: datasets.VoiceDatasetArgs,
processor: ultravox_processing.UltravoxProcessor,
train_on_inputs: bool,
repeat_data: bool,
num_samples: Optional[int] = None,
) -> data.IterableDataset:

data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names]
interleave = datasets.InterleaveDataset(data_sets, repeat=repeat_data)
ds_with_proc = ultravox_processing.UltravoxDataproc(
interleave, processor=processor, train_on_inputs=train_on_inputs
)
limited_ds = datasets.Range(ds_with_proc, num_samples=num_samples)
return limited_ds


@record
def main() -> None:
# Disable parallelism to avoid deadlocks in DataLoader, apparently
Expand Down Expand Up @@ -160,31 +180,48 @@ def main() -> None:
# TODO: check if the whole model can now be moved to dtype instead

# Prepare dataset, subsetting if needed
train_dataset: data.IterableDataset
val_dataset: data.IterableDataset
if is_master:
data_args = datasets.VoiceDatasetArgs(
num_prompts=args.num_prompts,
data_dir=args.data_dir,
shuffle=args.shuffle_data,
shuffle_seed=args.shuffle_seed,
max_audio_duration_secs=args.max_audio_duration_secs,
use_mds=args.mds,
mds_batch_size=args.batch_size,
train_dataset = prepare_dataset(
dataset_names=args.data_sets,
train_on_inputs=args.train_on_inputs,
repeat_data=args.repeat_data,
processor=processor,
num_samples=args.num_samples,
data_args=datasets.VoiceDatasetArgs(
num_prompts=args.num_prompts,
data_dir=args.data_dir,
shuffle=args.shuffle_data,
shuffle_seed=args.shuffle_seed,
max_audio_duration_secs=args.max_audio_duration_secs,
use_mds=args.mds,
mds_batch_size=args.batch_size,
),
)
data_sets = [datasets.create_dataset(ds, data_args) for ds in args.data_sets]
interleaved = datasets.InterleaveDataset(data_sets, repeat=args.repeat_data)
train_dataset: torch.utils.data.IterableDataset = (
ultravox_processing.UltravoxDataproc(
interleaved, processor=processor, train_on_inputs=args.train_on_inputs
)
val_dataset = prepare_dataset(
dataset_names=args.data_sets,
train_on_inputs=args.train_on_inputs,
repeat_data=args.repeat_data,
processor=processor,
num_samples=args.val_num_samples,
data_args=datasets.VoiceDatasetArgs(
num_prompts=1,
data_dir=args.data_dir,
shuffle=False,
max_audio_duration_secs=16,
use_mds=args.mds,
mds_batch_size=args.batch_size,
),
)
train_dataset = datasets.Range(train_dataset, args.num_samples)
logging.info(
f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples}"
f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})"
)
else:
# When using DDP with split_batches=True, the primary process will distribute the batches to the workers
# The point of this is to avoid unnecessary data processing/downloading in the workers.
train_dataset = datasets.EmptyDataset()
val_dataset = datasets.EmptyDataset()

# Set up the data loader
data_collator = datasets.DataCollatorForSeq2SeqWithAudio(tokenizer=text_tokenizer)
Expand All @@ -198,6 +235,7 @@ def main() -> None:
trainer = transformers.Seq2SeqTrainer(
model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
tokenizer=text_tokenizer,
args=transformers.Seq2SeqTrainingArguments(
Expand All @@ -207,7 +245,8 @@ def main() -> None:
optim=args.optimizer,
num_train_epochs=args.num_epochs,
max_steps=args.max_steps,
eval_steps=args.eval_steps,
evaluation_strategy="steps",
eval_steps=args.val_steps,
save_strategy="steps",
save_steps=args.save_steps,
logging_first_step=True,
Expand All @@ -218,7 +257,7 @@ def main() -> None:
per_device_train_batch_size=args.batch_size * world_size,
accelerator_config={"split_batches": True},
gradient_accumulation_steps=args.grad_accum_steps,
eval_accumulation_steps=args.eval_accum_steps,
eval_accumulation_steps=args.val_accum_steps,
# tf32=dtype == torch.float32 and device.type == "cuda", # TODO: check for Ampere GPU not just CUDA
ddp_find_unused_parameters=False,
learning_rate=args.lr,
Expand Down Expand Up @@ -255,25 +294,25 @@ def main() -> None:
logging.info(f"end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")

# Merge LoRA weights for better inference performance.
# Note: this is irreversible and changes model saving format
model.merge_and_unload()
inference = infer.LocalInference(
model=model,
processor=processor,
tokenizer=text_tokenizer,
device=args.device,
dtype=dtype,
)
metrics = evaluation.evaluate(
inference,
data_dir=args.data_dir,
num_procs=args.eval_num_procs,
num_samples=args.eval_num_samples,
max_new_tokens=args.eval_max_new_tokens,
verbose=True,
)
if is_master:
# Merge LoRA weights for better performance.
# Note: this is irreversible and changes model saving format
model.merge_and_unload()
inference = infer.LocalInference(
model=model,
processor=processor,
tokenizer=text_tokenizer,
device=args.device,
dtype=dtype,
)
metrics = evaluation.evaluate(
inference,
data_dir=args.data_dir,
num_procs=args.eval_num_procs,
num_samples=args.eval_num_samples,
max_new_tokens=args.eval_max_new_tokens,
verbose=True,
)
trainer.log(metrics)


Expand Down

0 comments on commit 87eb25d

Please sign in to comment.