Skip to content

Commit

Permalink
feat: start implementing checkpoint
Browse files Browse the repository at this point in the history
- add BaseCheckpointSaver
- add Checkpoint
- add CheckpointConfig
- add CompileConfig
- add InvokeConfig

work on #11
  • Loading branch information
bsorrentino committed Aug 6, 2024
1 parent 9ed434a commit f9800ec
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 7 deletions.
45 changes: 45 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.state.BaseCheckpointSaver;

import java.util.Optional;


public class CompileConfig {

private BaseCheckpointSaver checkpointSaver;
private String[] interruptBefore = {};
private String[] interruptAfter = {};

public Optional<BaseCheckpointSaver> getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); }
public String[] getInterruptBefore() { return interruptBefore; }
public String[] getInterruptAfter() { return interruptAfter; }

public static Builder builder() {
return new Builder();
}

public static class Builder {
private final CompileConfig config = new CompileConfig();

public Builder checkpointSaver(BaseCheckpointSaver checkpointSaver) {
this.config.checkpointSaver = checkpointSaver;
return this;
}
public Builder interruptBefore(String... interruptBefore) {
this.config.interruptBefore = interruptBefore;
return this;
}
public Builder interruptAfter(String... interruptAfter) {
this.config.interruptAfter = interruptAfter;
return this;
}
public CompileConfig build() {
return config;
}
}


private CompileConfig() {}

}
33 changes: 29 additions & 4 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ public class CompiledGraph<State extends AgentState> {
final Map<String, EdgeValue<State>> edges = new LinkedHashMap<>();

private int maxIterations = 25;
private final CompileConfig compileConfig;

/**
* Constructs a CompiledGraph with the given StateGraph.
*
* @param stateGraph the StateGraph to be used in this CompiledGraph
*/
protected CompiledGraph(StateGraph<State> stateGraph) {
protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfig ) {
this.stateGraph = stateGraph;
this.compileConfig = compileConfig;
stateGraph.nodes.forEach(n ->
nodes.put(n.id(), n.action())
);
Expand Down Expand Up @@ -109,10 +111,11 @@ private String getEntryPoint( State state ) throws Exception {
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, InvokeConfig config ) throws Exception {

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

Expand Down Expand Up @@ -171,16 +174,27 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro

}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
return this.stream( inputs, InvokeConfig.builder().build() );
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) throws Exception {

var sourceIterator = stream(inputs).iterator();
var sourceIterator = stream(inputs, config).iterator();

var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
Expand All @@ -189,6 +203,17 @@ public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return result.reduce((a, b) -> b).map( NodeOutput::state);
}

/**
* Invokes the graph execution with the provided inputs and returns the final state.
*
* @param inputs the input map
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return this.invoke( inputs, InvokeConfig.builder().build() );
}

/**
* Generates a drawable graph representation of the state graph.
*
Expand Down
2 changes: 1 addition & 1 deletion core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@Value
@Accessors(fluent = true)
class EdgeValue<State extends AgentState> {
public class EdgeValue<State extends AgentState> {

/**
* The unique identifier for the edge value.
Expand Down
39 changes: 39 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.state.CheckpointConfig;

import java.util.Optional;

public class InvokeConfig {

private CheckpointConfig checkpointConfig;

public Optional<CheckpointConfig> getCheckpointConfig() {
return Optional.ofNullable(checkpointConfig);
}

static Builder builder() {
return new Builder();
}

public static class Builder {

private String checkpointThreadId;

public Builder checkpointThreadId(String threadId) {
this.checkpointThreadId = threadId;
return this;
}
public InvokeConfig build() {
InvokeConfig result = new InvokeConfig();

if( checkpointThreadId != null ) {
result.checkpointConfig = CheckpointConfig.of(checkpointThreadId);
}

return result;
}
}

private InvokeConfig() {}
}
16 changes: 14 additions & 2 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ private Node<State> makeFakeNode(String id) {
/**
* Compiles the state graph into a compiled graph.
*
* @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph<State> compile() throws GraphStateException {
public CompiledGraph<State> compile( CompileConfig config ) throws GraphStateException {
if (entryPoint == null) {
throw Errors.missingEntryPoint.exception();
}
Expand Down Expand Up @@ -237,6 +238,17 @@ public CompiledGraph<State> compile() throws GraphStateException {
}
}

return new CompiledGraph<>(this);
return new CompiledGraph<>(this, config);
}

/**
* Compiles the state graph into a compiled graph.
*
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph<State> compile() throws GraphStateException {
return compile(CompileConfig.builder().build());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.bsc.langgraph4j.state;

import java.util.Collection;
import java.util.Optional;

public interface BaseCheckpointSaver {

Collection<Checkpoint> list();
Optional<Checkpoint> get( String id );
void put( Checkpoint checkpoint );
}
35 changes: 35 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.bsc.langgraph4j.state;

import lombok.Value;

import java.util.*;

public class Checkpoint {

@lombok.Value(staticConstructor="of")
public static class Value {
AgentState state;
String next;
}

private final String id;
private final Value value;

public final String getId() {
return id;
}
public final Value getValue() {
return value;
}

public Checkpoint( Value value ) {
this(UUID.randomUUID().toString(), value );
}
public Checkpoint(String id, Value value) {
Objects.requireNonNull(id, "id cannot be null");
Objects.requireNonNull(value, "value cannot be null");
this.id = id;
this.value = value;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.bsc.langgraph4j.state;

import lombok.Value;

@Value(staticConstructor = "of")
public class CheckpointConfig {
String threadId;
}

0 comments on commit f9800ec

Please sign in to comment.