From f87ffa1d8db32b38c47d9f5aa2ae88f1d7978a04 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Fri, 27 Sep 2024 18:31:42 +0200 Subject: [PATCH] fix: integration with html_mode --- scrapegraphai/graphs/smart_scraper_graph.py | 44 +++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 4ffc6bed..65f03a24 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -70,7 +70,6 @@ def _create_graph(self) -> BaseGraph: "scrape_do": self.config.get("scrape_do") } ) - generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", @@ -82,14 +81,15 @@ def _create_graph(self) -> BaseGraph: } ) - if self.config.get("html_mode") is not True: - + if self.config.get("html_mode") is False: parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ "llm_model": self.llm_model, "chunk_size": self.model_token + } + ) if self.config.get("reasoning"): reasoning_node = ReasoningNode( @@ -102,17 +102,17 @@ def _create_graph(self) -> BaseGraph: } ) + if self.config.get("html_mode") is False and self.config.get("reasoning") is True: + return BaseGraph( nodes=[ fetch_node, parse_node, - reasoning_node, generate_answer_node, ], edges=[ (fetch_node, parse_node), - (parse_node, generate_answer_node) (parse_node, reasoning_node), (reasoning_node, generate_answer_node) ], @@ -120,18 +120,48 @@ def _create_graph(self) -> BaseGraph: graph_name=self.__class__.__name__ ) - return BaseGraph( + elif self.config.get("html_mode") is True and self.config.get("reasoning") is True: + + return BaseGraph( nodes=[ fetch_node, + reasoning_node, generate_answer_node, ], edges=[ - (fetch_node, generate_answer_node) + (fetch_node, reasoning_node), + (reasoning_node, generate_answer_node) ], entry_point=fetch_node, graph_name=self.__class__.__name__ ) + elif self.config.get("html_mode") is True and self.config.get("reasoning") is False: + return BaseGraph( + nodes=[ + fetch_node, + generate_answer_node, + ], + edges=[ + (fetch_node, generate_answer_node) + ], + entry_point=fetch_node, + graph_name=self.__class__.__name__ + ) + + return BaseGraph( + nodes=[ + fetch_node, + parse_node, + generate_answer_node, + ], + edges=[ + (fetch_node, parse_node), + (parse_node, generate_answer_node) + ], + entry_point=fetch_node, + graph_name=self.__class__.__name__ + ) def run(self) -> str: """