Skip to content

Commit

Permalink
Audio Encoder to bfloat16 (#4)
Browse files Browse the repository at this point in the history
* audio enc to bfloat16

* increase transformers min required version
  • Loading branch information
farzadab authored Jun 3, 2024
1 parent d2df97b commit 56a8209
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Core
transformers[torch]>=4.39.3
transformers[torch]>=4.40.2
bitsandbytes>=0.42.0
peft
simple_parsing
Expand Down
4 changes: 1 addition & 3 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def main() -> None:
logging.info(
f"Using dtype and device (world_size): {dtype}, {device} ({world_size})"
)
model.to(device)
model.language_model.to(dtype)
model.multi_modal_projector.to(dtype)
model.to(device=device, dtype=dtype)
# TODO: check if the whole model can now be moved to dtype instead

# Prepare dataset, subsetting if needed
Expand Down

0 comments on commit 56a8209

Please sign in to comment.