Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate short-circuiting for loading from local #3600

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ludwig/schema/llms/base_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import field

from marshmallow import fields, ValidationError
Expand Down Expand Up @@ -45,17 +46,20 @@ def validate(model_name: str):
"""Validates and upgrades the given model name to its full path, if applicable.

If the name exists in `MODEL_PRESETS`, returns the corresponding value from the dict; otherwise checks if the
given name (which should be a full path) exists in the transformers library.
given name (which should be a full path) exists locally or in the transformers library.
"""
if isinstance(model_name, str):
if model_name in MODEL_PRESETS:
return MODEL_PRESETS[model_name]
if os.path.isdir(model_name):
return model_name
Comment on lines +54 to +55
Copy link
Contributor

@arnavgarg1 arnavgarg1 Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wondering if we should add some more checks here, the most basic one being that the directory should not be empty. In an ideal world, we also add validation to ensure that the model config exists in this directly and it can be initialized correctly from this directory using the same code block from line 57, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we can do this in a fast-follow, but I do think for completeness these additional checks are important. What do you think? @Infernaught @justinxzhao

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Justin and I expect HF to give us a failure message if the model objects are bad, so it's probably not that bad if we don't have these checks for the time being. I'm going to merge this for now, but I definitely think we should keep thinking about how to properly verify this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with the merge for now as well, but I do think we should follow-up with custom validation here and raise clear resolution methods for what went wrong and how to fix it.

I do think we should add these checks now because they're honestly so trivial to add and we want to fail fast, always. There's nothing wrong with leaving it up to HF, but I guarantee that we are going to get Ludwig users messaging us asking us why they're seeing cryptic import error or not found errors from HF without them when there is a user-side error. The reason I personally really like these kinds of validation checks is that it lets us get ahead and provide really clear and crisp resolution methods to unblock/self-serve yourself, and that is what I would lean towards.

try:
AutoConfig.from_pretrained(model_name)
return model_name
except OSError:
raise ConfigValidationError(
f"Specified base model `{model_name}` is not a valid pretrained CausalLM listed on huggingface. "
f"Specified base model `{model_name}` is not a valid pretrained CausalLM listed on huggingface "
"or a valid local directory containing the weights for a pretrained CausalLM from huggingface."
"Please see: https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
)
raise ValidationError(
Expand Down
35 changes: 35 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,38 @@ def test_global_max_sequence_length_for_llms():

# Check that the value can never be larger than the model's context_len
assert model.global_max_sequence_length == 2048


def test_local_path_loading():
"""Tests that local paths can be used to load models."""

from huggingface_hub import snapshot_download

# Download the model to a local directory
LOCAL_PATH = "~/test_local_path_loading"
REPO_ID = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
os.makedirs(LOCAL_PATH, exist_ok=True)
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_PATH)

# Load the model using the local path
config1 = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: LOCAL_PATH,
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})],
OUTPUT_FEATURES: [text_feature(name="output")],
}
config_obj1 = ModelConfig.from_dict(config1)
model1 = LLM(config_obj1)

# Load the model using the repo id
config2 = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: REPO_ID,
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})],
OUTPUT_FEATURES: [text_feature(name="output")],
}
config_obj2 = ModelConfig.from_dict(config2)
model2 = LLM(config_obj2)

# Check that the models are the same
assert _compare_models(model1.model, model2.model)
Loading