Skip to content

Commit

Permalink
mypy fixes for openai_models.py
Browse files Browse the repository at this point in the history
I am unhappy with this, had to duplicate some code.
  • Loading branch information
simonw committed Nov 7, 2024
1 parent b3a6ec7 commit 91732d0
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,6 @@ def _attachment(attachment):


class _Shared:
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None

class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)

def __init__(
self,
model_id,
Expand Down Expand Up @@ -437,6 +427,16 @@ def build_kwargs(self, prompt, stream):


class Chat(_Shared, Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None

class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)

def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
Expand Down Expand Up @@ -473,6 +473,16 @@ def execute(self, prompt, stream, response, conversation=None):


class AsyncChat(_Shared, AsyncModel):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None

class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)

async def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
Expand Down

0 comments on commit 91732d0

Please sign in to comment.