diff --git a/scrapegraphai/graphs/scrape_graph.py b/scrapegraphai/graphs/scrape_graph.py new file mode 100644 index 00000000..a08149aa --- /dev/null +++ b/scrapegraphai/graphs/scrape_graph.py @@ -0,0 +1,98 @@ +""" +SmartScraperGraph Module +""" +from typing import Optional +from pydantic import BaseModel +from .base_graph import BaseGraph +from .abstract_graph import AbstractGraph +from ..nodes import ( + FetchNode, + ParseNode, +) + +class ScrapeGraph(AbstractGraph): + """ + ScrapeGraph is a scraping pipeline that automates the process of + extracting information from web pages. + + Attributes: + prompt (str): The prompt for the graph. + source (str): The source of the graph. + config (dict): Configuration parameters for the graph. + schema (BaseModel): The schema for the graph output. + verbose (bool): A flag indicating whether to show print statements during execution. + headless (bool): A flag indicating whether to run the graph in headless mode. + + Args: + prompt (str): The prompt for the graph. + source (str): The source of the graph. + config (dict): Configuration parameters for the graph. + schema (BaseModel): The schema for the graph output. + + Example: + >>> scraper = ScraperGraph( + ... "https://en.wikipedia.org/wiki/Chioggia", + ... {"llm": {"model": "openai/gpt-3.5-turbo"}} + ... ) + >>> result = smart_scraper.run() + ) + """ + + def __init__(self, source: str, config: dict, prompt: str = "", schema: Optional[BaseModel] = None): + super().__init__(prompt, config, source, schema) + + self.input_key = "url" if source.startswith("http") else "local_dir" + + def _create_graph(self) -> BaseGraph: + """ + Creates the graph of nodes representing the workflow for web scraping. + + Returns: + BaseGraph: A graph instance representing the web scraping workflow. + """ + fetch_node = FetchNode( + input="url| local_dir", + output=["doc"], + node_config={ + "llm_model": self.llm_model, + "force": self.config.get("force", False), + "cut": self.config.get("cut", True), + "loader_kwargs": self.config.get("loader_kwargs", {}), + "browser_base": self.config.get("browser_base"), + "scrape_do": self.config.get("scrape_do") + } + ) + + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + node_config={ + "llm_model": self.llm_model, + "chunk_size": self.model_token + } + ) + + return BaseGraph( + nodes=[ + fetch_node, + parse_node, + ], + edges=[ + (fetch_node, parse_node), + ], + entry_point=fetch_node, + graph_name=self.__class__.__name__ + ) + + def run(self) -> str: + """ + Executes the scraping process and returns the scraping content. + + Returns: + str: The scraping content. + """ + + inputs = {"user_prompt": self.prompt, self.input_key: self.source} + self.final_state, self.execution_info = self.graph.execute(inputs) + + return self.final_state.get("parsed_doc", "No document found.")