diff --git a/pipelines/examples/agents/ReAct_example.py b/pipelines/examples/agents/ReAct_example.py index 95ad00aba0d2..4496ad66eda6 100644 --- a/pipelines/examples/agents/ReAct_example.py +++ b/pipelines/examples/agents/ReAct_example.py @@ -82,7 +82,7 @@ # yapf: disable parser = argparse.ArgumentParser() -parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ") parser.add_argument("--api_key", default=None, type=str, help="The API Key.") args = parser.parse_args() @@ -97,7 +97,7 @@ def search_and_action_example(): default_prompt_template="question-answering-with-document-scores", ) - # https://serpapi.com/dashboard + # https://serper.dev web_retriever = WebRetriever(api_key=args.search_api_key, top_search_results=2) pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=pn) diff --git a/pipelines/examples/agents/ReAct_example_cn.py b/pipelines/examples/agents/ReAct_example_cn.py index 801816db8987..967381e0e104 100644 --- a/pipelines/examples/agents/ReAct_example_cn.py +++ b/pipelines/examples/agents/ReAct_example_cn.py @@ -60,7 +60,7 @@ parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") -parser.add_argument("--retriever", choices=['dense', 'SerpAPI'], default="dense", help="The type of Retriever.") +parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI'], default="dense", help="The type of Retriever.") parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") @@ -68,7 +68,7 @@ parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path") parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path") parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") -parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding") parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types") parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ") @@ -213,6 +213,6 @@ def search_and_action_example(web_retriever): use_gpu = True if args.device == "gpu" else False web_retriever = get_faiss_retriever(use_gpu) else: - # https://serpapi.com/dashboard - web_retriever = WebRetriever(api_key=args.search_api_key, engine="bing", top_search_results=2) + # https://serper.dev + web_retriever = WebRetriever(api_key=args.search_api_key, engine="google", top_search_results=2) search_and_action_example(web_retriever) diff --git a/pipelines/pipelines/nodes/search_engine/providers.py b/pipelines/pipelines/nodes/search_engine/providers.py index f2eb382e8a98..9e8833968bbe 100644 --- a/pipelines/pipelines/nodes/search_engine/providers.py +++ b/pipelines/pipelines/nodes/search_engine/providers.py @@ -130,3 +130,112 @@ def search(self, query: str, **kwargs) -> List[Document]: logger.debug("SerpAPI returned %s documents for the query '%s'", len(documents), query) result_docs = documents[:top_k] return self.score_results(result_docs, len(answer_box) > 0) + + +class SerperDev(SearchEngine): + """ + Serper.dev is a search engine that provides a REST API to access search results from Google. See the [Serper.dev website](https://serper.dev.com/) for more details. + """ + + def __init__( + self, + api_key: str, + top_k: Optional[int] = 10, + engine: Optional[str] = "google", + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param api_key: API key for Serper.dev API. + :param top_k: Number of results to return. + :param engine: Search engine to use, only supports Google. + :param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'hl' to 'en' + to set the search results language to English. + See the [Serper.dev documentation](https://serper.dev/playground) for the full list of supported parameters. + """ + super().__init__() + self.params_dict: Dict[str, Union[str, int, float]] = {} + self.api_key = api_key + self.kwargs = search_engine_kwargs if search_engine_kwargs else {} + self.engine = engine + self.top_k = top_k + + def search(self, query: str, **kwargs) -> List[Document]: + """ + :param query: Query string. + :param kwargs: Additional parameters passed to the Serper.dev API. For example, you can set 'hl' to 'en' + to set the search results language to English. + See the [Serper.dev documentation](https://serper.dev/playground) for the full list of supported parameters. + :return: List[Document] + """ + kwargs = {**self.kwargs, **kwargs} + top_k = kwargs.pop("top_k", self.top_k) + url = "https://google.serper.dev/search" + + params = {"q": query, **kwargs} + + headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} + + response = requests.post(url, headers=headers, json=params, timeout=30) + + if response.status_code != 200: + raise Exception(f"Error while querying {self.__class__.__name__}: {response.text}") + + json_result = json.loads(response.text) + organic = [ + Document.from_dict(d, field_map={"snippet": "content"}) for d in json_result["organic"] if "snippet" in d + ] + answer_box = [] + if "answerBox" in json_result: + answer_dict = json_result["answerBox"] + for key in ["answer", "snippetHighlighted", "snippet", "title"]: + if key in answer_dict: + answer_box_content = answer_dict[key] + if isinstance(answer_box_content, list): + answer_box_content = answer_box_content[0] + answer_box = [ + Document.from_dict( + { + "title": answer_dict.get("title", ""), + "content": answer_box_content, + "link": answer_dict.get("link", ""), + } + ) + ] + break + + people_also_search = [] + if "peopleAlsoSearchFor" in json_result: + for result in json_result["peopleAlsoSearchFor"]: + people_also_search.append( + Document.from_dict( + { + "title": result["title"], + "content": result["snippet"] if result.get("snippet") else result["title"], + "link": result["link"], + } + ) + ) + + related_searches = [] + if "relatedSearches" in json_result: + for result in json_result["relatedSearches"]: + related_searches.append(Document.from_dict({"content": result.get("query", "")})) + + related_questions = [] + if "peopleAlsoAsk" in json_result: + for result in json_result["peopleAlsoAsk"]: + related_questions.append( + Document.from_dict( + { + "title": result["title"], + "content": result["snippet"] if result.get("snippet") else result["title"], + "link": result["link"], + } + ) + ) + + documents = answer_box + organic + people_also_search + related_searches + related_questions + + logger.debug("Serper.dev API returned %s documents for the query '%s'", len(documents), query) + result_docs = documents[:top_k] + return self.score_results(result_docs, len(answer_box) > 0)