Skip to content

Commit

Permalink
fix: remove result event
Browse files Browse the repository at this point in the history
  • Loading branch information
hyper-clova committed Aug 28, 2024
1 parent 89e97bd commit 0797c7e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
12 changes: 8 additions & 4 deletions libs/community/langchain_community/chat_models/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ async def _aiter_sse(
event_data = sse.json()
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
return
if sse.event == "result":
return
yield sse


Expand Down Expand Up @@ -197,7 +199,7 @@ class ChatClovaX(BaseChatModel):
"""Automatically inferred from env are `NCP_APIGW_API_KEY` if not provided."""

base_url: Optional[str] = Field(
default=DEFAULT_BASE_URL, alias="ncp_clovastudio_api_base_url"
default=None, alias="clovastudio_api_base_url"
)
"""
Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided.
Expand All @@ -213,7 +215,7 @@ class ChatClovaX(BaseChatModel):
seed: Optional[int] = None

timeout: int = 90
max_retries: int = 3
max_retries: int = 2

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -274,7 +276,7 @@ def _api_url(self) -> str:

if self.task_id:
return (
f"{self.base_url}/{app_type}/v2/tasks/{self.task_id}/chat-completions"
f"{self.base_url}/{app_type}/v1/tasks/{self.task_id}/chat-completions"
)
else:
return f"{self.base_url}/{app_type}/v1/chat-completions/{self.model_name}"
Expand Down Expand Up @@ -315,7 +317,7 @@ def validate_environment(cls, values: Dict) -> Dict:
get_from_dict_or_env(values, "ncp_apigw_api_key", "NCP_APIGW_API_KEY")
)
values["base_url"] = get_from_dict_or_env(
values, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL"
values, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", DEFAULT_BASE_URL
)

if not values.get("client"):
Expand Down Expand Up @@ -380,6 +382,8 @@ def iter_sse() -> Iterator[ServerSentEvent]:
and event_data.get("data", {}) == "[DONE]"
):
return
if sse.event == "result":
return
if sse.event == "error":
raise SSEError(message=sse.data)
yield sse
Expand Down
15 changes: 8 additions & 7 deletions libs/community/langchain_community/embeddings/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
"""Automatically inferred from env are `NCP_APIGW_API_KEY` if not provided."""

base_url: Optional[str] = Field(
default=DEFAULT_BASE_URL, alias="ncp_clovastudio_api_base_url"
default=None, alias="clovastudio_api_base_url"
)
"""
Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided.
"""

app_id: Optional[str] = Field(default=None, alias="ncp_clovastudio_app_id")
app_id: Optional[str] = Field(default=None, alias="clovastudio_app_id")
service_app: bool = Field(
default=False,
description="false: use testapp, true: use service app on NCP Clova Studio",
Expand Down Expand Up @@ -125,7 +125,7 @@ def validate_environment(cls, values: Dict) -> Dict:
get_from_dict_or_env(values, "ncp_apigw_api_key", "NCP_APIGW_API_KEY")
)
values["base_url"] = get_from_dict_or_env(
values, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL"
values, "base_url", "NCP_CLOVASTUDIO_API_BASE_URL", DEFAULT_BASE_URL
)

values["app_id"] = get_from_dict_or_env(
Expand Down Expand Up @@ -169,13 +169,13 @@ def _embed_text(self, text: str) -> List[float]:
payload = {"text": text}
response = self.client.post(url=self._api_url, json=payload)
_raise_on_error(response)
return response.json()
return response.json()["result"]["embedding"]

async def _aembed_text(self, text: str) -> List[float]:
payload = {"text": text}
response = await self.async_client.post(url=self._api_url, json=payload)
await _araise_on_error(response)
return response.json()
return response.json()["result"]["embedding"]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
Expand All @@ -189,8 +189,9 @@ def embed_query(self, text: str) -> List[float]:
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in texts:
embeddings.append(self._aembed_text(text))
embedding = await self._aembed_text(text)
embeddings.append(embedding)
return embeddings

async def aembed_query(self, text: str) -> List[float]:
return self._aembed_text(text)
return await self._aembed_text(text)
16 changes: 15 additions & 1 deletion libs/community/tests/integration_tests/embeddings/test_naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,25 @@ def test_embedding_documents() -> None:
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) > 0


async def test_aembedding_documents() -> None:
"""Test cohere embeddings."""
documents = ["foo bar"]
embedding = ClovaXEmbeddings()
output = await embedding.aembed_documents(documents)
assert len(output) == 1
assert len(output[0]) > 0

def test_embedding_query() -> None:
"""Test cohere embeddings."""
document = "foo bar"
embedding = ClovaXEmbeddings()
output = embedding.embed_query(document)
assert len(output) > 0

async def test_aembedding_query() -> None:
"""Test cohere embeddings."""
document = "foo bar"
embedding = ClovaXEmbeddings()
output = await embedding.aembed_query(document)
assert len(output) > 0

0 comments on commit 0797c7e

Please sign in to comment.