diff --git a/setup.py b/setup.py index 5e4f72a..f183c1e 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ extras["test"] = [ - "pytest", + "pytest>=7.2.0,<8.0.0", "pytest-xdist", "parameterized", "psutil", diff --git a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py index ba8141a..65a487a 100644 --- a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py @@ -21,7 +21,7 @@ from huggingface_hub import HfApi, login, snapshot_download from transformers import AutoTokenizer, pipeline from transformers.file_utils import is_tf_available, is_torch_available -from transformers.pipelines import Conversation, Pipeline +from transformers.pipelines import SUPPORTED_TASKS, Conversation, Pipeline from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available @@ -231,15 +231,18 @@ def infer_task_from_model_architecture(model_config_path: str, architecture_inde architecture = config.get("architectures", [None])[architecture_index] task = None + if "HF_TASK" in os.environ: + return os.environ["HF_TASK"] + for arch_options in ARCHITECTURES_2_TASK: if architecture.endswith(arch_options): task = ARCHITECTURES_2_TASK[arch_options] if task is None: raise ValueError( - f"Task couldn't be inferenced from {architecture}." - f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}." - "Use env `HF_TASK` to define your task." + f"Task not supported via {list(SUPPORTED_TASKS.keys())} or" + f" {list(ARCHITECTURES_2_TASK.keys())}." + "Use inference.py to install unsupported task separately" ) # set env to work with os.environ["HF_TASK"] = task diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py index a8b5b71..7eec52c 100644 --- a/tests/unit/test_handler_service_with_context.py +++ b/tests/unit/test_handler_service_with_context.py @@ -96,7 +96,8 @@ def test_load(inference_handler): assert hf_pipeline_without_task.task == "token-classification" # test with automatic infer - os.environ["HF_TASK"] = TASK + os.environ["HF_TASK"] = "text-classification" + inference_handler = handler_service.HuggingFaceHandlerService() hf_pipeline_with_task = inference_handler.load(storage_folder, context) assert hf_pipeline_with_task.task == TASK diff --git a/tests/unit/test_transformers_utils.py b/tests/unit/test_transformers_utils.py index 902a074..7571980 100644 --- a/tests/unit/test_transformers_utils.py +++ b/tests/unit/test_transformers_utils.py @@ -131,6 +131,15 @@ def test_infer_task_from_model_architecture(): assert task == "token-classification" +@require_torch +def test_infer_task_from_model_architecture_from_env_variable(): + os.environ["HF_TASK"] = "image-classification" + with tempfile.TemporaryDirectory() as tmpdirname: + storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname) + task = infer_task_from_model_architecture(f"{storage_dir}/config.json") + assert task == "image-classification" + + @require_torch def test_wrap_conversation_pipeline(): init_pipeline = pipeline(