Skip to content

Commit

Permalink
Enable azure openai engines (#20)
Browse files Browse the repository at this point in the history
Love the library!

I had to make modifications to get it to work with the Azure instances
of OpenAI models that I use.

Primarily, I made it possible to pass an `engine` parameter to the
LLMClassifier as well/instead of the `model` parameter [which is what
the Azure instance
requires](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions).

I tried to do it in a way that minimized the # of edits on the existing
codebase - if this is something you'd be interested in including, happy
to make any changes to fit with your design principles.

```
import openai
openai.api_type = "azure"
openai.api_version = "2023-07-01-preview"
openai.api_base = "https://azure-openai-xxxxx.openai.azure.com/"  # Your Azure OpenAI resource's endpoint value.
openai.api_key = "<AZURE-OPENAI_API-KEY>"
evaluator = Factuality(engine="<Azure-Deployment-Name>", model=None)
```

---------

Co-authored-by: Ankur Goyal <ankrgyl@gmail.com>
  • Loading branch information
ecatkins and ankrgyl authored Oct 16, 2023
1 parent 9f160ac commit 6ac56c9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
Binary file modified .testcache/oai.sqlite
Binary file not shown.
19 changes: 9 additions & 10 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,11 @@ def __init__(
render_args=None,
max_tokens=None,
temperature=None,
engine=None,
):
found = False
for m in SUPPORTED_MODELS:
# Prefixes are ok, because they are just time snapshots
if model.startswith(m):
found = True
break
if not found:
raise ValueError(f"Unsupported model: {model}. Currently only supports OpenAI chat models.")

self.name = name
self.model = model
self.engine = engine
self.messages = messages
self.choice_scores = choice_scores
self.classification_functions = classification_functions
Expand Down Expand Up @@ -136,6 +129,7 @@ def _request_args(self, output, expected, **kwargs):
return dict(
Completion=openai.ChatCompletion,
model=self.model,
engine=self.engine,
messages=self._render_messages(output=output, expected=expected, **kwargs),
functions=self.classification_functions,
function_call={"name": "select_choice"},
Expand Down Expand Up @@ -176,6 +170,7 @@ class ModelGradedSpec:
prompt: str
choice_scores: dict[str, float]
model: Optional[str] = None
engine: Optional[str] = None
use_cot: Optional[bool] = None
temperature: Optional[float] = None

Expand All @@ -195,6 +190,7 @@ def __init__(
use_cot=True,
max_tokens=512,
temperature=0,
engine=None,
):
choice_strings = list(choice_scores.keys())

Expand All @@ -214,6 +210,7 @@ def __init__(
classification_functions=build_classification_functions(use_cot),
max_tokens=max_tokens,
temperature=temperature,
engine=engine,
render_args={"__choices": choice_strings},
)

Expand All @@ -229,10 +226,12 @@ def from_spec_file(cls, name: str, path: str, **kwargs):


class SpecFileClassifier(LLMClassifier):
def __new__(cls, model=None, use_cot=None, max_tokens=None, temperature=None):
def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None):
kwargs = {}
if model is not None:
kwargs["model"] = model
if engine is not None:
kwargs["engine"] = engine
if use_cot is not None:
kwargs["use_cot"] = use_cot
if max_tokens is not None:
Expand Down

0 comments on commit 6ac56c9

Please sign in to comment.