Skip to content

Commit

Permalink
Merge pull request #707 from ScrapeGraphAI/reasoning-branch
Browse files Browse the repository at this point in the history
Reasoning branch integration
  • Loading branch information
VinciGit00 authored Sep 28, 2024
2 parents 6d8f543 + f87ffa1 commit ac552bc
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _create_graph(self) -> BaseGraph:
"scrape_do": self.config.get("scrape_do")
}
)


generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
Expand All @@ -82,14 +81,15 @@ def _create_graph(self) -> BaseGraph:
}
)

if self.config.get("html_mode") is not True:

if self.config.get("html_mode") is False:
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"llm_model": self.llm_model,
"chunk_size": self.model_token
}
)

if self.config.get("reasoning"):
reasoning_node = ReasoningNode(
Expand All @@ -102,36 +102,66 @@ def _create_graph(self) -> BaseGraph:
}
)

if self.config.get("html_mode") is False and self.config.get("reasoning") is True:

return BaseGraph(
nodes=[
fetch_node,
parse_node,

reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, generate_answer_node)
(parse_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

return BaseGraph(
elif self.config.get("html_mode") is True and self.config.get("reasoning") is True:

return BaseGraph(
nodes=[
fetch_node,
reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, generate_answer_node)
(fetch_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

elif self.config.get("html_mode") is True and self.config.get("reasoning") is False:
return BaseGraph(
nodes=[
fetch_node,
generate_answer_node,
],
edges=[
(fetch_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

return BaseGraph(
nodes=[
fetch_node,
parse_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Expand Down

0 comments on commit ac552bc

Please sign in to comment.