Skip to content

Commit

Permalink
fix: Improve robustness of get_task HF pipeline invocations (#5284)
Browse files Browse the repository at this point in the history
* replace get_task method and change invocation layer order

* add test for invocation layer order

* add test documentation

* make invocation layer test more robust

* fix type annotation

* change hf timeout

* simplify timeout mock and add get_task exception cause

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
  • Loading branch information
MichelBartels and anakin87 authored Jul 6, 2023
1 parent ac41219 commit 08f1865
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
4 changes: 2 additions & 2 deletions haystack/nodes/prompt/invocation_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.azure_chatgpt import AzureChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
11 changes: 10 additions & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
GenerationConfig,
Pipeline,
)
from transformers.pipelines import get_task
from huggingface_hub import model_info
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler

Expand All @@ -43,6 +43,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)

def get_task(model: str, use_auth_token: Optional[Union[str, bool]] = None, timeout: float = 3.0) -> Optional[str]:
"""
Simplified version of transformers.pipelines.get_task with support for timeouts
"""
try:
return model_info(model, token=use_auth_token, timeout=timeout).pipeline_tag
except Exception as e:
raise RuntimeError(f"The task of {model} could not be checked because of the following error: {e}") from e


class HFLocalInvocationLayer(PromptModelInvocationLayer):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/prompt/invocation_layer/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def test_supports(tmp_path):
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
assert mock_get_task.call_count == 3
mock_get_task.side_effect = RuntimeError
assert not HFLocalInvocationLayer.supports("google/flan-t5-base")
assert mock_get_task.call_count == 4

# some HF local model directory, let's use the one from test/prompt/invocation_layer
assert HFLocalInvocationLayer.supports(str(tmp_path))
Expand Down
14 changes: 14 additions & 0 deletions test/prompt/invocation_layer/test_invocation_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from haystack.nodes.prompt.prompt_model import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInferenceEndpointInvocationLayer


@pytest.mark.unit
def test_invocation_layer_order():
"""
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
"""
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
assert HFLocalInvocationLayer in last_invocation_layers
assert HFInferenceEndpointInvocationLayer in last_invocation_layers

0 comments on commit 08f1865

Please sign in to comment.