forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move serpapi wrapper (langchain-ai#1199)
Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com>
- Loading branch information
1 parent
7e4cd62
commit a4b2ec2
Showing
7 changed files
with
174 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,166 +1,4 @@ | ||
"""Chain that calls SerpAPI. | ||
"""For backwards compatiblity.""" | ||
from langchain.utilities.serpapi import SerpAPIWrapper | ||
|
||
Heavily borrowed from https://github.com/ofirpress/self-ask | ||
""" | ||
import os | ||
import sys | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import aiohttp | ||
from pydantic import BaseModel, Extra, Field, root_validator | ||
|
||
from langchain.utils import get_from_dict_or_env | ||
|
||
|
||
class HiddenPrints: | ||
"""Context manager to hide prints.""" | ||
|
||
def __enter__(self) -> None: | ||
"""Open file to pipe stdout to.""" | ||
self._original_stdout = sys.stdout | ||
sys.stdout = open(os.devnull, "w") | ||
|
||
def __exit__(self, *_: Any) -> None: | ||
"""Close file that stdout was piped to.""" | ||
sys.stdout.close() | ||
sys.stdout = self._original_stdout | ||
|
||
|
||
def _get_default_params() -> dict: | ||
return { | ||
"engine": "google", | ||
"google_domain": "google.com", | ||
"gl": "us", | ||
"hl": "en", | ||
} | ||
|
||
|
||
def process_response(res: dict, return_organic_results: bool = False) -> str: | ||
"""Process response from SerpAPI.""" | ||
if "error" in res.keys(): | ||
raise ValueError(f"Got error from SerpAPI: {res['error']}") | ||
|
||
if return_organic_results: | ||
if "organic_results" in res.keys(): | ||
return "\n\n".join( | ||
[ | ||
f"* Title: {r.get('title')} \n * Snippet: {r.get('snippet')}" | ||
for r in res["organic_results"] | ||
] | ||
) | ||
|
||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): | ||
toret = res["answer_box"]["answer"] | ||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): | ||
toret = res["answer_box"]["snippet"] | ||
elif ( | ||
"answer_box" in res.keys() | ||
and "snippet_highlighted_words" in res["answer_box"].keys() | ||
): | ||
toret = res["answer_box"]["snippet_highlighted_words"][0] | ||
elif ( | ||
"sports_results" in res.keys() | ||
and "game_spotlight" in res["sports_results"].keys() | ||
): | ||
toret = res["sports_results"]["game_spotlight"] | ||
elif ( | ||
"knowledge_graph" in res.keys() | ||
and "description" in res["knowledge_graph"].keys() | ||
): | ||
toret = res["knowledge_graph"]["description"] | ||
elif "snippet" in res["organic_results"][0].keys(): | ||
toret = res["organic_results"][0]["snippet"] | ||
|
||
else: | ||
toret = "No good search result found" | ||
return toret | ||
|
||
|
||
class SerpAPIWrapper(BaseModel): | ||
"""Wrapper around SerpAPI. | ||
To use, you should have the ``google-search-results`` python package installed, | ||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass | ||
`serpapi_api_key` as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain import SerpAPIWrapper | ||
serpapi = SerpAPIWrapper() | ||
""" | ||
|
||
search_engine: Any #: :meta private: | ||
params: dict = Field(default_factory=_get_default_params) | ||
serpapi_api_key: Optional[str] = None | ||
return_organic_results: bool = False | ||
aiosession: Optional[aiohttp.ClientSession] = None | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
serpapi_api_key = get_from_dict_or_env( | ||
values, "serpapi_api_key", "SERPAPI_API_KEY" | ||
) | ||
values["serpapi_api_key"] = serpapi_api_key | ||
try: | ||
from serpapi import GoogleSearch | ||
|
||
values["search_engine"] = GoogleSearch | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import serpapi python package. " | ||
"Please it install it with `pip install google-search-results`." | ||
) | ||
return values | ||
|
||
async def arun(self, query: str) -> str: | ||
"""Use aiohttp to run query through SerpAPI and parse result.""" | ||
|
||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]: | ||
params = self.get_params(query) | ||
params["source"] = "python" | ||
if self.serpapi_api_key: | ||
params["serp_api_key"] = self.serpapi_api_key | ||
params["output"] = "json" | ||
url = "https://serpapi.com/search" | ||
return url, params | ||
|
||
url, params = construct_url_and_params() | ||
if not self.aiosession: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url, params=params) as response: | ||
res = await response.json() | ||
else: | ||
async with self.aiosession.get(url, params=params) as response: | ||
res = await response.json() | ||
|
||
return process_response(res, return_organic_results=self.return_organic_results) | ||
|
||
def run(self, query: str) -> str: | ||
"""Run query through SerpAPI and parse result.""" | ||
params = self.get_params(query) | ||
with HiddenPrints(): | ||
search = self.search_engine(params) | ||
res = search.get_dict() | ||
return process_response(res, return_organic_results=self.return_organic_results) | ||
|
||
def get_params(self, query: str) -> Dict[str, str]: | ||
"""Get parameters for SerpAPI.""" | ||
_params = { | ||
"api_key": self.serpapi_api_key, | ||
"q": query, | ||
} | ||
params = {**self.params, **_params} | ||
return params | ||
|
||
|
||
# For backwards compatibility | ||
|
||
SerpAPIChain = SerpAPIWrapper | ||
__all__ = ["SerpAPIWrapper"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Chain that calls SerpAPI. | ||
Heavily borrowed from https://github.com/ofirpress/self-ask | ||
""" | ||
import os | ||
import sys | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import aiohttp | ||
from pydantic import BaseModel, Extra, Field, root_validator | ||
|
||
from langchain.utils import get_from_dict_or_env | ||
|
||
|
||
class HiddenPrints: | ||
"""Context manager to hide prints.""" | ||
|
||
def __enter__(self) -> None: | ||
"""Open file to pipe stdout to.""" | ||
self._original_stdout = sys.stdout | ||
sys.stdout = open(os.devnull, "w") | ||
|
||
def __exit__(self, *_: Any) -> None: | ||
"""Close file that stdout was piped to.""" | ||
sys.stdout.close() | ||
sys.stdout = self._original_stdout | ||
|
||
|
||
class SerpAPIWrapper(BaseModel): | ||
"""Wrapper around SerpAPI. | ||
To use, you should have the ``google-search-results`` python package installed, | ||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass | ||
`serpapi_api_key` as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain import SerpAPIWrapper | ||
serpapi = SerpAPIWrapper() | ||
""" | ||
|
||
search_engine: Any #: :meta private: | ||
params: dict = Field( | ||
default={ | ||
"engine": "google", | ||
"google_domain": "google.com", | ||
"gl": "us", | ||
"hl": "en", | ||
} | ||
) | ||
serpapi_api_key: Optional[str] = None | ||
return_organic_results: bool = False | ||
aiosession: Optional[aiohttp.ClientSession] = None | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
serpapi_api_key = get_from_dict_or_env( | ||
values, "serpapi_api_key", "SERPAPI_API_KEY" | ||
) | ||
values["serpapi_api_key"] = serpapi_api_key | ||
try: | ||
from serpapi import GoogleSearch | ||
|
||
values["search_engine"] = GoogleSearch | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import serpapi python package. " | ||
"Please it install it with `pip install google-search-results`." | ||
) | ||
return values | ||
|
||
async def arun(self, query: str) -> str: | ||
"""Use aiohttp to run query through SerpAPI and parse result.""" | ||
|
||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]: | ||
params = self.get_params(query) | ||
params["source"] = "python" | ||
if self.serpapi_api_key: | ||
params["serp_api_key"] = self.serpapi_api_key | ||
params["output"] = "json" | ||
url = "https://serpapi.com/search" | ||
return url, params | ||
|
||
url, params = construct_url_and_params() | ||
if not self.aiosession: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url, params=params) as response: | ||
res = await response.json() | ||
else: | ||
async with self.aiosession.get(url, params=params) as response: | ||
res = await response.json() | ||
|
||
return self._process_response( | ||
res, return_organic_results=self.return_organic_results | ||
) | ||
|
||
def run(self, query: str) -> str: | ||
"""Run query through SerpAPI and parse result.""" | ||
params = self.get_params(query) | ||
with HiddenPrints(): | ||
search = self.search_engine(params) | ||
res = search.get_dict() | ||
return self._process_response( | ||
res, return_organic_results=self.return_organic_results | ||
) | ||
|
||
def get_params(self, query: str) -> Dict[str, str]: | ||
"""Get parameters for SerpAPI.""" | ||
_params = { | ||
"api_key": self.serpapi_api_key, | ||
"q": query, | ||
} | ||
params = {**self.params, **_params} | ||
return params | ||
|
||
@staticmethod | ||
def _process_response(res: dict, return_organic_results: bool = False) -> str: | ||
"""Process response from SerpAPI.""" | ||
if "error" in res.keys(): | ||
raise ValueError(f"Got error from SerpAPI: {res['error']}") | ||
|
||
if return_organic_results: | ||
if "organic_results" in res.keys(): | ||
return "\n\n".join( | ||
[ | ||
f"* Title: {r.get('title')} \n * Snippet: {r.get('snippet')}" | ||
for r in res["organic_results"] | ||
] | ||
) | ||
|
||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): | ||
toret = res["answer_box"]["answer"] | ||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): | ||
toret = res["answer_box"]["snippet"] | ||
elif ( | ||
"answer_box" in res.keys() | ||
and "snippet_highlighted_words" in res["answer_box"].keys() | ||
): | ||
toret = res["answer_box"]["snippet_highlighted_words"][0] | ||
elif ( | ||
"sports_results" in res.keys() | ||
and "game_spotlight" in res["sports_results"].keys() | ||
): | ||
toret = res["sports_results"]["game_spotlight"] | ||
elif ( | ||
"knowledge_graph" in res.keys() | ||
and "description" in res["knowledge_graph"].keys() | ||
): | ||
toret = res["knowledge_graph"]["description"] | ||
elif "snippet" in res["organic_results"][0].keys(): | ||
toret = res["organic_results"][0]["snippet"] | ||
|
||
else: | ||
toret = "No good search result found" | ||
return toret |
Oops, something went wrong.