-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement SmartScraperMultiParseMergeFirstGraph class that scra…
…pes a list of URLs and merge the content first and finally generates answers to a given prompt. (Different from the SmartScraperMultiGraph is that in this case the content is merged before to be processed by the llm.)
- Loading branch information
1 parent
612c644
commit 3e3e1b2
Showing
2 changed files
with
105 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
scrapegraphai/graphs/smart_scraper_multi_parse_merge_first_graph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
SmartScraperMultiGraph Module | ||
""" | ||
from copy import deepcopy | ||
from typing import List, Optional | ||
from pydantic import BaseModel | ||
from .base_graph import BaseGraph | ||
from .abstract_graph import AbstractGraph | ||
from .scrape_graph import ScrapeGraph | ||
from ..nodes import ( | ||
GraphIteratorNode, | ||
MergeAnswersNode, | ||
) | ||
from ..utils.copy import safe_deepcopy | ||
|
||
class SmartScraperMultiParseMergeFirstGraph(AbstractGraph): | ||
""" | ||
SmartScraperMultiParseMergeFirstGraph is a scraping pipeline that scrapes a | ||
list of URLs and merge the content first and finally generates answers to a given prompt. | ||
It only requires a user prompt and a list of URLs. | ||
The difference with the SmartScraperMultiGraph is that in this case the content is merged | ||
before to be passed to the llm. | ||
Attributes: | ||
prompt (str): The user prompt to search the internet. | ||
llm_model (dict): The configuration for the language model. | ||
embedder_model (dict): The configuration for the embedder model. | ||
headless (bool): A flag to run the browser in headless mode. | ||
verbose (bool): A flag to display the execution information. | ||
model_token (int): The token limit for the language model. | ||
Args: | ||
prompt (str): The user prompt to search the internet. | ||
source (List[str]): The source of the graph. | ||
config (dict): Configuration parameters for the graph. | ||
schema (Optional[BaseModel]): The schema for the graph output. | ||
Example: | ||
>>> search_graph = SmartScraperMultiParseMergeFirstGraph( | ||
... prompt="Who is Marco Perini?", | ||
... source= [ | ||
... "https://perinim.github.io/", | ||
... "https://perinim.github.io/cv/" | ||
... ], | ||
... config={"llm": {"model": "openai/gpt-3.5-turbo"}} | ||
... ) | ||
>>> result = search_graph.run() | ||
""" | ||
|
||
def __init__(self, prompt: str, source: List[str], | ||
config: dict, schema: Optional[BaseModel] = None): | ||
|
||
self.copy_config = safe_deepcopy(config) | ||
self.copy_schema = deepcopy(schema) | ||
super().__init__(prompt, config, source, schema) | ||
|
||
def _create_graph(self) -> BaseGraph: | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping | ||
and parsing and then merge the content and generates answers to a given prompt. | ||
""" | ||
graph_iterator_node = GraphIteratorNode( | ||
input="user_prompt & urls", | ||
output=["parsed_doc"], | ||
node_config={ | ||
"graph_instance": ScrapeGraph, | ||
"scraper_config": self.copy_config, | ||
}, | ||
schema=self.copy_schema | ||
) | ||
|
||
merge_answers_node = MergeAnswersNode( | ||
input="user_prompt & parsed_doc", | ||
output=["answer"], | ||
node_config={ | ||
"llm_model": self.llm_model, | ||
"schema": self.copy_schema | ||
} | ||
) | ||
|
||
return BaseGraph( | ||
nodes=[ | ||
graph_iterator_node, | ||
merge_answers_node, | ||
], | ||
edges=[ | ||
(graph_iterator_node, merge_answers_node), | ||
], | ||
entry_point=graph_iterator_node, | ||
graph_name=self.__class__.__name__ | ||
) | ||
|
||
def run(self) -> str: | ||
""" | ||
Executes the web scraping and parsing process first and | ||
then concatenate the content and generates answers to a given prompt. | ||
Returns: | ||
str: The answer to the prompt. | ||
""" | ||
inputs = {"user_prompt": self.prompt, "urls": self.source} | ||
self.final_state, self.execution_info = self.graph.execute(inputs) | ||
return self.final_state.get("answer", "No answer found.") |