Skip to content

Commit

Permalink
feat: Add support for OpenAI's gpt-3.5-turbo-instruct model (#5837)
Browse files Browse the repository at this point in the history
* support gpt-3.5.-turbo-instruct

* add release note
  • Loading branch information
tholor authored Sep 19, 2023
1 parent 4112639 commit aa3cc3d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 3 deletions.
5 changes: 4 additions & 1 deletion haystack/nodes/prompt/invocation_layer/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,8 @@ def url(self) -> str:

@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
valid_model = (
any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
and not "gpt-3.5-turbo-instruct" in model_name_or_path
)
return valid_model and not has_azure_parameters(**kwargs)
2 changes: 1 addition & 1 deletion haystack/nodes/prompt/invocation_layer/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union

@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any(
valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie", "gpt-3.5-turbo-instruct"] or any(
m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]
)
return valid_model and not has_azure_parameters(**kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Support OpenAI's new `gpt-3.5-turbo-instruct` model
2 changes: 1 addition & 1 deletion test/prompt/invocation_layer/test_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_supports_correct_model_names():

@pytest.mark.unit
def test_does_not_support_wrong_model_names():
for model_name in ["got-3.5-turbo", "wrong_model_name"]:
for model_name in ["got-3.5-turbo", "wrong_model_name", "gpt-3.5-turbo-instruct"]:
assert not ChatGPTInvocationLayer.supports(model_name)


Expand Down
1 change: 1 addition & 0 deletions test/prompt/invocation_layer/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_supports(load_openai_tokenizer):
assert layer.supports("davinci")
assert layer.supports("text-ada-001")
assert layer.supports("text-davinci-002")
assert layer.supports("gpt-3.5-turbo-instruct")

# the following model contains "ada" in the name, but it's not from OpenAI
assert not layer.supports("ybelkada/mpt-7b-bf16-sharded")

0 comments on commit aa3cc3d

Please sign in to comment.