Skip to content

Commit

Permalink
refactor: Simplify selection of Azure vs OpenAI invocation layers (#5271
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vblagoje authored Jul 6, 2023
1 parent a1a3900 commit ac41219
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 19 deletions.
6 changes: 3 additions & 3 deletions haystack/nodes/prompt/invocation_layer/azure_chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Optional

from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters


class AzureChatGPTInvocationLayer(ChatGPTInvocationLayer):
Expand Down Expand Up @@ -41,7 +42,6 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool:
Ensures Azure ChatGPT Invocation Layer is selected when `azure_base_url` and `azure_deployment_name` are provided in
addition to a list of supported models.
"""

valid_model = any(m for m in ["gpt-35-turbo", "gpt-4", "gpt-4-32k"] if m in model_name_or_path)
return (
valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None
)
return valid_model and has_azure_parameters(**kwargs)
5 changes: 2 additions & 3 deletions haystack/nodes/prompt/invocation_layer/azure_open_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Optional

from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters


class AzureOpenAIInvocationLayer(OpenAIInvocationLayer):
Expand Down Expand Up @@ -42,6 +43,4 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool:
addition to a list of supported models.
"""
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return (
valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None
)
return valid_model and has_azure_parameters(**kwargs)
4 changes: 3 additions & 1 deletion haystack/nodes/prompt/invocation_layer/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,4 +136,5 @@ def url(self) -> str:

@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
valid_model = any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
return valid_model and not has_azure_parameters(**kwargs)
3 changes: 2 additions & 1 deletion haystack/nodes/prompt/invocation_layer/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sseclient

from haystack.errors import OpenAIError
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
from haystack.utils.openai_utils import (
openai_request,
_openai_text_completion_tokenization_details,
Expand Down Expand Up @@ -224,4 +225,4 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return valid_model and kwargs.get("azure_base_url") is None
return valid_model and not has_azure_parameters(**kwargs)
3 changes: 3 additions & 0 deletions haystack/nodes/prompt/invocation_layer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def has_azure_parameters(**kwargs) -> bool:
azure_params = ["azure_base_url", "azure_deployment_name"]
return any(kwargs.get(param) for param in azure_params)
11 changes: 1 addition & 10 deletions haystack/nodes/prompt/prompt_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import logging
import re
from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload

from haystack.nodes.base import BaseComponent
Expand Down Expand Up @@ -87,15 +86,7 @@ def create_invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)

potential_invocation_layer = PromptModelInvocationLayer.invocation_layer_providers
# if azure_base_url exist as an argument, invocation layer classes are filtered to only keep the ones relatives to azure
if "azure_base_url" in self.model_kwargs:
potential_invocation_layer = [
layer for layer in potential_invocation_layer if re.search(r"azure", layer.__name__, re.IGNORECASE)
]
# search all invocation layer classes candidates and find the first one that supports the model,
# then create an instance of that invocation layer
for invocation_layer in potential_invocation_layer:
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
if inspect.isabstract(invocation_layer):
continue
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
Expand Down
39 changes: 38 additions & 1 deletion test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import (
HFLocalInvocationLayer,
DefaultTokenStreamingHandler,
AzureChatGPTInvocationLayer,
AzureOpenAIInvocationLayer,
OpenAIInvocationLayer,
ChatGPTInvocationLayer,
)


@pytest.fixture
Expand Down Expand Up @@ -196,6 +203,36 @@ def test_invalid_template_params(mock_model, mock_prompthub):
node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.")


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
def test_azure_vs_open_ai_invocation_layer_selection():
"""
Tests that the correct invocation layer is selected based on the model name and additional parameters.
As we support both OpenAI and Azure models, we need to make sure that the correct invocation layer is selected
based on the model name and additional parameters.
"""
azure_model_kwargs = {
"azure_base_url": "https://some_unimportant_url",
"azure_deployment_name": "https://some_unimportant_url.azurewebsites.net/api/prompt",
}

node = PromptNode("gpt-4", api_key="some_key", model_kwargs=azure_model_kwargs)
assert isinstance(node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer)

node = PromptNode("text-davinci-003", api_key="some_key", model_kwargs=azure_model_kwargs)
assert isinstance(node.prompt_model.model_invocation_layer, AzureOpenAIInvocationLayer)

node = PromptNode("gpt-4", api_key="some_key")
assert isinstance(node.prompt_model.model_invocation_layer, ChatGPTInvocationLayer) and not isinstance(
node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer
)

node = PromptNode("text-davinci-003", api_key="some_key")
assert isinstance(node.prompt_model.model_invocation_layer, OpenAIInvocationLayer) and not isinstance(
node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer
)


@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
Expand Down

0 comments on commit ac41219

Please sign in to comment.