Skip to content

Commit

Permalink
feat: support OpenAI-Organization for authentication (#5292)
Browse files Browse the repository at this point in the history
* add openai_organization to invocation layer, generator and retriever

* added tests
  • Loading branch information
anakin87 authored Jul 7, 2023
1 parent 0697f5c commit 90ff381
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 2 deletions.
6 changes: 6 additions & 0 deletions haystack/nodes/answer_generator/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
context_join_str: str = " ",
moderate_content: bool = False,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
"""
:param api_key: Your API key from OpenAI. It is required for this node to work.
Expand Down Expand Up @@ -105,6 +106,8 @@ def __init__(
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
answers are flagged, an empty list is returned in place of the answers.
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
"""
super().__init__(progress_bar=progress_bar)
if (examples is None and examples_context is not None) or (examples is not None and examples_context is None):
Expand Down Expand Up @@ -165,6 +168,7 @@ def __init__(
self.context_join_str = context_join_str
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
self.moderate_content = moderate_content
self.openai_organization = openai_organization

tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)

Expand Down Expand Up @@ -233,6 +237,8 @@ def predict(
headers["api-key"] = self.api_key
else:
headers["Authorization"] = f"Bearer {self.api_key}"
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization

if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
Expand Down
9 changes: 8 additions & 1 deletion haystack/nodes/prompt/invocation_layer/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
model_name_or_path: str = "text-davinci-003",
max_length: Optional[int] = 100,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -43,6 +44,8 @@ def __init__(
:param max_length: The maximum number of tokens the output text can have.
:param api_key: The OpenAI API key.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
Expand All @@ -60,6 +63,7 @@ def __init__(
)
self.api_key = api_key
self.api_base = api_base
self.openai_organization = openai_organization

# 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create.
Expand Down Expand Up @@ -103,7 +107,10 @@ def url(self) -> str:

@property
def headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization
return headers

def invoke(self, *args, **kwargs):
"""
Expand Down
3 changes: 3 additions & 0 deletions haystack/nodes/retriever/_openai_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, retriever: "EmbeddingRetriever"):
self.url = f"{retriever.api_base}/embeddings"

self.api_key = retriever.api_key
self.openai_organization = retriever.openai_organization
self.batch_size = min(64, retriever.batch_size)
self.progress_bar = retriever.progress_bar
model_class: str = next(
Expand Down Expand Up @@ -113,6 +114,8 @@ def azure_get_embedding(input: str):
else:
payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text}
headers["Authorization"] = f"Bearer {self.api_key}"
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization

res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT)

Expand Down
6 changes: 5 additions & 1 deletion haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,7 @@ def __init__(
azure_base_url: Optional[str] = None,
azure_deployment_name: Optional[str] = None,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
"""
:param document_store: An instance of DocumentStore from which to retrieve documents.
Expand Down Expand Up @@ -1521,7 +1522,9 @@ def __init__(
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
will not be used.
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
"""
torch_and_transformers_import.check()

Expand Down Expand Up @@ -1551,6 +1554,7 @@ def __init__(
self.api_version = azure_api_version
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.openai_organization = openai_organization
self.model_format = (
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
if model_format is None
Expand Down
22 changes: 22 additions & 0 deletions test/nodes/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
import logging


@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_no_openai_organization(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key")
assert generator.openai_organization is None

generator.predict(query="test query", documents=[Document(content="test document")])
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]


@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_organization(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization")
assert generator.openai_organization == "fake_organization"

generator.predict(query="test query", documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"


@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_answer_generator_default_api_base(mock_request):
Expand Down
24 changes: 24 additions & 0 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,3 +1154,27 @@ def test_openai_custom_api_base(mock_request):

retriever.embed_documents(documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"


@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_no_openai_organization(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
assert retriever.openai_organization is None

retriever.embed_queries(queries=["test query"])
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]


@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_openai_organization(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(
embedding_model="text-embedding-ada-002", api_key="fake_api_key", openai_organization="fake_organization"
)
assert retriever.openai_organization == "fake_organization"

retriever.embed_queries(queries=["test query"])
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
26 changes: 26 additions & 0 deletions test/prompt/invocation_layer/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,29 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog):
_ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.")
assert "The prompt has been truncated from" in caplog.text
assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_no_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")

assert invocation_layer.openai_organization is None
assert "OpenAI-Organization" not in invocation_layer.headers

invocation_layer.invoke(prompt="dummy_prompt")
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization")

assert invocation_layer.openai_organization == "fake_organization"
assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization"

invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"

0 comments on commit 90ff381

Please sign in to comment.