Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-Pipelines] Serper.dev search engine provider (Google Search API) #6192

Merged
merged 2 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pipelines/examples/agents/ReAct_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.")
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.")
parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ")
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
args = parser.parse_args()
Expand All @@ -97,7 +97,7 @@ def search_and_action_example():
default_prompt_template="question-answering-with-document-scores",
)

# https://serpapi.com/dashboard
# https://serper.dev
web_retriever = WebRetriever(api_key=args.search_api_key, top_search_results=2)
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=pn)

Expand Down
8 changes: 4 additions & 4 deletions pipelines/examples/agents/ReAct_example_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.")
parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.")
parser.add_argument("--retriever", choices=['dense', 'SerpAPI'], default="dense", help="The type of Retriever.")
parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI'], default="dense", help="The type of Retriever.")
parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.")
parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.")
parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.")
parser.add_argument("--query_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The query_embedding_model path")
parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path")
parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path")
parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.")
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.")
parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding")
parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types")
parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ")
Expand Down Expand Up @@ -213,6 +213,6 @@ def search_and_action_example(web_retriever):
use_gpu = True if args.device == "gpu" else False
web_retriever = get_faiss_retriever(use_gpu)
else:
# https://serpapi.com/dashboard
web_retriever = WebRetriever(api_key=args.search_api_key, engine="bing", top_search_results=2)
# https://serper.dev
web_retriever = WebRetriever(api_key=args.search_api_key, engine="google", top_search_results=2)
search_and_action_example(web_retriever)
109 changes: 109 additions & 0 deletions pipelines/pipelines/nodes/search_engine/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,112 @@ def search(self, query: str, **kwargs) -> List[Document]:
logger.debug("SerpAPI returned %s documents for the query '%s'", len(documents), query)
result_docs = documents[:top_k]
return self.score_results(result_docs, len(answer_box) > 0)


class SerperDev(SearchEngine):
"""
Serper.dev is a search engine that provides a REST API to access search results from Google. See the [Serper.dev website](https://serper.dev.com/) for more details.
"""

def __init__(
self,
api_key: str,
top_k: Optional[int] = 10,
engine: Optional[str] = "google",
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for Serper.dev API.
:param top_k: Number of results to return.
:param engine: Search engine to use, only supports Google.
:param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'hl' to 'en'
to set the search results language to English.
See the [Serper.dev documentation](https://serper.dev/playground) for the full list of supported parameters.
"""
super().__init__()
self.params_dict: Dict[str, Union[str, int, float]] = {}
self.api_key = api_key
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
self.engine = engine
self.top_k = top_k

def search(self, query: str, **kwargs) -> List[Document]:
"""
:param query: Query string.
:param kwargs: Additional parameters passed to the Serper.dev API. For example, you can set 'hl' to 'en'
to set the search results language to English.
See the [Serper.dev documentation](https://serper.dev/playground) for the full list of supported parameters.
:return: List[Document]
"""
kwargs = {**self.kwargs, **kwargs}
top_k = kwargs.pop("top_k", self.top_k)
url = "https://google.serper.dev/search"

params = {"q": query, **kwargs}

headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}

response = requests.post(url, headers=headers, json=params, timeout=30)

if response.status_code != 200:
raise Exception(f"Error while querying {self.__class__.__name__}: {response.text}")

json_result = json.loads(response.text)
organic = [
Document.from_dict(d, field_map={"snippet": "content"}) for d in json_result["organic"] if "snippet" in d
]
answer_box = []
if "answerBox" in json_result:
answer_dict = json_result["answerBox"]
for key in ["answer", "snippetHighlighted", "snippet", "title"]:
if key in answer_dict:
answer_box_content = answer_dict[key]
if isinstance(answer_box_content, list):
answer_box_content = answer_box_content[0]
answer_box = [
Document.from_dict(
{
"title": answer_dict.get("title", ""),
"content": answer_box_content,
"link": answer_dict.get("link", ""),
}
)
]
break

people_also_search = []
if "peopleAlsoSearchFor" in json_result:
for result in json_result["peopleAlsoSearchFor"]:
people_also_search.append(
Document.from_dict(
{
"title": result["title"],
"content": result["snippet"] if result.get("snippet") else result["title"],
"link": result["link"],
}
)
)

related_searches = []
if "relatedSearches" in json_result:
for result in json_result["relatedSearches"]:
related_searches.append(Document.from_dict({"content": result.get("query", "")}))

related_questions = []
if "peopleAlsoAsk" in json_result:
for result in json_result["peopleAlsoAsk"]:
related_questions.append(
Document.from_dict(
{
"title": result["title"],
"content": result["snippet"] if result.get("snippet") else result["title"],
"link": result["link"],
}
)
)

documents = answer_box + organic + people_also_search + related_searches + related_questions

logger.debug("Serper.dev API returned %s documents for the query '%s'", len(documents), query)
result_docs = documents[:top_k]
return self.score_results(result_docs, len(answer_box) > 0)