Skip to content

Commit

Permalink
Change os.rename to shutil.move when downloading models (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored May 16, 2023
1 parent c281684 commit 806b691
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from dataclasses import dataclass
from typing import Optional
import shutil

import faster_whisper
import huggingface_hub
Expand Down Expand Up @@ -58,7 +59,8 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:


def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(
os.path.expanduser("~"), ".cache", "whisper"))
url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))

Expand Down Expand Up @@ -97,7 +99,8 @@ def get_local_model_path(model: TranscriptionModel) -> Optional[str]:
def download_faster_whisper_model(size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None):
if size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s" % (size, ", ".join(faster_whisper.utils._MODELS))
"Invalid model size '%s', expected one of: %s" % (
size, ", ".join(faster_whisper.utils._MODELS))
)

repo_id = "guillaumekln/faster-whisper-%s" % size
Expand Down Expand Up @@ -131,13 +134,15 @@ def run(self) -> None:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
file_path = get_whisper_cpp_file_path(
size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)

if self.model.model_type == ModelType.WHISPER:
url = whisper._MODELS[self.model.whisper_model_size.value]
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
file_path = get_whisper_file_path(
size=self.model.whisper_model_size)
expected_sha256 = url.split('/')[-2]
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)

Expand All @@ -154,12 +159,14 @@ def close(self):
return super().close()

if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(size=self.model.whisper_model_size.value, tqdm_class=_tqdm)
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.value, tqdm_class=_tqdm)
self.signals.finished.emit(model_path)
return

if self.model.model_type == ModelType.HUGGING_FACE:
model_path = huggingface_hub.snapshot_download(self.model.hugging_face_model_id, tqdm_class=_tqdm)
model_path = huggingface_hub.snapshot_download(
self.model.hugging_face_model_id, tqdm_class=_tqdm)
self.signals.finished.emit(model_path)
return

Expand Down Expand Up @@ -228,7 +235,8 @@ def download_model(self, url: str, file_path: str, expected_sha256: Optional[str

logging.debug('Downloaded model')

os.rename(tmp_file, file_path)
# https://github.com/chidiwilliams/buzz/issues/454
shutil.move(tmp_file, file_path)
logging.debug('Moved file from %s to %s', tmp_file, file_path)
return True

Expand Down

0 comments on commit 806b691

Please sign in to comment.