From f193da8f1d3f6326a9011b0c047c29a7877d5513 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Wed, 11 Sep 2024 19:36:38 +0200 Subject: [PATCH] refactor(samples): expose StateGraph intead of CompiledGraph work on #24 --- .../langchain4j/adaptiverag/AdaptiveRag.java | 6 ++--- .../agentexecutor/AgentExecutor.java | 15 ++--------- .../agentexecutor/AgentExecutorTest.java | 25 +++++++++++-------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/adaptive-rag/src/main/java/dev/langchain4j/adaptiverag/AdaptiveRag.java b/adaptive-rag/src/main/java/dev/langchain4j/adaptiverag/AdaptiveRag.java index 7c63944..ab184f9 100644 --- a/adaptive-rag/src/main/java/dev/langchain4j/adaptiverag/AdaptiveRag.java +++ b/adaptive-rag/src/main/java/dev/langchain4j/adaptiverag/AdaptiveRag.java @@ -246,7 +246,7 @@ private String gradeGeneration_v_documentsAndQuestion( State state ) { return "not supported"; } - public CompiledGraph buildGraph() throws Exception { + public StateGraph buildGraph() throws Exception { return new StateGraph<>(State::new) // Define the nodes .addNode("web_search", node_async(this::webSearch) ) // web search @@ -280,7 +280,7 @@ public CompiledGraph buildGraph() throws Exception { "useful", END, "not useful", "transform_query" )) - .compile(); + ; } public static void main( String[] args ) throws Exception { @@ -290,7 +290,7 @@ public static void main( String[] args ) throws Exception { AdaptiveRag adaptiveRagTest = new AdaptiveRag( System.getenv("OPENAI_API_KEY"), System.getenv("TAVILY_API_KEY")); - var graph = adaptiveRagTest.buildGraph(); + var graph = adaptiveRagTest.buildGraph().compile(); var result = graph.stream( mapOf( "question", "What player at the Bears expected to draft first in the 2024 NFL draft?" ) ); // var result = graph.stream( mapOf( "question", "What kind the agent memory do iu know?" ) ); diff --git a/agent-executor/src/main/java/dev/langchain4j/agentexecutor/AgentExecutor.java b/agent-executor/src/main/java/dev/langchain4j/agentexecutor/AgentExecutor.java index ab255b4..6b66bd7 100644 --- a/agent-executor/src/main/java/dev/langchain4j/agentexecutor/AgentExecutor.java +++ b/agent-executor/src/main/java/dev/langchain4j/agentexecutor/AgentExecutor.java @@ -5,7 +5,6 @@ import dev.langchain4j.model.output.FinishReason; import lombok.var; import org.bsc.langgraph4j.*; -import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; import org.bsc.langgraph4j.state.AgentState; import org.bsc.langgraph4j.state.AppenderChannel; import org.bsc.langgraph4j.state.Channel; @@ -22,14 +21,9 @@ public class AgentExecutor { public class GraphBuilder { - private BaseCheckpointSaver checkpointSaver; private ChatLanguageModel chatLanguageModel; private List objectsWithTools; - public GraphBuilder checkpointSaver(BaseCheckpointSaver checkpointSaver) { - this.checkpointSaver = checkpointSaver; - return this; - } public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) { this.chatLanguageModel = chatLanguageModel; return this; @@ -39,7 +33,7 @@ public GraphBuilder objectsWithTools(List objectsWithTools) { return this; } - public CompiledGraph build() throws GraphStateException { + public StateGraph build() throws GraphStateException { Objects.requireNonNull(objectsWithTools, "objectsWithTools is required!"); Objects.requireNonNull(chatLanguageModel, "chatLanguageModel is required!"); @@ -55,11 +49,6 @@ public CompiledGraph build() throws GraphStateException { .tools( toolSpecifications ) .build(); - CompileConfig.Builder config = new CompileConfig.Builder(); - - if( checkpointSaver != null ) { - config.checkpointSaver(checkpointSaver); - } return new StateGraph<>(State.SCHEMA,State::new) .addEdge(START,"agent") @@ -75,7 +64,7 @@ public CompiledGraph build() throws GraphStateException { mapOf("continue", "action", "end", END) ) .addEdge("action", "agent") - .compile( config.build() ); + ; } } diff --git a/agent-executor/src/test/java/dev/langchain4j/agentexecutor/AgentExecutorTest.java b/agent-executor/src/test/java/dev/langchain4j/agentexecutor/AgentExecutorTest.java index 2c70a03..3aa659c 100644 --- a/agent-executor/src/test/java/dev/langchain4j/agentexecutor/AgentExecutorTest.java +++ b/agent-executor/src/test/java/dev/langchain4j/agentexecutor/AgentExecutorTest.java @@ -3,10 +3,7 @@ import dev.langchain4j.DotEnvConfig; import dev.langchain4j.model.openai.OpenAiChatModel; import lombok.var; -import org.bsc.langgraph4j.GraphRepresentation; -import org.bsc.langgraph4j.RunnableConfig; -import org.bsc.langgraph4j.NodeOutput; -import org.bsc.langgraph4j.StateGraph; +import org.bsc.langgraph4j.*; import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; import org.bsc.langgraph4j.checkpoint.MemorySaver; import org.bsc.langgraph4j.state.AgentState; @@ -31,7 +28,7 @@ public static void loadEnv() { DotEnvConfig.load(); } - private AgentExecutor.GraphBuilder newGraphBuilder() throws Exception { + private StateGraph newGraph() throws Exception { var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY") .orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!")); @@ -49,11 +46,13 @@ private AgentExecutor.GraphBuilder newGraphBuilder() throws Exception { return agentExecutor.graphBuilder() .chatLanguageModel(chatLanguageModel) - .objectsWithTools(listOf(new TestTool())); + .objectsWithTools(listOf(new TestTool())) + .build(); } private List executeAgent( String prompt ) throws Exception { - var iterator = newGraphBuilder().build().stream( mapOf( "input", prompt ) ); + + var iterator = newGraph().compile().stream( mapOf( "input", prompt ) ); return iterator.stream() .peek( s -> System.out.println( s.node() ) ) @@ -62,12 +61,16 @@ private List executeAgent( String prompt ) throws Exceptio } private List executeAgent(String prompt, String threadId, BaseCheckpointSaver saver) throws Exception { - var config = RunnableConfig.builder().threadId(threadId).build(); - var iterator = newGraphBuilder() + CompileConfig compileConfig = CompileConfig.builder() .checkpointSaver( saver ) - .build() - .stream( mapOf( "input", prompt ), config ); + .build(); + + var config = RunnableConfig.builder().threadId(threadId).build(); + + var graph = newGraph().compile( compileConfig ); + + var iterator = graph.stream( mapOf( "input", prompt ), config ); return iterator.stream() .peek( s -> System.out.println( s.node() ) )