diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index c713e4d5606d4..90ef6bdda6725 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -23,7 +23,7 @@ wait_exponential, ) -from langchain.llms.base import LLM, BaseLLM +from langchain.llms.base import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -515,7 +515,7 @@ def _invocation_params(self) -> Dict[str, Any]: return {**{"engine": self.deployment_name}, **super()._invocation_params} -class OpenAIChat(LLM, BaseModel): +class OpenAIChat(BaseLLM, BaseModel): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the @@ -621,15 +621,30 @@ def _completion_with_retry(**kwargs: Any) -> Any: return _completion_with_retry(**kwargs) - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - messages = self.prefix_messages + [{"role": "user", "content": prompt}] + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + if len(prompts) > 1: + raise ValueError(f"OpenAIChat only supports single prompts, got {prompts}") + messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop response = self.completion_with_retry(messages=messages, **params) - return response["choices"][0]["message"]["content"] + return LLMResult( + generations=[ + [Generation(text=response["choices"][0]["message"]["content"])] + ], + llm_output={"token_usage": response["usage"]}, + ) + + async def _agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + raise NotImplementedError("Async generation not implemented for this LLM.") @property def _identifying_params(self) -> Mapping[str, Any]: