Skip to content

Commit

Permalink
Optionally include "passage" in BoolQ samples (#6)
Browse files Browse the repository at this point in the history
* v1

* Optionally include "passage" in BoolQ samples
  • Loading branch information
juberti authored Jun 5, 2024
1 parent edc3797 commit 6f387d0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
79 changes: 57 additions & 22 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,26 @@ class DatasetSplit(str, enum.Enum):
@dataclasses.dataclass
class VoiceDatasetArgs:
data_dir: Optional[str] = None
prompt: Optional[str] = None
"""A specific prompt to use for the dataset."""
num_prompts: int = 1
"""If `prompt` is not set, the number of canned prompts to use."""
include_audio: bool = True
"""Whether to include audio in the samples."""
include_context: bool = False
"""Whether to include additional textual context from the dataset to the prompt."""
shuffle: bool = False
"""Whether to shuffle the dataset."""
shuffle_seed: int = 42
"""Seed for shuffling the dataset."""
max_audio_duration_secs: Optional[float] = None
"""Whether to skip samples with audio longer than this duration."""
use_mds: bool = False
"""Whether to load the dataset from GCP (using MDS) or Hugging Face."""
mds_batch_size: int = 32
"""Batch size for MDS."""
split: DatasetSplit = DatasetSplit.TRAIN
"""Which split of the dataset to use."""

def __post_init__(self):
if isinstance(self.split, str):
Expand Down Expand Up @@ -261,17 +274,25 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
pass

def _get_answer_prompt(self, idx: int) -> str:
if self._args.prompt:
return self._args.prompt
prompt_idx = idx % min(self._args.num_prompts, len(ANSWER_PROMPTS))
return ANSWER_PROMPTS[prompt_idx]

def _get_transcribe_prompt(self, idx: int) -> str:
if self._args.prompt:
return self._args.prompt
prompt_idx = idx % min(self._args.num_prompts, len(TRANSCRIBE_PROMPTS))
return TRANSCRIBE_PROMPTS[prompt_idx]

def _get_answer_messages(self, idx: int, text: str) -> List[Dict[str, str]]:
def _get_answer_messages(
self, idx: int, question: str, answer: str, context: Optional[str] = None
) -> List[Dict[str, str]]:
prompt = self._get_answer_prompt(idx) if self._args.include_audio else question
user_content = f"{context}\n\n{prompt}" if context else prompt
return [
{"role": "user", "content": self._get_answer_prompt(idx)},
{"role": "assistant", "content": text},
{"role": "user", "content": user_content},
{"role": "assistant", "content": answer},
]

def _get_transcribe_messages(self, idx: int, text: str) -> List[Dict[str, str]]:
Expand All @@ -294,6 +315,19 @@ def _get_audio(self, row: transformers.BatchFeature) -> np.ndarray:
assert sampling_rate == SAMPLE_RATE
return audio

def _load_audio(self, base_url: str, folder: str, filename: str) -> np.ndarray:
if self._args.data_dir:
audio_path = f"{self._args.data_dir}/{folder}/{filename}"
audio = audio_from_file(audio_path)
else:
url = f"{base_url}/{filename}" # hack for GCS bucket naming
if self._session is None:
self._session = requests.Session()
response = self._session.get(url)
response.raise_for_status()
audio = audio_from_buf(response.content)
return audio

def _get_transcribe_sample(
self,
idx: int,
Expand All @@ -302,24 +336,21 @@ def _get_transcribe_sample(
tproc: Optional[Callable[[str], str]] = None,
) -> VoiceSample:
text = tproc(row[tcol]) if tproc else row[tcol]
return VoiceSample(
return self._make_sample(
self._get_transcribe_messages(idx, text),
self._get_audio(row),
audio_transcript=text,
)

def _load_audio(self, base_url: str, folder: str, filename: str) -> np.ndarray:
if self._args.data_dir:
audio_path = f"{self._args.data_dir}/{folder}/{filename}"
audio = audio_from_file(audio_path)
else:
url = f"{base_url}/{filename}" # hack for GCS bucket naming
if self._session is None:
self._session = requests.Session()
response = self._session.get(url)
response.raise_for_status()
audio = audio_from_buf(response.content)
return audio
def _make_sample(
self,
messages: List[Dict[str, str]],
audio: np.ndarray,
audio_transcript: Optional[str] = None,
) -> VoiceSample:
if not self._args.include_audio:
return VoiceSample(messages)
return VoiceSample(messages, audio, audio_transcript=audio_transcript)


class LibriSpeechDummyDataset(VoiceDataset):
Expand Down Expand Up @@ -385,10 +416,11 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
return VoiceSample(
self._get_answer_messages(idx, row["chat"][1]["message"]),
self._load_anyinstruct_audio(row["chat"][0]["speech"]),
audio_transcript=row["chat"][0]["message"],
chat = row["chat"]
return self._make_sample(
self._get_answer_messages(idx, chat[0]["message"], chat[1]["message"]),
self._load_anyinstruct_audio(chat[0]["speech"]),
audio_transcript=chat[0]["message"],
)


Expand Down Expand Up @@ -428,8 +460,11 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
self._init_dataset(dataset)

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
return VoiceSample(
self._get_answer_messages(idx, "True" if row["answer"] else "False"),
question = row["question"]
answer = "True" if row["answer"] else "False"
context = row["passage"] if self._args.include_context else None
return self._make_sample(
self._get_answer_messages(idx, question, answer, context),
self._get_audio(row),
audio_transcript=row["question"],
)
Expand Down
31 changes: 14 additions & 17 deletions ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ class InferArgs:
# Data sets to use for inference
data_sets: Optional[List[str]] = simple_parsing.field(default=None, alias="-d")
# Which dataset split to use
data_split: datasets.DatasetSplit = datasets.DatasetSplit.VALIDATION
data_split: datasets.DatasetSplit = simple_parsing.field(
default=datasets.DatasetSplit.VALIDATION, alias="-s"
)
# Directory for existing data
data_dir: Optional[str] = None
# Use dataset context
context: bool = False
# Load datasets using MDS
mds: bool = False
# Number of dataset samples to process
Expand Down Expand Up @@ -128,10 +132,9 @@ def run_tui(
eval_str = ""
if scores is not None:
assert args.data_sets
assert sample.audio_transcript is not None, "Query must have transcript"
ds_name = args.data_sets[0]
eval_sample = eval_types.Sample(
sample.audio_transcript,
sample.audio_transcript or sample.messages[0]["content"],
expected_answer=expected_response,
generated_answer=text,
)
Expand All @@ -156,7 +159,8 @@ def run_tui(
print(f"X: {expected_response}{eval_str}")


def oneshot_infer(inference: base.VoiceInference, prompt: str, args: InferArgs):
def oneshot_infer(inference: base.VoiceInference, args: InferArgs):
prompt = args.prompt or (DEFAULT_ASR_PROMPT if args.asr else DEFAULT_PROMPT)
if args.audio_file is not None:
sample = datasets.VoiceSample.from_prompt_and_buf(
prompt, args.audio_file.read()
Expand All @@ -166,10 +170,13 @@ def oneshot_infer(inference: base.VoiceInference, prompt: str, args: InferArgs):
run_tui(-1, inference, sample, args)


def dataset_infer(inference: base.VoiceInference, prompt: str, args: InferArgs):
def dataset_infer(inference: base.VoiceInference, args: InferArgs):
assert args.data_sets, "At least one data set must be provided"
ds_args = datasets.VoiceDatasetArgs(
data_dir=args.data_dir,
prompt=args.prompt,
include_audio=not args.text_only,
include_context=args.context,
shuffle=args.shuffle,
use_mds=args.mds,
split=args.data_split,
Expand All @@ -184,15 +191,6 @@ def dataset_infer(inference: base.VoiceInference, prompt: str, args: InferArgs):
expected_answer = sample.messages[1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:1]
# Normally, we overwrite the dataset prompt with our prompt, allowing us to customize
# the inference prompt in this tool.
# If we're using text-only mode though, there's no audio to inference, so we
# just paste the text transcript in as the text prompt.
if not args.text_only:
sample.messages[0]["content"] = prompt
else:
sample.messages[0]["content"] = sample.audio_transcript
sample.audio = sample.audio_transcript = None
if not args.json:
run_tui(i, inference, sample, args, expected_answer, scores)
else:
Expand Down Expand Up @@ -222,11 +220,10 @@ def main(args: InferArgs):
device=args.device,
data_type=args.data_type,
)
prompt = args.prompt or (DEFAULT_ASR_PROMPT if args.asr else DEFAULT_PROMPT)
if args.data_sets is None:
oneshot_infer(inference, prompt, args)
oneshot_infer(inference, args)
else:
dataset_infer(inference, prompt, args)
dataset_infer(inference, args)


if __name__ == "__main__":
Expand Down

0 comments on commit 6f387d0

Please sign in to comment.