Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: Properly output a Tool from Glean Search #3851

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
from typing import Dict, List

from langflow.custom import Component
from langflow.io import DataInput, Output
Expand Down Expand Up @@ -49,7 +48,7 @@ def parse_transcription(self) -> Data:
self.status = error_message
return Data(data={"error": error_message})

def parse_with_speakers(self, utterances: List[Dict]) -> str:
def parse_with_speakers(self, utterances: list[dict]) -> str:
parsed_result = []
for utterance in utterances:
speaker = utterance["speaker"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import assemblyai as aai

from langflow.custom import Component
Expand Down Expand Up @@ -48,7 +46,7 @@ class AssemblyAIListTranscripts(Component):
Output(display_name="Transcript List", name="transcript_list", method="list_transcripts"),
]

def list_transcripts(self) -> List[Data]:
def list_transcripts(self) -> list[Data]:
aai.settings.api_key = self.api_key

params = aai.ListTranscriptParameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from langchain_core.prompts import HumanMessagePromptTemplate

from langchain_core.prompts import HumanMessagePromptTemplate

from langflow.custom import Component
from langflow.inputs import DefaultPromptField, SecretStrInput, StrInput
from langflow.io import Output
Expand Down
187 changes: 113 additions & 74 deletions src/backend/base/langflow/components/tools/GleanSearchAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from urllib.parse import urljoin

import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from pydantic.v1 import Field

from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.field_typing import Tool
Expand All @@ -28,90 +30,127 @@ class GleanSearchAPIComponent(LCToolComponent):
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""

glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)

if len(results) == 0:
raise AssertionError("No good Glean Search Result was found")

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self.results(query, **kwargs)

processed_results = []
for result in results:
if "title" in result:
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]

processed_results.append(result)

return processed_results

def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()
response_json = response.json()

return response_json.get("results", [])

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")

def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
wrapper = self._build_wrapper(
glean_api_url=self.glean_api_url,
glean_access_token=self.glean_access_token,
)

return Tool(name="glean_search_api", description="Search with the Glean API", func=wrapper.run)
tool = StructuredTool.from_function(
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
)

def run_model(self) -> Data | list[Data]:
wrapper = self._build_wrapper()
self.status = "Glean Search API Tool for Langchain"

results = wrapper.results(
query=self.query,
page_size=self.page_size,
request_options=self.request_options,
)
return tool

list_results = results.get("results", [])
def run_model(self) -> list[Data]:
tool = self.build_tool()

results = tool.run(
{
"query": self.query,
"page_size": self.page_size,
"request_options": self.request_options,
}
)

# Build the data
data = []
for result in list_results:
data.append(Data(data=result))
for result in results:
data.append(Data(data=result, text=result["snippets"][0]["text"]))

self.status = data

return data

def _build_wrapper(self):
class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""

glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def run(self, query: str, **kwargs: Any) -> str:
results = self.results(query, **kwargs)

return self._result_as_string(results)

def results(self, query: str, **kwargs: Any) -> dict:
results = self._search_api_results(query, **kwargs)

return results

def _search_api_results(self, query: str, **kwargs: Any) -> dict[str, Any]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()

return response.json()

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

return GleanAPIWrapper(glean_api_url=self.glean_api_url, glean_access_token=self.glean_access_token)
def _build_wrapper(
self,
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)
Loading