forked from ScrapeGraphAI/Scrapegraph-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement ScrapeGraph class for only web scraping automation
- Loading branch information
1 parent
e0fc457
commit 612c644
Showing
1 changed file
with
98 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
SmartScraperGraph Module | ||
""" | ||
from typing import Optional | ||
from pydantic import BaseModel | ||
from .base_graph import BaseGraph | ||
from .abstract_graph import AbstractGraph | ||
from ..nodes import ( | ||
FetchNode, | ||
ParseNode, | ||
) | ||
|
||
class ScrapeGraph(AbstractGraph): | ||
""" | ||
ScrapeGraph is a scraping pipeline that automates the process of | ||
extracting information from web pages. | ||
Attributes: | ||
prompt (str): The prompt for the graph. | ||
source (str): The source of the graph. | ||
config (dict): Configuration parameters for the graph. | ||
schema (BaseModel): The schema for the graph output. | ||
verbose (bool): A flag indicating whether to show print statements during execution. | ||
headless (bool): A flag indicating whether to run the graph in headless mode. | ||
Args: | ||
prompt (str): The prompt for the graph. | ||
source (str): The source of the graph. | ||
config (dict): Configuration parameters for the graph. | ||
schema (BaseModel): The schema for the graph output. | ||
Example: | ||
>>> scraper = ScraperGraph( | ||
... "https://en.wikipedia.org/wiki/Chioggia", | ||
... {"llm": {"model": "openai/gpt-3.5-turbo"}} | ||
... ) | ||
>>> result = smart_scraper.run() | ||
) | ||
""" | ||
|
||
def __init__(self, source: str, config: dict, prompt: str = "", schema: Optional[BaseModel] = None): | ||
super().__init__(prompt, config, source, schema) | ||
|
||
self.input_key = "url" if source.startswith("http") else "local_dir" | ||
|
||
def _create_graph(self) -> BaseGraph: | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping. | ||
Returns: | ||
BaseGraph: A graph instance representing the web scraping workflow. | ||
""" | ||
fetch_node = FetchNode( | ||
input="url| local_dir", | ||
output=["doc"], | ||
node_config={ | ||
"llm_model": self.llm_model, | ||
"force": self.config.get("force", False), | ||
"cut": self.config.get("cut", True), | ||
"loader_kwargs": self.config.get("loader_kwargs", {}), | ||
"browser_base": self.config.get("browser_base"), | ||
"scrape_do": self.config.get("scrape_do") | ||
} | ||
) | ||
|
||
parse_node = ParseNode( | ||
input="doc", | ||
output=["parsed_doc"], | ||
node_config={ | ||
"llm_model": self.llm_model, | ||
"chunk_size": self.model_token | ||
} | ||
) | ||
|
||
return BaseGraph( | ||
nodes=[ | ||
fetch_node, | ||
parse_node, | ||
], | ||
edges=[ | ||
(fetch_node, parse_node), | ||
], | ||
entry_point=fetch_node, | ||
graph_name=self.__class__.__name__ | ||
) | ||
|
||
def run(self) -> str: | ||
""" | ||
Executes the scraping process and returns the scraping content. | ||
Returns: | ||
str: The scraping content. | ||
""" | ||
|
||
inputs = {"user_prompt": self.prompt, self.input_key: self.source} | ||
self.final_state, self.execution_info = self.graph.execute(inputs) | ||
|
||
return self.final_state.get("parsed_doc", "No document found.") |