diff --git a/libs/community/langchain_community/chat_models/naver.py b/libs/community/langchain_community/chat_models/naver.py index 1201901e9b8e7..593cad8fa3232 100644 --- a/libs/community/langchain_community/chat_models/naver.py +++ b/libs/community/langchain_community/chat_models/naver.py @@ -8,6 +8,7 @@ Iterator, List, Optional, + Self, Tuple, Type, Union, @@ -34,8 +35,8 @@ SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env -from pydantic import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_env +from pydantic import AliasChoices, Field, SecretStr, model_validator _DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com" @@ -174,7 +175,9 @@ class ChatClovaX(BaseChatModel): async_client: httpx.AsyncClient = Field(default=None) #: :meta private: model_name: str = Field( - default="HCX-003", alias="model", description="NCP ClovaStudio chat model name" + default="HCX-003", + validation_alias=AliasChoices("model_name", "model"), + description="NCP ClovaStudio chat model name", ) task_id: Optional[str] = Field( default=None, description="NCP Clova Studio chat model tuning task ID" @@ -195,11 +198,11 @@ class ChatClovaX(BaseChatModel): Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided. """ - temperature: Optional[float] = Field(gt=0.0, le=1.0) - top_k: Optional[int] = Field(ge=0, le=128) - top_p: Optional[float] = Field(ge=0, le=1.0) - repeat_penalty: Optional[float] = Field(gt=0.0, le=10) - max_tokens: Optional[int] = Field(ge=0, le=4096) + temperature: Optional[float] = Field(gt=0.0, le=1.0, default=0.5) + top_k: Optional[int] = Field(ge=0, le=128, default=0) + top_p: Optional[float] = Field(ge=0, le=1.0, default=0.8) + repeat_penalty: Optional[float] = Field(gt=0.0, le=10, default=5.0) + max_tokens: Optional[int] = Field(ge=0, le=4096, default=100) stop_before: Optional[list[str]] = Field(default=None, alias="stop") include_ai_filters: Optional[bool] = Field(default=False) seed: Optional[int] = Field(ge=0, le=4294967295, default=0) @@ -271,48 +274,42 @@ def _api_url(self) -> str: else: return f"{self.base_url}/{app_type}/v1/chat-completions/{self.model_name}" - def __init__( - self, - client: Optional[httpx.Client] = None, - async_client: Optional[httpx.AsyncClient] = None, - **kwargs: Any, - ) -> None: - """Validate that api key and python package exists in environment.""" - kwargs["ncp_clovastudio_api_key"] = convert_to_secret_str( - get_from_dict_or_env(kwargs, "api_key", "NCP_CLOVASTUDIO_API_KEY") - ) - kwargs["ncp_apigw_api_key"] = convert_to_secret_str( - get_from_dict_or_env( - kwargs, "apigw_api_key", "NCP_APIGW_API_KEY", "ncp_apigw_api_key" + @model_validator(mode="after") + def validate_model_after(self) -> Self: + if not (self.model_name or self.task_id): + raise ValueError("either model_name or task_id must be assigned a value.") + + if not self.ncp_clovastudio_api_key: + self.ncp_clovastudio_api_key = convert_to_secret_str( + get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY") ) - ) - kwargs["base_url"] = get_from_dict_or_env( - kwargs, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL - ) - super().__init__(**kwargs) + if not self.ncp_apigw_api_key: + self.ncp_apigw_api_key = convert_to_secret_str( + get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY") + ) - if not (self.model_name or self.task_id): - raise ValueError("either model_name or task_id must be assigned a value.") + if not self.base_url: + self.base_url = get_from_env( + "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL + ) - if client: - self.client = client - else: + if not self.client: self.client = httpx.Client( base_url=self.base_url, headers=self.default_headers(), timeout=self.timeout, ) - if async_client: - self.async_client = async_client - else: + if not self.async_client: self.async_client = httpx.AsyncClient( base_url=self.base_url, headers=self.default_headers(), timeout=self.timeout, ) + return self + def default_headers(self) -> Dict[str, Any]: clovastudio_api_key = ( self.ncp_clovastudio_api_key.get_secret_value() diff --git a/libs/community/langchain_community/embeddings/naver.py b/libs/community/langchain_community/embeddings/naver.py index a22374ac3d808..927c2849c093b 100644 --- a/libs/community/langchain_community/embeddings/naver.py +++ b/libs/community/langchain_community/embeddings/naver.py @@ -1,13 +1,14 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Self import httpx from langchain_core.embeddings import Embeddings -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_env from pydantic import ( BaseModel, Field, SecretStr, + model_validator, ) _DEFAULT_BASE_URL = "https://clovastudio.apigw.ntruss.com" @@ -103,47 +104,42 @@ def _api_url(self) -> str: f"/v1/api-tools/embedding/{model_name}/{self.app_id}" ) - def __init__( - self, - client: Optional[httpx.Client] = None, - async_client: Optional[httpx.AsyncClient] = None, - **kwargs: Any, - ) -> None: - """Validate that api key and python package exists in environment.""" - kwargs["ncp_clovastudio_api_key"] = convert_to_secret_str( - get_from_dict_or_env(kwargs, "api_key", "NCP_CLOVASTUDIO_API_KEY") - ) - kwargs["ncp_apigw_api_key"] = convert_to_secret_str( - get_from_dict_or_env( - kwargs, "apigw_api_key", "NCP_APIGW_API_KEY", "ncp_apigw_api_key" + @model_validator(mode="after") + def validate_model_after(self) -> Self: + if not self.ncp_clovastudio_api_key: + self.ncp_clovastudio_api_key = convert_to_secret_str( + get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY") ) - ) - kwargs["base_url"] = get_from_dict_or_env( - kwargs, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL - ) - super().__init__(**kwargs) + if not self.ncp_apigw_api_key: + self.ncp_apigw_api_key = convert_to_secret_str( + get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY") + ) - self.app_id = get_from_dict_or_env(kwargs, "app_id", "NCP_CLOVASTUDIO_APP_ID") + if not self.base_url: + self.base_url = get_from_env( + "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL + ) - if client is not None: - self.client = client - else: + if not self.app_id: + self.app_id = get_from_env("app_id", "NCP_CLOVASTUDIO_APP_ID") + + if not self.client: self.client = httpx.Client( base_url=self.base_url, headers=self.default_headers(), timeout=self.timeout, ) - if async_client is not None: - self.async_client = async_client - else: + if not self.async_client: self.async_client = httpx.AsyncClient( base_url=self.base_url, headers=self.default_headers(), timeout=self.timeout, ) + return self + def default_headers(self) -> Dict[str, Any]: clovastudio_api_key = ( self.ncp_clovastudio_api_key.get_secret_value() @@ -151,8 +147,8 @@ def default_headers(self) -> Dict[str, Any]: else None ) apigw_api_key = ( - self.ncp_clovastudio_api_key.get_secret_value() - if self.ncp_clovastudio_api_key + self.ncp_apigw_api_key.get_secret_value() + if self.ncp_apigw_api_key else None ) return { diff --git a/libs/community/tests/unit_tests/chat_models/test_naver.py b/libs/community/tests/unit_tests/chat_models/test_naver.py index 0648ae647d4dd..a4c48af874e32 100644 --- a/libs/community/tests/unit_tests/chat_models/test_naver.py +++ b/libs/community/tests/unit_tests/chat_models/test_naver.py @@ -27,7 +27,7 @@ def test_initialization_api_key() -> None: """Test chat model initialization.""" - chat_model = ChatClovaX(api_key="foo", apigw_api_key="bar") + chat_model = ChatClovaX(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type] assert ( cast(SecretStr, chat_model.ncp_clovastudio_api_key).get_secret_value() == "foo" ) @@ -35,7 +35,7 @@ def test_initialization_api_key() -> None: def test_initialization_model_name() -> None: - llm = ChatClovaX(model="HCX-DASH-001") + llm = ChatClovaX(model="HCX-DASH-001") # type: ignore[call-arg] assert llm.model_name == "HCX-DASH-001" llm = ChatClovaX(model_name="HCX-DASH-001") assert llm.model_name == "HCX-DASH-001" diff --git a/libs/community/tests/unit_tests/embeddings/test_naver.py b/libs/community/tests/unit_tests/embeddings/test_naver.py index ca2a54f7563fa..7fd8b28665633 100644 --- a/libs/community/tests/unit_tests/embeddings/test_naver.py +++ b/libs/community/tests/unit_tests/embeddings/test_naver.py @@ -13,6 +13,6 @@ def test_initialization_api_key() -> None: - llm = ClovaXEmbeddings(api_key="foo", apigw_api_key="bar") + llm = ClovaXEmbeddings(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type] assert cast(SecretStr, llm.ncp_clovastudio_api_key).get_secret_value() == "foo" assert cast(SecretStr, llm.ncp_apigw_api_key).get_secret_value() == "bar"