Skip to content

Commit

Permalink
Add support for additional tasks via HuggingFace hub
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds committed Mar 4, 2024
1 parent 80634b3 commit 12e6788
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@


extras["test"] = [
"pytest",
"pytest>=7.2.0,<8.0.0",
"pytest-xdist",
"parameterized",
"psutil",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,18 @@ def infer_task_from_model_architecture(model_config_path: str, architecture_inde
architecture = config.get("architectures", [None])[architecture_index]

task = None
if os.environ["HF_TASK"]:
return os.environ["HF_TASK"]

for arch_options in ARCHITECTURES_2_TASK:
if architecture.endswith(arch_options):
task = ARCHITECTURES_2_TASK[arch_options]

if task is None:
raise ValueError(
f"Task couldn't be inferenced from {architecture}."
f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
"Use env `HF_TASK` to define your task."
f"Task not supported via Transformers supported tasks or"
f" {list(ARCHITECTURES_2_TASK.keys())}."
"Use inference.py to install unsupported task separately"
)
# set env to work with
os.environ["HF_TASK"] = task
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def test_infer_task_from_model_architecture():
assert task == "token-classification"


@require_torch
def test_infer_task_from_model_architecture_from_env_variable():
os.environ["HF_TASK"] = "image-classification"
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
assert task == "image-classification"


@require_torch
def test_wrap_conversation_pipeline():
init_pipeline = pipeline(
Expand Down

0 comments on commit 12e6788

Please sign in to comment.