Skip to content

Commit

Permalink
Add support for additional tasks via HuggingFace hub
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds committed Mar 4, 2024
1 parent 80634b3 commit 2689eb4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@


extras["test"] = [
"pytest",
"pytest<=8.0.0",
"pytest-xdist",
"parameterized",
"psutil",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from huggingface_hub import HfApi, login, snapshot_download
from transformers import AutoTokenizer, pipeline
from transformers.file_utils import is_tf_available, is_torch_available
from transformers.pipelines import Conversation, Pipeline
from transformers.pipelines import SUPPORTED_TASKS, Conversation, Pipeline

from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available

Expand Down Expand Up @@ -231,15 +231,18 @@ def infer_task_from_model_architecture(model_config_path: str, architecture_inde
architecture = config.get("architectures", [None])[architecture_index]

task = None
if "HF_TASK" in os.environ:
return os.environ["HF_TASK"]

for arch_options in ARCHITECTURES_2_TASK:
if architecture.endswith(arch_options):
task = ARCHITECTURES_2_TASK[arch_options]

if task is None:
raise ValueError(
f"Task couldn't be inferenced from {architecture}."
f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
"Use env `HF_TASK` to define your task."
f"Task not supported via {list(SUPPORTED_TASKS.keys())} or"
f" {list(ARCHITECTURES_2_TASK.keys())}."
"Use inference.py to install unsupported task separately"
)
# set env to work with
os.environ["HF_TASK"] = task
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def test_infer_task_from_model_architecture():
assert task == "token-classification"


@require_torch
def test_infer_task_from_model_architecture_from_env_variable():
os.environ["HF_TASK"] = "image-classification"
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
assert task == "image-classification"


@require_torch
def test_wrap_conversation_pipeline():
init_pipeline = pipeline(
Expand Down

0 comments on commit 2689eb4

Please sign in to comment.