diff --git a/libs/langchain/langchain/chat_models/mlflow.py b/libs/langchain/langchain/chat_models/mlflow.py index e1c1ad1542bd7..4aa42f7a96351 100644 --- a/libs/langchain/langchain/chat_models/mlflow.py +++ b/libs/langchain/langchain/chat_models/mlflow.py @@ -115,13 +115,13 @@ def _generate( "messages": message_dicts, "temperature": self.temperature, "n": self.n, - "stop": stop or self.stop, - "max_tokens": self.max_tokens, **self.extra_params, **kwargs, } if stop := self.stop or stop: data["stop"] = stop + if self.max_tokens is not None: + data["max_tokens"] = self.max_tokens resp = self._client.predict(endpoint=self.endpoint, inputs=data) return ChatMlflow._create_chat_result(resp) diff --git a/libs/langchain/langchain/llms/databricks.py b/libs/langchain/langchain/llms/databricks.py index d83e67a6cc8bf..a3f505b5c2ee4 100644 --- a/libs/langchain/langchain/llms/databricks.py +++ b/libs/langchain/langchain/llms/databricks.py @@ -334,13 +334,14 @@ class Config: @property def _llm_params(self) -> Dict[str, Any]: - params = { + params: Dict[str, Any] = { "temperature": self.temperature, "n": self.n, - "stop": self.stop, - "max_tokens": self.max_tokens, - **(self.model_kwargs or self.extra_params), } + if self.stop: + params["stop"] = self.stop + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens return params @validator("cluster_id", always=True) @@ -457,11 +458,9 @@ def _call( request: Dict[str, Any] = {"prompt": prompt} if self._client.llm: request.update(self._llm_params) - request.update(self.model_kwargs or self.extra_params) - else: - request.update(self.model_kwargs or self.extra_params) + request.update(self.model_kwargs or self.extra_params) request.update(kwargs) - if stop := self.stop or stop: + if stop: request["stop"] = stop if self.transform_input_fn: diff --git a/libs/langchain/langchain/llms/mlflow.py b/libs/langchain/langchain/llms/mlflow.py index 565a4b3a363f0..00e16bcb4651b 100644 --- a/libs/langchain/langchain/llms/mlflow.py +++ b/libs/langchain/langchain/llms/mlflow.py @@ -106,12 +106,14 @@ def _call( "prompt": prompt, "temperature": self.temperature, "n": self.n, - "max_tokens": self.max_tokens, **self.extra_params, **kwargs, } if stop := self.stop or stop: data["stop"] = stop + if self.max_tokens is not None: + data["max_tokens"] = self.max_tokens + resp = self._client.predict(endpoint=self.endpoint, inputs=data) return resp["choices"][0]["text"]