Skip to content

Commit

Permalink
fix: fix to pass test code after apply Numeric Constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
hyper-clova committed Sep 24, 2024
1 parent 134d18c commit d3b2de6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 64 deletions.
65 changes: 31 additions & 34 deletions libs/community/langchain_community/chat_models/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Iterator,
List,
Optional,
Self,
Tuple,
Type,
Union,
Expand All @@ -34,8 +35,8 @@
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import AliasChoices, Field, SecretStr, model_validator

_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"

Expand Down Expand Up @@ -174,7 +175,9 @@ class ChatClovaX(BaseChatModel):
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:

model_name: str = Field(
default="HCX-003", alias="model", description="NCP ClovaStudio chat model name"
default="HCX-003",
validation_alias=AliasChoices("model_name", "model"),
description="NCP ClovaStudio chat model name",
)
task_id: Optional[str] = Field(
default=None, description="NCP Clova Studio chat model tuning task ID"
Expand All @@ -195,11 +198,11 @@ class ChatClovaX(BaseChatModel):
Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided.
"""

temperature: Optional[float] = Field(gt=0.0, le=1.0)
top_k: Optional[int] = Field(ge=0, le=128)
top_p: Optional[float] = Field(ge=0, le=1.0)
repeat_penalty: Optional[float] = Field(gt=0.0, le=10)
max_tokens: Optional[int] = Field(ge=0, le=4096)
temperature: Optional[float] = Field(gt=0.0, le=1.0, default=0.5)
top_k: Optional[int] = Field(ge=0, le=128, default=0)
top_p: Optional[float] = Field(ge=0, le=1.0, default=0.8)
repeat_penalty: Optional[float] = Field(gt=0.0, le=10, default=5.0)
max_tokens: Optional[int] = Field(ge=0, le=4096, default=100)
stop_before: Optional[list[str]] = Field(default=None, alias="stop")
include_ai_filters: Optional[bool] = Field(default=False)
seed: Optional[int] = Field(ge=0, le=4294967295, default=0)
Expand Down Expand Up @@ -271,48 +274,42 @@ def _api_url(self) -> str:
else:
return f"{self.base_url}/{app_type}/v1/chat-completions/{self.model_name}"

def __init__(
self,
client: Optional[httpx.Client] = None,
async_client: Optional[httpx.AsyncClient] = None,
**kwargs: Any,
) -> None:
"""Validate that api key and python package exists in environment."""
kwargs["ncp_clovastudio_api_key"] = convert_to_secret_str(
get_from_dict_or_env(kwargs, "api_key", "NCP_CLOVASTUDIO_API_KEY")
)
kwargs["ncp_apigw_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
kwargs, "apigw_api_key", "NCP_APIGW_API_KEY", "ncp_apigw_api_key"
@model_validator(mode="after")
def validate_model_after(self) -> Self:
if not (self.model_name or self.task_id):
raise ValueError("either model_name or task_id must be assigned a value.")

if not self.ncp_clovastudio_api_key:
self.ncp_clovastudio_api_key = convert_to_secret_str(
get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY")
)
)
kwargs["base_url"] = get_from_dict_or_env(
kwargs, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
)

super().__init__(**kwargs)
if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
)

if not (self.model_name or self.task_id):
raise ValueError("either model_name or task_id must be assigned a value.")
if not self.base_url:
self.base_url = get_from_env(
"base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
)

if client:
self.client = client
else:
if not self.client:
self.client = httpx.Client(
base_url=self.base_url,
headers=self.default_headers(),
timeout=self.timeout,
)

if async_client:
self.async_client = async_client
else:
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=self.base_url,
headers=self.default_headers(),
timeout=self.timeout,
)

return self

def default_headers(self) -> Dict[str, Any]:
clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
Expand Down
50 changes: 23 additions & 27 deletions libs/community/langchain_community/embeddings/naver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Self

import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import (
BaseModel,
Field,
SecretStr,
model_validator,
)

_DEFAULT_BASE_URL = "https://clovastudio.apigw.ntruss.com"
Expand Down Expand Up @@ -103,47 +104,42 @@ def _api_url(self) -> str:
f"/v1/api-tools/embedding/{model_name}/{self.app_id}"
)

def __init__(
self,
client: Optional[httpx.Client] = None,
async_client: Optional[httpx.AsyncClient] = None,
**kwargs: Any,
) -> None:
"""Validate that api key and python package exists in environment."""
kwargs["ncp_clovastudio_api_key"] = convert_to_secret_str(
get_from_dict_or_env(kwargs, "api_key", "NCP_CLOVASTUDIO_API_KEY")
)
kwargs["ncp_apigw_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
kwargs, "apigw_api_key", "NCP_APIGW_API_KEY", "ncp_apigw_api_key"
@model_validator(mode="after")
def validate_model_after(self) -> Self:
if not self.ncp_clovastudio_api_key:
self.ncp_clovastudio_api_key = convert_to_secret_str(
get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY")
)
)
kwargs["base_url"] = get_from_dict_or_env(
kwargs, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
)

super().__init__(**kwargs)
if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
)

self.app_id = get_from_dict_or_env(kwargs, "app_id", "NCP_CLOVASTUDIO_APP_ID")
if not self.base_url:
self.base_url = get_from_env(
"base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
)

if not self.app_id:
self.app_id = get_from_env("app_id", "NCP_CLOVASTUDIO_APP_ID")

if client is not None:
self.client = client
else:
if not self.client:
self.client = httpx.Client(
base_url=self.base_url,
headers=self.default_headers(),
timeout=self.timeout,
)

if async_client is not None:
self.async_client = async_client
else:
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=self.base_url,
headers=self.default_headers(),
timeout=self.timeout,
)

return self

def default_headers(self) -> Dict[str, Any]:
clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
Expand Down
4 changes: 2 additions & 2 deletions libs/community/tests/unit_tests/chat_models/test_naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@

def test_initialization_api_key() -> None:
"""Test chat model initialization."""
chat_model = ChatClovaX(api_key="foo", apigw_api_key="bar")
chat_model = ChatClovaX(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type]
assert (
cast(SecretStr, chat_model.ncp_clovastudio_api_key).get_secret_value() == "foo"
)
assert cast(SecretStr, chat_model.ncp_apigw_api_key).get_secret_value() == "bar"


def test_initialization_model_name() -> None:
llm = ChatClovaX(model="HCX-DASH-001")
llm = ChatClovaX(model="HCX-DASH-001") # type: ignore[call-arg]
assert llm.model_name == "HCX-DASH-001"
llm = ChatClovaX(model_name="HCX-DASH-001")
assert llm.model_name == "HCX-DASH-001"
Expand Down
2 changes: 1 addition & 1 deletion libs/community/tests/unit_tests/embeddings/test_naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@


def test_initialization_api_key() -> None:
llm = ClovaXEmbeddings(api_key="foo", apigw_api_key="bar")
llm = ClovaXEmbeddings(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type]
assert cast(SecretStr, llm.ncp_clovastudio_api_key).get_secret_value() == "foo"
assert cast(SecretStr, llm.ncp_apigw_api_key).get_secret_value() == "bar"

0 comments on commit d3b2de6

Please sign in to comment.