diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 4185f4c067..ae1ac362ed 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -48,6 +48,7 @@ def __init__( context_join_str: str = " ", moderate_content: bool = False, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, ): """ :param api_key: Your API key from OpenAI. It is required for this node to work. @@ -105,6 +106,8 @@ def __init__( using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or answers are flagged, an empty list is returned in place of the answers. :param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). """ super().__init__(progress_bar=progress_bar) if (examples is None and examples_context is not None) or (examples is not None and examples_context is None): @@ -165,6 +168,7 @@ def __init__( self.context_join_str = context_join_str self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None self.moderate_content = moderate_content + self.openai_organization = openai_organization tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model) @@ -233,6 +237,8 @@ def predict( headers["api-key"] = self.api_key else: headers["Authorization"] = f"Bearer {self.api_key}" + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers): logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index b530a3aa31..7388d51f75 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -34,6 +34,7 @@ def __init__( model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, **kwargs, ): """ @@ -43,6 +44,8 @@ def __init__( :param max_length: The maximum number of tokens the output text can have. :param api_key: The OpenAI API key. :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). :param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant @@ -60,6 +63,7 @@ def __init__( ) self.api_key = api_key self.api_base = api_base + self.openai_organization = openai_organization # 16 is the default length for answers from OpenAI shown in the docs # here, https://platform.openai.com/docs/api-reference/completions/create. @@ -103,7 +107,10 @@ def url(self) -> str: @property def headers(self) -> Dict[str, str]: - return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization + return headers def invoke(self, *args, **kwargs): """ diff --git a/haystack/nodes/retriever/_openai_encoder.py b/haystack/nodes/retriever/_openai_encoder.py index 03c914da37..6079a188a2 100644 --- a/haystack/nodes/retriever/_openai_encoder.py +++ b/haystack/nodes/retriever/_openai_encoder.py @@ -40,6 +40,7 @@ def __init__(self, retriever: "EmbeddingRetriever"): self.url = f"{retriever.api_base}/embeddings" self.api_key = retriever.api_key + self.openai_organization = retriever.openai_organization self.batch_size = min(64, retriever.batch_size) self.progress_bar = retriever.progress_bar model_class: str = next( @@ -113,6 +114,8 @@ def azure_get_embedding(input: str): else: payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text} headers["Authorization"] = f"Bearer {self.api_key}" + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 29883a5642..4bd5c8b743 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1462,6 +1462,7 @@ def __init__( azure_base_url: Optional[str] = None, azure_deployment_name: Optional[str] = None, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, ): """ :param document_store: An instance of DocumentStore from which to retrieve documents. @@ -1521,7 +1522,9 @@ def __init__( This parameter is an OpenAI Azure endpoint, usually in the form `https://.openai.azure.com' :param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API will not be used. - :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"` + :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). """ torch_and_transformers_import.check() @@ -1551,6 +1554,7 @@ def __init__( self.api_version = azure_api_version self.azure_base_url = azure_base_url self.azure_deployment_name = azure_deployment_name + self.openai_organization = openai_organization self.model_format = ( self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token) if model_format is None diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index 262bc830e2..989524b638 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -9,6 +9,28 @@ import logging +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_no_openai_organization(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key") + assert generator.openai_organization is None + + generator.predict(query="test query", documents=[Document(content="test document")]) + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_openai_organization(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization") + assert generator.openai_organization == "fake_organization" + + generator.predict(query="test query", documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization" + + @pytest.mark.unit @patch("haystack.nodes.answer_generator.openai.openai_request") def test_openai_answer_generator_default_api_base(mock_request): diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 656ed64c66..07b86adc02 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -1154,3 +1154,27 @@ def test_openai_custom_api_base(mock_request): retriever.embed_documents(documents=[Document(content="test document")]) assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings" + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_no_openai_organization(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key") + assert retriever.openai_organization is None + + retriever.embed_queries(queries=["test query"]) + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_openai_organization(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever( + embedding_model="text-embedding-ada-002", api_key="fake_api_key", openai_organization="fake_organization" + ) + assert retriever.openai_organization == "fake_organization" + + retriever.embed_queries(queries=["test query"]) + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization" diff --git a/test/prompt/invocation_layer/test_openai.py b/test/prompt/invocation_layer/test_openai.py index c4960bbf65..31fe7066da 100644 --- a/test/prompt/invocation_layer/test_openai.py +++ b/test/prompt/invocation_layer/test_openai.py @@ -39,3 +39,29 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog): _ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.") assert "The prompt has been truncated from" in caplog.text assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_no_openai_organization(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key") + + assert invocation_layer.openai_organization is None + assert "OpenAI-Organization" not in invocation_layer.headers + + invocation_layer.invoke(prompt="dummy_prompt") + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_openai_organization(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization") + + assert invocation_layer.openai_organization == "fake_organization" + assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"