Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve robustness of get_task HF pipeline invocations #5284

Merged
merged 8 commits into from
Jul 6, 2023
Merged
4 changes: 2 additions & 2 deletions haystack/nodes/prompt/invocation_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.azure_chatgpt import AzureChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
11 changes: 10 additions & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
GenerationConfig,
Pipeline,
)
from transformers.pipelines import get_task
from huggingface_hub import model_info
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler

Expand All @@ -43,6 +43,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)

def get_task(model: str, use_auth_token: Optional[Union[str, bool]] = None, timeout: float = 3.0) -> Optional[str]:
"""
Simplified version of transformers.pipelines.get_task with support for timeouts
"""
try:
return model_info(model, token=use_auth_token, timeout=timeout).pipeline_tag
except Exception as e:
raise RuntimeError(f"The task of {model} could not be checked because of the following error: {e}") from e


class HFLocalInvocationLayer(PromptModelInvocationLayer):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/prompt/invocation_layer/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def test_supports(tmp_path):
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
assert mock_get_task.call_count == 3
mock_get_task.side_effect = RuntimeError
assert not HFLocalInvocationLayer.supports("google/flan-t5-base")
assert mock_get_task.call_count == 4

# some HF local model directory, let's use the one from test/prompt/invocation_layer
assert HFLocalInvocationLayer.supports(str(tmp_path))
Expand Down
14 changes: 14 additions & 0 deletions test/prompt/invocation_layer/test_invocation_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from haystack.nodes.prompt.prompt_model import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInferenceEndpointInvocationLayer


@pytest.mark.unit
def test_invocation_layer_order():
"""
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
"""
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
assert HFLocalInvocationLayer in last_invocation_layers
assert HFInferenceEndpointInvocationLayer in last_invocation_layers