Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenizer-only flag to only download tokenizers from HF or oras #895

Merged
merged 10 commits into from
Jan 23, 2024
Merged
26 changes: 20 additions & 6 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
]
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
TOKENIZER_FILES = [
'special_tokens_map.json',
'tokenizer.json',
'tokenizer.model',
'tokenizer_config.json',
]

ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'
Expand All @@ -45,6 +51,7 @@ def download_from_hf_hub(
model: str,
save_dir: str,
prefer_safetensors: bool = True,
tokenizers_only: bool = False,
token: Optional[str] = None,
):
"""Downloads model files from a Hugging Face Hub model repo.
Expand All @@ -57,6 +64,7 @@ def download_from_hf_hub(
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizers_only (bool): If true, only download tokenzier files.
irenedea marked this conversation as resolved.
Show resolved Hide resolved
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.

Expand Down Expand Up @@ -94,11 +102,14 @@ def download_from_hf_hub(
f'No supported model weights found in repo {model}.' +
' Please make sure the repo contains either safetensors or pytorch weights.'
)

allow_patterns = TOKENIZER_FILES if tokenizers_only else None
irenedea marked this conversation as resolved.
Show resolved Hide resolved

download_start = time.time()
hf_hub.snapshot_download(model,
local_dir=save_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
token=token)
download_duration = time.time() - download_start
log.info(
Expand Down Expand Up @@ -221,16 +232,18 @@ def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
tokenizer_only: bool,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.

Args:
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
model (str): The name of the model to download.
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
save_dir (str): Path to the directory where files will be downloaded.
tokenizer_only (bool): If true, only download the tokenzier files.
concurrency (int): The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
Expand All @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,):
with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs['models'][model]
config_type = 'tokenizers' if tokenizer_only else 'models'
path = configs[config_type][model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
Expand Down
14 changes: 11 additions & 3 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
python download_model.py hf --model mosaicml/mpt-7b --save-dir <save_dir> --token <token>

Download from ORAS registry:
python download_model.py oras --registry <registry> --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py oras --model mosaicml/mpt-7b --config-file <config_file> \
--credentials-dir <credentials_dir> --save-dir <save_dir>

Download from an HTTP file server:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py http --url https://server.com/path --save-dir <save_dir>
irenedea marked this conversation as resolved.
Show resolved Hide resolved

Download from an HTTP file server with fallback to Hugging Face Hub:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir> \
Expand Down Expand Up @@ -56,6 +57,9 @@ def parse_args() -> argparse.Namespace:

base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument('--save-dir', type=str, required=True)
base_parser.add_argument('--tokenizer-only',
default=False,
action='store_true')
irenedea marked this conversation as resolved.
Show resolved Hide resolved

# Add subparser for downloading from Hugging Face Hub.
hf_parser = subparsers.add_parser('hf', parents=[base_parser])
Expand Down Expand Up @@ -85,6 +89,9 @@ def parse_args() -> argparse.Namespace:
download_from = args.download_from

if download_from == 'http':
if args.tokenizer_only == True:
irenedea marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'tokenizer-only is not currently supported for http.')
try:
download_from_http_fileserver(args.url, args.save_dir,
args.ignore_cert)
Expand All @@ -109,7 +116,8 @@ def parse_args() -> argparse.Namespace:
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token,
tokenizers_only=args.tokenizer_only,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
prefer_safetensors=args.prefer_safetensors)
elif download_from == 'oras':
download_from_oras(args.model, args.config_file, args.credentials_dir,
args.save_dir, args.concurrency)
args.save_dir, args.tokenizer_only, args.concurrency)
irenedea marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions tests/utils/test_model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock,
mock_snapshot_download.assert_called_once_with(
test_repo_id,
local_dir=save_dir,
allow_patterns=None,
ignore_patterns=expected_ignore_patterns,
token=None)

Expand Down
Loading