diff --git a/haystack/nodes/prompt/invocation_layer/azure_chatgpt.py b/haystack/nodes/prompt/invocation_layer/azure_chatgpt.py index f975a21b36..582875edc2 100644 --- a/haystack/nodes/prompt/invocation_layer/azure_chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/azure_chatgpt.py @@ -1,6 +1,7 @@ from typing import Dict, Optional from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer +from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters class AzureChatGPTInvocationLayer(ChatGPTInvocationLayer): @@ -41,7 +42,6 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool: Ensures Azure ChatGPT Invocation Layer is selected when `azure_base_url` and `azure_deployment_name` are provided in addition to a list of supported models. """ + valid_model = any(m for m in ["gpt-35-turbo", "gpt-4", "gpt-4-32k"] if m in model_name_or_path) - return ( - valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None - ) + return valid_model and has_azure_parameters(**kwargs) diff --git a/haystack/nodes/prompt/invocation_layer/azure_open_ai.py b/haystack/nodes/prompt/invocation_layer/azure_open_ai.py index fc14e1eecb..e9a3a73b3f 100644 --- a/haystack/nodes/prompt/invocation_layer/azure_open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/azure_open_ai.py @@ -1,6 +1,7 @@ from typing import Dict, Optional from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer +from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters class AzureOpenAIInvocationLayer(OpenAIInvocationLayer): @@ -42,6 +43,4 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool: addition to a list of supported models. """ valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path) - return ( - valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None - ) + return valid_model and has_azure_parameters(**kwargs) diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index 6ca056266e..5d5dbbcbf5 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -3,6 +3,7 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer +from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages logger = logging.getLogger(__name__) @@ -135,4 +136,5 @@ def url(self) -> str: @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: - return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path) + valid_model = any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path) + return valid_model and not has_azure_parameters(**kwargs) diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index 198a54f73c..b530a3aa31 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -5,6 +5,7 @@ import sseclient from haystack.errors import OpenAIError +from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters from haystack.utils.openai_utils import ( openai_request, _openai_text_completion_tokenization_details, @@ -224,4 +225,4 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path) - return valid_model and kwargs.get("azure_base_url") is None + return valid_model and not has_azure_parameters(**kwargs) diff --git a/haystack/nodes/prompt/invocation_layer/utils.py b/haystack/nodes/prompt/invocation_layer/utils.py new file mode 100644 index 0000000000..07ddac51b7 --- /dev/null +++ b/haystack/nodes/prompt/invocation_layer/utils.py @@ -0,0 +1,3 @@ +def has_azure_parameters(**kwargs) -> bool: + azure_params = ["azure_base_url", "azure_deployment_name"] + return any(kwargs.get(param) for param in azure_params) diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index af322068a6..f8a406004e 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -1,6 +1,5 @@ import inspect import logging -import re from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload from haystack.nodes.base import BaseComponent @@ -87,15 +86,7 @@ def create_invocation_layer( model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs ) - potential_invocation_layer = PromptModelInvocationLayer.invocation_layer_providers - # if azure_base_url exist as an argument, invocation layer classes are filtered to only keep the ones relatives to azure - if "azure_base_url" in self.model_kwargs: - potential_invocation_layer = [ - layer for layer in potential_invocation_layer if re.search(r"azure", layer.__name__, re.IGNORECASE) - ] - # search all invocation layer classes candidates and find the first one that supports the model, - # then create an instance of that invocation layer - for invocation_layer in potential_invocation_layer: + for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers: if inspect.isabstract(invocation_layer): continue if invocation_layer.supports(self.model_name_or_path, **all_kwargs): diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 0d518be4ba..1a6c4d61c6 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -10,7 +10,14 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES -from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler +from haystack.nodes.prompt.invocation_layer import ( + HFLocalInvocationLayer, + DefaultTokenStreamingHandler, + AzureChatGPTInvocationLayer, + AzureOpenAIInvocationLayer, + OpenAIInvocationLayer, + ChatGPTInvocationLayer, +) @pytest.fixture @@ -196,6 +203,36 @@ def test_invalid_template_params(mock_model, mock_prompthub): node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.") +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None) +def test_azure_vs_open_ai_invocation_layer_selection(): + """ + Tests that the correct invocation layer is selected based on the model name and additional parameters. + As we support both OpenAI and Azure models, we need to make sure that the correct invocation layer is selected + based on the model name and additional parameters. + """ + azure_model_kwargs = { + "azure_base_url": "https://some_unimportant_url", + "azure_deployment_name": "https://some_unimportant_url.azurewebsites.net/api/prompt", + } + + node = PromptNode("gpt-4", api_key="some_key", model_kwargs=azure_model_kwargs) + assert isinstance(node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer) + + node = PromptNode("text-davinci-003", api_key="some_key", model_kwargs=azure_model_kwargs) + assert isinstance(node.prompt_model.model_invocation_layer, AzureOpenAIInvocationLayer) + + node = PromptNode("gpt-4", api_key="some_key") + assert isinstance(node.prompt_model.model_invocation_layer, ChatGPTInvocationLayer) and not isinstance( + node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer + ) + + node = PromptNode("text-davinci-003", api_key="some_key") + assert isinstance(node.prompt_model.model_invocation_layer, OpenAIInvocationLayer) and not isinstance( + node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer + ) + + @pytest.mark.skip @pytest.mark.integration @pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)