From f9800ec98253f5015775ff088b5486ce4ea3c4d7 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Tue, 6 Aug 2024 13:03:02 +0200 Subject: [PATCH] feat: start implementing checkpoint - add BaseCheckpointSaver - add Checkpoint - add CheckpointConfig - add CompileConfig - add InvokeConfig work on #11 --- .../org/bsc/langgraph4j/CompileConfig.java | 45 +++++++++++++++++++ .../org/bsc/langgraph4j/CompiledGraph.java | 33 ++++++++++++-- .../java/org/bsc/langgraph4j/EdgeValue.java | 2 +- .../org/bsc/langgraph4j/InvokeConfig.java | 39 ++++++++++++++++ .../java/org/bsc/langgraph4j/StateGraph.java | 16 ++++++- .../state/BaseCheckpointSaver.java | 11 +++++ .../org/bsc/langgraph4j/state/Checkpoint.java | 35 +++++++++++++++ .../langgraph4j/state/CheckpointConfig.java | 8 ++++ 8 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java new file mode 100644 index 0000000..2ca5415 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java @@ -0,0 +1,45 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.state.BaseCheckpointSaver; + +import java.util.Optional; + + +public class CompileConfig { + + private BaseCheckpointSaver checkpointSaver; + private String[] interruptBefore = {}; + private String[] interruptAfter = {}; + + public Optional getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); } + public String[] getInterruptBefore() { return interruptBefore; } + public String[] getInterruptAfter() { return interruptAfter; } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private final CompileConfig config = new CompileConfig(); + + public Builder checkpointSaver(BaseCheckpointSaver checkpointSaver) { + this.config.checkpointSaver = checkpointSaver; + return this; + } + public Builder interruptBefore(String... interruptBefore) { + this.config.interruptBefore = interruptBefore; + return this; + } + public Builder interruptAfter(String... interruptAfter) { + this.config.interruptAfter = interruptAfter; + return this; + } + public CompileConfig build() { + return config; + } + } + + + private CompileConfig() {} + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java index 6336fca..e205b89 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java @@ -31,14 +31,16 @@ public class CompiledGraph { final Map> edges = new LinkedHashMap<>(); private int maxIterations = 25; + private final CompileConfig compileConfig; /** * Constructs a CompiledGraph with the given StateGraph. * * @param stateGraph the StateGraph to be used in this CompiledGraph */ - protected CompiledGraph(StateGraph stateGraph) { + protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig ) { this.stateGraph = stateGraph; + this.compileConfig = compileConfig; stateGraph.nodes.forEach(n -> nodes.put(n.id(), n.action()) ); @@ -109,10 +111,11 @@ private String getEntryPoint( State state ) throws Exception { * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs. * * @param inputs the input map + * @param config the invoke configuration * @return an AsyncGenerator stream of NodeOutput * @throws Exception if there is an error creating the stream */ - public AsyncGenerator> stream(Map inputs ) throws Exception { + public AsyncGenerator> stream(Map inputs, InvokeConfig config ) throws Exception { return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> { @@ -171,16 +174,27 @@ public AsyncGenerator> stream(Map inputs ) thro } + /** + * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs. + * + * @param inputs the input map + * @return an AsyncGenerator stream of NodeOutput + * @throws Exception if there is an error creating the stream + */ + public AsyncGenerator> stream(Map inputs ) throws Exception { + return this.stream( inputs, InvokeConfig.builder().build() ); + } /** * Invokes the graph execution with the provided inputs and returns the final state. * * @param inputs the input map + * @param config the invoke configuration * @return an Optional containing the final state if present, otherwise an empty Optional * @throws Exception if there is an error during invocation */ - public Optional invoke(Map inputs ) throws Exception { + public Optional invoke(Map inputs, InvokeConfig config ) throws Exception { - var sourceIterator = stream(inputs).iterator(); + var sourceIterator = stream(inputs, config).iterator(); var result = StreamSupport.stream( Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED), @@ -189,6 +203,17 @@ public Optional invoke(Map inputs ) throws Exception { return result.reduce((a, b) -> b).map( NodeOutput::state); } + /** + * Invokes the graph execution with the provided inputs and returns the final state. + * + * @param inputs the input map + * @return an Optional containing the final state if present, otherwise an empty Optional + * @throws Exception if there is an error during invocation + */ + public Optional invoke(Map inputs ) throws Exception { + return this.invoke( inputs, InvokeConfig.builder().build() ); + } + /** * Generates a drawable graph representation of the state graph. * diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java index 2677ba2..21ef364 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java @@ -7,7 +7,7 @@ @Value @Accessors(fluent = true) -class EdgeValue { +public class EdgeValue { /** * The unique identifier for the edge value. diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java new file mode 100644 index 0000000..1588a35 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java @@ -0,0 +1,39 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.state.CheckpointConfig; + +import java.util.Optional; + +public class InvokeConfig { + + private CheckpointConfig checkpointConfig; + + public Optional getCheckpointConfig() { + return Optional.ofNullable(checkpointConfig); + } + + static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String checkpointThreadId; + + public Builder checkpointThreadId(String threadId) { + this.checkpointThreadId = threadId; + return this; + } + public InvokeConfig build() { + InvokeConfig result = new InvokeConfig(); + + if( checkpointThreadId != null ) { + result.checkpointConfig = CheckpointConfig.of(checkpointThreadId); + } + + return result; + } + } + + private InvokeConfig() {} +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java index 362528b..30ffa4d 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java @@ -198,10 +198,11 @@ private Node makeFakeNode(String id) { /** * Compiles the state graph into a compiled graph. * + * @param config the compile configuration * @return a compiled graph * @throws GraphStateException if there are errors related to the graph state */ - public CompiledGraph compile() throws GraphStateException { + public CompiledGraph compile( CompileConfig config ) throws GraphStateException { if (entryPoint == null) { throw Errors.missingEntryPoint.exception(); } @@ -237,6 +238,17 @@ public CompiledGraph compile() throws GraphStateException { } } - return new CompiledGraph<>(this); + return new CompiledGraph<>(this, config); } + + /** + * Compiles the state graph into a compiled graph. + * + * @return a compiled graph + * @throws GraphStateException if there are errors related to the graph state + */ + public CompiledGraph compile() throws GraphStateException { + return compile(CompileConfig.builder().build()); + } + } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java new file mode 100644 index 0000000..1877c76 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java @@ -0,0 +1,11 @@ +package org.bsc.langgraph4j.state; + +import java.util.Collection; +import java.util.Optional; + +public interface BaseCheckpointSaver { + + Collection list(); + Optional get( String id ); + void put( Checkpoint checkpoint ); +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java new file mode 100644 index 0000000..dad10fb --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java @@ -0,0 +1,35 @@ +package org.bsc.langgraph4j.state; + +import lombok.Value; + +import java.util.*; + +public class Checkpoint { + + @lombok.Value(staticConstructor="of") + public static class Value { + AgentState state; + String next; + } + + private final String id; + private final Value value; + + public final String getId() { + return id; + } + public final Value getValue() { + return value; + } + + public Checkpoint( Value value ) { + this(UUID.randomUUID().toString(), value ); + } + public Checkpoint(String id, Value value) { + Objects.requireNonNull(id, "id cannot be null"); + Objects.requireNonNull(value, "value cannot be null"); + this.id = id; + this.value = value; + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java new file mode 100644 index 0000000..0b33249 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java @@ -0,0 +1,8 @@ +package org.bsc.langgraph4j.state; + +import lombok.Value; + +@Value(staticConstructor = "of") +public class CheckpointConfig { + String threadId; +}