Skip to content

Commit

Permalink
feat(CompiledGraph): expose state Management
Browse files Browse the repository at this point in the history
- getState
- getStateHistory
- updateState
- update streaming processing to support graph resume

work on #14
  • Loading branch information
bsorrentino committed Aug 26, 2024
1 parent cc86564 commit 183a0ce
Showing 1 changed file with 189 additions and 35 deletions.
224 changes: 189 additions & 35 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.serializer.MapSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.state.StateSnapshot;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Supplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;

/**
* Represents a compiled graph of nodes and edges.
Expand Down Expand Up @@ -55,20 +61,52 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
);
}

public Collection<StateSnapshot> getStateHistory( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

return saver.list(config).stream().map( checkpoint -> StateSnapshot.of( checkpoint, config ) ).collect(Collectors.toList());
}

public StateSnapshot getState( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

var checkpoint = saver.get(config).orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

return StateSnapshot.of( checkpoint, config );
}

public RunnableConfig updateState( RunnableConfig config, Map<String,Object> values, String asNode ) throws Exception {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

// merge values with checkpoint values
var updatedCheckpoint = saver.get(config)
.map( cp -> cp.updateState(values, stateGraph.getChannels()) )
.orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

// update checkpoint in saver
var newConfig = saver.put( config, updatedCheckpoint );

return RunnableConfig.builder(newConfig)
.checkPointId( updatedCheckpoint.getId() )
.nextNode( asNode )
.build();
}

public EdgeValue<State> getEntryPoint() {
return stateGraph.getEntryPoint();
}

public String getFinishPoint() {
return stateGraph.getFinishPoint();
}

/**
* Sets the maximum number of iterations for the graph execution.
*
* @param maxIterations the maximum number of iterations
* @throws IllegalArgumentException if maxIterations is less than or equal to 0
*/
void setMaxIterations(int maxIterations) {
public void setMaxIterations(int maxIterations) {
if( maxIterations <= 0 ) {
throw new IllegalArgumentException("maxIterations must be > 0!");
}
Expand Down Expand Up @@ -112,38 +150,124 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

private void addCheckpoint( String nodeId, State state ) throws Exception {
private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception {
if( compileConfig.getCheckpointSaver().isPresent() ) {
Checkpoint.Value value = Checkpoint.Value.of(state, nodeId);
compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) );
Checkpoint cp = Checkpoint.builder()
.nodeId( nodeId )
.state( state )
.nextNodeId( nextNodeId )
.build();
compileConfig.getCheckpointSaver().get().put( config, cp );
}
}

Map<String,Object> getInitialStateFromSchema() {
return stateGraph.getChannels().entrySet().stream()
return stateGraph.getChannels().entrySet().stream()
.filter( c -> c.getValue().getDefault().isPresent() )
.collect(Collectors.toMap(Map.Entry::getKey, e ->
e.getValue().getDefault().get().get()
));
}

State getInitialState(Map<String,Object> inputs) {
State getInitialState(Map<String,Object> inputs, RunnableConfig config) {

return compileConfig.getCheckpointSaver()
.flatMap(BaseCheckpointSaver::getLast)
.map( cp -> {
var state = cp.getValue().getState().mergeWith(inputs, stateGraph.getChannels());
return stateGraph.getStateFactory().apply(state);
})
.orElseGet( () -> {
var initialState =
stateGraph.getStateFactory()
.andThen( state -> state.mergeWith(inputs, stateGraph.getChannels()) )
.apply( getInitialStateFromSchema() );
return stateGraph.getStateFactory().apply(initialState);
});
.flatMap( saver -> saver.get( config ) )
.map( cp ->
stateGraph.getStateFactory()
.apply( AgentState.updateState( cp.getState(), inputs, stateGraph.getChannels() ) )
)
.orElseGet( () ->
stateGraph.getStateFactory()
.apply( AgentState.updateState(getInitialStateFromSchema(), inputs, stateGraph.getChannels() ) )
);
}

private State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundException {

Map<String,Object> newData = MapSerializer.INSTANCE.cloneObject(data);

return stateGraph.getStateFactory().apply(newData);
}

private State cloneState( State state ) throws IOException, ClassNotFoundException {
return cloneState(state.data());
}

private void yieldData(BlockingQueue<AsyncGenerator.Data<NodeOutput<State>>> queue, NodeOutput<State> data) {
queue.add( AsyncGenerator.Data.of( completedFuture( data) ) );
}

private void streamData(
State initialState,
String startNodeId,
RunnableConfig config,
Consumer<NodeOutput<State>> yieldData) throws Exception {

var currentState = initialState;


// if( initialNodeId != null ) {
//
// yieldData.accept( NodeOutput.of(initialNodeId, currentState));
//
// log.trace("START FROM NODE: {}", initialNodeId);
// }


var currentNodeId = startNodeId;
// addCheckpoint( config, initialNodeId, currentState, currentNodeId );


Map<String, Object> partialState;

int iteration = 0;

while( !Objects.equals(currentNodeId, END) ) {

log.trace( "NEXT NODE: {}", currentNodeId);

var action = nodes.get(currentNodeId);
if (action == null) {
throw StateGraph.RunnableErrors.missingNode.exception(currentNodeId);
}

partialState = action.apply(currentState).get();

currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) );

yieldData.accept( NodeOutput.of(currentNodeId, currentState) );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() );
break;
}

final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, currentState, nextNodeId );

currentNodeId = nextNodeId;

if ( Objects.equals(currentNodeId, END) ) {
break;
}

if( ++iteration > maxIterations ) {
log.warn( "Maximum number of iterations ({}) reached!", maxIterations);
break;
}

}

yieldData.accept( NodeOutput.of(END, currentState) );

addCheckpoint( config, END, currentState, END );

log.trace( "STOP");

}


/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
Expand All @@ -152,27 +276,53 @@ State getInitialState(Map<String,Object> inputs) {
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, RunnableConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

final boolean isResumeRequest = (inputs == null);

if( isResumeRequest ) {

BaseCheckpointSaver saver = compileConfig.getCheckpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));

Checkpoint startCheckpoint = saver.get( config ).orElseThrow( () -> (new IllegalStateException("Resume request without a saved checkpoint!")) );

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

try {
State startState = stateGraph.getStateFactory().apply( startCheckpoint.getState().data() );
streamData( startState,
startCheckpoint.getNextNodeId(),
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) )
);
}
catch (Throwable e) {
throw new RuntimeException( e );
}

});

}

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

try {

var currentState = getInitialState(inputs);
var currentState = cloneState( getInitialState(inputs, config) ) ;

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) ));
addCheckpoint( "start", currentState );
yieldData( queue, NodeOutput.of( START, currentState ) );

log.trace( "START");

var currentNodeId = this.getEntryPoint( currentState );
addCheckpoint( config, START, currentState, currentNodeId );

Map<String, Object> partialState;

int iteration = 0;

while( !Objects.equals(currentNodeId, StateGraph.END) ) {
while( !Objects.equals(currentNodeId, END) ) {

log.trace( "NEXT NODE: {}", currentNodeId);

Expand All @@ -183,18 +333,21 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok

partialState = action.apply(currentState).get();

currentState = stateGraph.getStateFactory().apply( currentState.mergeWith(partialState, stateGraph.getChannels()) );
currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) );

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) ));
addCheckpoint( currentNodeId, currentState );
yieldData( queue, NodeOutput.of(currentNodeId, currentState) );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() );
break;
}

currentNodeId = nextNodeId(currentNodeId, currentState);
final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, currentState, nextNodeId );

if ( Objects.equals(currentNodeId, StateGraph.END) ) {
currentNodeId = nextNodeId;

if ( Objects.equals(currentNodeId, END) ) {
break;
}

Expand All @@ -205,8 +358,9 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok

}

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) ));
addCheckpoint( "stop", currentState );
yieldData( queue, NodeOutput.of(END, currentState) );

addCheckpoint( config, END, currentState, END );
log.trace( "STOP");

} catch (Exception e) {
Expand All @@ -225,7 +379,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
return this.stream( inputs, InvokeConfig.builder().build() );
return this.stream( inputs, RunnableConfig.builder().build() );
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
Expand All @@ -235,7 +389,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
public Optional<State> invoke(Map<String,Object> inputs, RunnableConfig config ) throws Exception {

var sourceIterator = stream(inputs, config).iterator();

Expand All @@ -254,7 +408,7 @@ public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) t
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return this.invoke( inputs, InvokeConfig.builder().build() );
return this.invoke( inputs, RunnableConfig.builder().build() );
}

/**
Expand Down

0 comments on commit 183a0ce

Please sign in to comment.