From 806b691d02253e675db5dce599a007eed94fa07c Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Tue, 16 May 2023 23:24:39 +0100 Subject: [PATCH] Change os.rename to shutil.move when downloading models (#459) --- buzz/model_loader.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 99d673119..c3db7e06f 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -6,6 +6,7 @@ import warnings from dataclasses import dataclass from typing import Optional +import shutil import faster_whisper import huggingface_hub @@ -58,7 +59,8 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: def get_whisper_file_path(size: WhisperModelSize) -> str: - root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) + root_dir = os.getenv("XDG_CACHE_HOME", os.path.join( + os.path.expanduser("~"), ".cache", "whisper")) url = whisper._MODELS[size.value] return os.path.join(root_dir, os.path.basename(url)) @@ -97,7 +99,8 @@ def get_local_model_path(model: TranscriptionModel) -> Optional[str]: def download_faster_whisper_model(size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None): if size not in faster_whisper.utils._MODELS: raise ValueError( - "Invalid model size '%s', expected one of: %s" % (size, ", ".join(faster_whisper.utils._MODELS)) + "Invalid model size '%s', expected one of: %s" % ( + size, ", ".join(faster_whisper.utils._MODELS)) ) repo_id = "guillaumekln/faster-whisper-%s" % size @@ -131,13 +134,15 @@ def run(self) -> None: model_name = self.model.whisper_model_size.value url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp', filename=f'ggml-{model_name}.bin') - file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size) + file_path = get_whisper_cpp_file_path( + size=self.model.whisper_model_size) expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name] return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256) if self.model.model_type == ModelType.WHISPER: url = whisper._MODELS[self.model.whisper_model_size.value] - file_path = get_whisper_file_path(size=self.model.whisper_model_size) + file_path = get_whisper_file_path( + size=self.model.whisper_model_size) expected_sha256 = url.split('/')[-2] return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256) @@ -154,12 +159,14 @@ def close(self): return super().close() if self.model.model_type == ModelType.FASTER_WHISPER: - model_path = download_faster_whisper_model(size=self.model.whisper_model_size.value, tqdm_class=_tqdm) + model_path = download_faster_whisper_model( + size=self.model.whisper_model_size.value, tqdm_class=_tqdm) self.signals.finished.emit(model_path) return if self.model.model_type == ModelType.HUGGING_FACE: - model_path = huggingface_hub.snapshot_download(self.model.hugging_face_model_id, tqdm_class=_tqdm) + model_path = huggingface_hub.snapshot_download( + self.model.hugging_face_model_id, tqdm_class=_tqdm) self.signals.finished.emit(model_path) return @@ -228,7 +235,8 @@ def download_model(self, url: str, file_path: str, expected_sha256: Optional[str logging.debug('Downloaded model') - os.rename(tmp_file, file_path) + # https://github.com/chidiwilliams/buzz/issues/454 + shutil.move(tmp_file, file_path) logging.debug('Moved file from %s to %s', tmp_file, file_path) return True