diff --git a/haystack/nodes/prompt/invocation_layer/__init__.py b/haystack/nodes/prompt/invocation_layer/__init__.py index 72df2edfda..a50aafc082 100644 --- a/haystack/nodes/prompt/invocation_layer/__init__.py +++ b/haystack/nodes/prompt/invocation_layer/__init__.py @@ -3,10 +3,10 @@ from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer from haystack.nodes.prompt.invocation_layer.azure_chatgpt import AzureChatGPTInvocationLayer from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler -from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer -from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer +from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer +from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index f7728b53fc..86d93b6f1a 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -21,7 +21,7 @@ GenerationConfig, Pipeline, ) - from transformers.pipelines import get_task + from huggingface_hub import model_info from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler @@ -43,6 +43,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1]) return any(all(stop_word) for stop_word in stop_result) + def get_task(model: str, use_auth_token: Optional[Union[str, bool]] = None, timeout: float = 3.0) -> Optional[str]: + """ + Simplified version of transformers.pipelines.get_task with support for timeouts + """ + try: + return model_info(model, token=use_auth_token, timeout=timeout).pipeline_tag + except Exception as e: + raise RuntimeError(f"The task of {model} could not be checked because of the following error: {e}") from e + class HFLocalInvocationLayer(PromptModelInvocationLayer): """ diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 5c30dd6b9c..5152dac954 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -368,7 +368,9 @@ def test_supports(tmp_path): assert HFLocalInvocationLayer.supports("google/flan-t5-base") assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b") assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta") - assert mock_get_task.call_count == 3 + mock_get_task.side_effect = RuntimeError + assert not HFLocalInvocationLayer.supports("google/flan-t5-base") + assert mock_get_task.call_count == 4 # some HF local model directory, let's use the one from test/prompt/invocation_layer assert HFLocalInvocationLayer.supports(str(tmp_path)) diff --git a/test/prompt/invocation_layer/test_invocation_layers.py b/test/prompt/invocation_layer/test_invocation_layers.py new file mode 100644 index 0000000000..6004a9f7db --- /dev/null +++ b/test/prompt/invocation_layer/test_invocation_layers.py @@ -0,0 +1,14 @@ +import pytest + +from haystack.nodes.prompt.prompt_model import PromptModelInvocationLayer +from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInferenceEndpointInvocationLayer + + +@pytest.mark.unit +def test_invocation_layer_order(): + """ + Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond. + """ + last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:]) + assert HFLocalInvocationLayer in last_invocation_layers + assert HFInferenceEndpointInvocationLayer in last_invocation_layers