Skip to content

Commit

Permalink
premai[patch]: Standardize premai params (#21513)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

community:premai[patch]: standardize init args

- updated `temperature` with Pydantic Field, updated the unit test.
- updated `max_tokens` with Pydantic Field, updated the unit test.
- updated `max_retries` with Pydantic Field, updated the unit test.

Related to #20085

---------

Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent fcf9230 commit 69f9acb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
15 changes: 12 additions & 3 deletions libs/community/langchain_community/chat_models/premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,22 @@ class ChatPremAI(BaseChatModel, BaseModel):
If model name is other than default model then it will override the calls
from the model deployed from launchpad."""

temperature: Optional[float] = None
session_id: Optional[str] = None
"""The ID of the session to use. It helps to track the chat history."""

temperature: Optional[float] = Field(default=None)
"""Model temperature. Value should be >= 0 and <= 1.0"""

max_tokens: Optional[int] = None
top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""

max_tokens: Optional[int] = Field(default=None)

"""The maximum number of tokens to generate"""

max_retries: int = 1
max_retries: int = Field(default=1)
"""Max number of retries to call the API"""

system_prompt: Optional[str] = ""
Expand Down
3 changes: 3 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,7 @@ def test_premai_initialization() -> None:
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8), # type: ignore[arg-type, call-arg]
]:
assert model.model == "prem-ai-model"
assert model.temperature is None
assert model.max_tokens is None
assert model.max_retries == 1
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"

0 comments on commit 69f9acb

Please sign in to comment.