Skip to content

Commit

Permalink
refactor(samples): expose StateGraph intead of CompiledGraph
Browse files Browse the repository at this point in the history
work on #24
  • Loading branch information
bsorrentino committed Sep 11, 2024
1 parent 7474b86 commit f193da8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private String gradeGeneration_v_documentsAndQuestion( State state ) {
return "not supported";
}

public CompiledGraph<State> buildGraph() throws Exception {
public StateGraph<State> buildGraph() throws Exception {
return new StateGraph<>(State::new)
// Define the nodes
.addNode("web_search", node_async(this::webSearch) ) // web search
Expand Down Expand Up @@ -280,7 +280,7 @@ public CompiledGraph<State> buildGraph() throws Exception {
"useful", END,
"not useful", "transform_query"
))
.compile();
;
}

public static void main( String[] args ) throws Exception {
Expand All @@ -290,7 +290,7 @@ public static void main( String[] args ) throws Exception {

AdaptiveRag adaptiveRagTest = new AdaptiveRag( System.getenv("OPENAI_API_KEY"), System.getenv("TAVILY_API_KEY"));

var graph = adaptiveRagTest.buildGraph();
var graph = adaptiveRagTest.buildGraph().compile();

var result = graph.stream( mapOf( "question", "What player at the Bears expected to draft first in the 2024 NFL draft?" ) );
// var result = graph.stream( mapOf( "question", "What kind the agent memory do iu know?" ) );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dev.langchain4j.model.output.FinishReason;
import lombok.var;
import org.bsc.langgraph4j.*;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;
Expand All @@ -22,14 +21,9 @@
public class AgentExecutor {

public class GraphBuilder {
private BaseCheckpointSaver checkpointSaver;
private ChatLanguageModel chatLanguageModel;
private List<Object> objectsWithTools;

public GraphBuilder checkpointSaver(BaseCheckpointSaver checkpointSaver) {
this.checkpointSaver = checkpointSaver;
return this;
}
public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
return this;
Expand All @@ -39,7 +33,7 @@ public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
return this;
}

public CompiledGraph<State> build() throws GraphStateException {
public StateGraph<State> build() throws GraphStateException {
Objects.requireNonNull(objectsWithTools, "objectsWithTools is required!");
Objects.requireNonNull(chatLanguageModel, "chatLanguageModel is required!");

Expand All @@ -55,11 +49,6 @@ public CompiledGraph<State> build() throws GraphStateException {
.tools( toolSpecifications )
.build();

CompileConfig.Builder config = new CompileConfig.Builder();

if( checkpointSaver != null ) {
config.checkpointSaver(checkpointSaver);
}

return new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
Expand All @@ -75,7 +64,7 @@ public CompiledGraph<State> build() throws GraphStateException {
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile( config.build() );
;

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import dev.langchain4j.DotEnvConfig;
import dev.langchain4j.model.openai.OpenAiChatModel;
import lombok.var;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.*;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.MemorySaver;
import org.bsc.langgraph4j.state.AgentState;
Expand All @@ -31,7 +28,7 @@ public static void loadEnv() {
DotEnvConfig.load();
}

private AgentExecutor.GraphBuilder newGraphBuilder() throws Exception {
private StateGraph<AgentExecutor.State> newGraph() throws Exception {

var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY")
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!"));
Expand All @@ -49,11 +46,13 @@ private AgentExecutor.GraphBuilder newGraphBuilder() throws Exception {

return agentExecutor.graphBuilder()
.chatLanguageModel(chatLanguageModel)
.objectsWithTools(listOf(new TestTool()));
.objectsWithTools(listOf(new TestTool()))
.build();
}

private List<AgentExecutor.State> executeAgent( String prompt ) throws Exception {
var iterator = newGraphBuilder().build().stream( mapOf( "input", prompt ) );

var iterator = newGraph().compile().stream( mapOf( "input", prompt ) );

return iterator.stream()
.peek( s -> System.out.println( s.node() ) )
Expand All @@ -62,12 +61,16 @@ private List<AgentExecutor.State> executeAgent( String prompt ) throws Exceptio
}

private List<AgentExecutor.State> executeAgent(String prompt, String threadId, BaseCheckpointSaver saver) throws Exception {
var config = RunnableConfig.builder().threadId(threadId).build();

var iterator = newGraphBuilder()
CompileConfig compileConfig = CompileConfig.builder()
.checkpointSaver( saver )
.build()
.stream( mapOf( "input", prompt ), config );
.build();

var config = RunnableConfig.builder().threadId(threadId).build();

var graph = newGraph().compile( compileConfig );

var iterator = graph.stream( mapOf( "input", prompt ), config );

return iterator.stream()
.peek( s -> System.out.println( s.node() ) )
Expand Down

0 comments on commit f193da8

Please sign in to comment.