Skip to content

Commit

Permalink
refactor: rename getCheckpointSaver() to checkpointSaver() because it…
Browse files Browse the repository at this point in the history
… returns an Optional
  • Loading branch information
bsorrentino committed Aug 28, 2024
1 parent a00054b commit fd072d2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.bsc.langgraph4j;

import lombok.Getter;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;

import java.util.Optional;
Expand All @@ -8,12 +9,12 @@
public class CompileConfig {

private BaseCheckpointSaver checkpointSaver;
@Getter
private String[] interruptBefore = {};
@Getter
private String[] interruptAfter = {};

public Optional<BaseCheckpointSaver> getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); }
public String[] getInterruptBefore() { return interruptBefore; }
public String[] getInterruptAfter() { return interruptAfter; }
public Optional<BaseCheckpointSaver> checkpointSaver() { return Optional.ofNullable(checkpointSaver); }

public static Builder builder() {
return new Builder();
Expand Down
14 changes: 7 additions & 7 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,23 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
}

public Collection<StateSnapshot<State>> getStateHistory( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );
var saver = compileConfig.checkpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

return saver.list(config).stream()
.map( checkpoint -> StateSnapshot.of( checkpoint, config, stateGraph.getStateFactory() ) )
.collect(Collectors.toList());
}

public StateSnapshot<State> getState( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );
var saver = compileConfig.checkpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

var checkpoint = saver.get(config).orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

return StateSnapshot.of( checkpoint, config, stateGraph.getStateFactory() );
}

public RunnableConfig updateState( RunnableConfig config, Map<String,Object> values, String asNode ) throws Exception {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );
var saver = compileConfig.checkpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

// merge values with checkpoint values
var updatedCheckpoint = saver.get(config)
Expand Down Expand Up @@ -152,13 +152,13 @@ private String getEntryPoint( State state ) throws Exception {
}

private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception {
if( compileConfig.getCheckpointSaver().isPresent() ) {
if( compileConfig.checkpointSaver().isPresent() ) {
Checkpoint cp = Checkpoint.builder()
.nodeId( nodeId )
.state( state )
.nextNodeId( nextNodeId )
.build();
compileConfig.getCheckpointSaver().get().put( config, cp );
compileConfig.checkpointSaver().get().put( config, cp );
}
}

Expand All @@ -172,7 +172,7 @@ Map<String,Object> getInitialStateFromSchema() {

Map<String,Object> getInitialState(Map<String,Object> inputs, RunnableConfig config) {

return compileConfig.getCheckpointSaver()
return compileConfig.checkpointSaver()
.flatMap( saver -> saver.get( config ) )
.map( cp -> AgentState.updateState( cp.getState(), inputs, stateGraph.getChannels() ))
.orElseGet( () -> AgentState.updateState(getInitialStateFromSchema(), inputs, stateGraph.getChannels() ));
Expand Down Expand Up @@ -258,7 +258,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna

if( isResumeRequest ) {

BaseCheckpointSaver saver = compileConfig.getCheckpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));
BaseCheckpointSaver saver = compileConfig.checkpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));

Checkpoint startCheckpoint = saver.get( config ).orElseThrow( () -> (new IllegalStateException("Resume request without a saved checkpoint!")) );

Expand Down

0 comments on commit fd072d2

Please sign in to comment.