From 183a0ceb7069f1f96b2e69590843c9ed08f7818b Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Mon, 26 Aug 2024 22:32:50 +0200 Subject: [PATCH] feat(CompiledGraph): expose state Management - getState - getStateHistory - updateState - update streaming processing to support graph resume work on #14 --- .../org/bsc/langgraph4j/CompiledGraph.java | 224 +++++++++++++++--- 1 file changed, 189 insertions(+), 35 deletions(-) diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java index 0cfa7ae..6a904b1 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java @@ -8,17 +8,23 @@ import org.bsc.langgraph4j.action.AsyncNodeAction; import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; import org.bsc.langgraph4j.checkpoint.Checkpoint; +import org.bsc.langgraph4j.serializer.MapSerializer; import org.bsc.langgraph4j.state.AgentState; -import org.bsc.langgraph4j.state.Channel; +import org.bsc.langgraph4j.state.StateSnapshot; +import java.io.IOException; import java.util.*; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; -import java.util.function.Supplier; +import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; import static java.lang.String.format; import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.bsc.langgraph4j.StateGraph.END; +import static org.bsc.langgraph4j.StateGraph.START; /** * Represents a compiled graph of nodes and edges. @@ -55,6 +61,37 @@ protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfi ); } + public Collection getStateHistory( RunnableConfig config ) { + var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) ); + + return saver.list(config).stream().map( checkpoint -> StateSnapshot.of( checkpoint, config ) ).collect(Collectors.toList()); + } + + public StateSnapshot getState( RunnableConfig config ) { + var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) ); + + var checkpoint = saver.get(config).orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) ); + + return StateSnapshot.of( checkpoint, config ); + } + + public RunnableConfig updateState( RunnableConfig config, Map values, String asNode ) throws Exception { + var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) ); + + // merge values with checkpoint values + var updatedCheckpoint = saver.get(config) + .map( cp -> cp.updateState(values, stateGraph.getChannels()) ) + .orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) ); + + // update checkpoint in saver + var newConfig = saver.put( config, updatedCheckpoint ); + + return RunnableConfig.builder(newConfig) + .checkPointId( updatedCheckpoint.getId() ) + .nextNode( asNode ) + .build(); + } + public EdgeValue getEntryPoint() { return stateGraph.getEntryPoint(); } @@ -62,13 +99,14 @@ public EdgeValue getEntryPoint() { public String getFinishPoint() { return stateGraph.getFinishPoint(); } + /** * Sets the maximum number of iterations for the graph execution. * * @param maxIterations the maximum number of iterations * @throws IllegalArgumentException if maxIterations is less than or equal to 0 */ - void setMaxIterations(int maxIterations) { + public void setMaxIterations(int maxIterations) { if( maxIterations <= 0 ) { throw new IllegalArgumentException("maxIterations must be > 0!"); } @@ -112,38 +150,124 @@ private String getEntryPoint( State state ) throws Exception { return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint"); } - private void addCheckpoint( String nodeId, State state ) throws Exception { + private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception { if( compileConfig.getCheckpointSaver().isPresent() ) { - Checkpoint.Value value = Checkpoint.Value.of(state, nodeId); - compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) ); + Checkpoint cp = Checkpoint.builder() + .nodeId( nodeId ) + .state( state ) + .nextNodeId( nextNodeId ) + .build(); + compileConfig.getCheckpointSaver().get().put( config, cp ); } } Map getInitialStateFromSchema() { - return stateGraph.getChannels().entrySet().stream() + return stateGraph.getChannels().entrySet().stream() .filter( c -> c.getValue().getDefault().isPresent() ) .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getDefault().get().get() )); } - State getInitialState(Map inputs) { + State getInitialState(Map inputs, RunnableConfig config) { return compileConfig.getCheckpointSaver() - .flatMap(BaseCheckpointSaver::getLast) - .map( cp -> { - var state = cp.getValue().getState().mergeWith(inputs, stateGraph.getChannels()); - return stateGraph.getStateFactory().apply(state); - }) - .orElseGet( () -> { - var initialState = - stateGraph.getStateFactory() - .andThen( state -> state.mergeWith(inputs, stateGraph.getChannels()) ) - .apply( getInitialStateFromSchema() ); - return stateGraph.getStateFactory().apply(initialState); - }); + .flatMap( saver -> saver.get( config ) ) + .map( cp -> + stateGraph.getStateFactory() + .apply( AgentState.updateState( cp.getState(), inputs, stateGraph.getChannels() ) ) + ) + .orElseGet( () -> + stateGraph.getStateFactory() + .apply( AgentState.updateState(getInitialStateFromSchema(), inputs, stateGraph.getChannels() ) ) + ); + } + + private State cloneState( Map data ) throws IOException, ClassNotFoundException { + + Map newData = MapSerializer.INSTANCE.cloneObject(data); + + return stateGraph.getStateFactory().apply(newData); + } + + private State cloneState( State state ) throws IOException, ClassNotFoundException { + return cloneState(state.data()); + } + + private void yieldData(BlockingQueue>> queue, NodeOutput data) { + queue.add( AsyncGenerator.Data.of( completedFuture( data) ) ); + } + + private void streamData( + State initialState, + String startNodeId, + RunnableConfig config, + Consumer> yieldData) throws Exception { + + var currentState = initialState; + + +// if( initialNodeId != null ) { +// +// yieldData.accept( NodeOutput.of(initialNodeId, currentState)); +// +// log.trace("START FROM NODE: {}", initialNodeId); +// } + + + var currentNodeId = startNodeId; +// addCheckpoint( config, initialNodeId, currentState, currentNodeId ); + + + Map partialState; + + int iteration = 0; + + while( !Objects.equals(currentNodeId, END) ) { + + log.trace( "NEXT NODE: {}", currentNodeId); + + var action = nodes.get(currentNodeId); + if (action == null) { + throw StateGraph.RunnableErrors.missingNode.exception(currentNodeId); + } + + partialState = action.apply(currentState).get(); + + currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) ); + + yieldData.accept( NodeOutput.of(currentNodeId, currentState) ); + + if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) { + addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() ); + break; + } + + final String nextNodeId = nextNodeId(currentNodeId, currentState); + addCheckpoint( config, currentNodeId, currentState, nextNodeId ); + + currentNodeId = nextNodeId; + + if ( Objects.equals(currentNodeId, END) ) { + break; + } + + if( ++iteration > maxIterations ) { + log.warn( "Maximum number of iterations ({}) reached!", maxIterations); + break; + } + + } + + yieldData.accept( NodeOutput.of(END, currentState) ); + + addCheckpoint( config, END, currentState, END ); + + log.trace( "STOP"); + } + /** * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs. * @@ -152,27 +276,53 @@ State getInitialState(Map inputs) { * @return an AsyncGenerator stream of NodeOutput * @throws Exception if there is an error creating the stream */ - public AsyncGenerator> stream(Map inputs, InvokeConfig config ) throws Exception { + public AsyncGenerator> stream(Map inputs, RunnableConfig config ) throws Exception { Objects.requireNonNull(config, "config cannot be null"); + final boolean isResumeRequest = (inputs == null); + + if( isResumeRequest ) { + + BaseCheckpointSaver saver = compileConfig.getCheckpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured"))); + + Checkpoint startCheckpoint = saver.get( config ).orElseThrow( () -> (new IllegalStateException("Resume request without a saved checkpoint!")) ); + + return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> { + + try { + State startState = stateGraph.getStateFactory().apply( startCheckpoint.getState().data() ); + streamData( startState, + startCheckpoint.getNextNodeId(), + config, + data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) ) + ); + } + catch (Throwable e) { + throw new RuntimeException( e ); + } + + }); + + } + return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> { try { - var currentState = getInitialState(inputs); + var currentState = cloneState( getInitialState(inputs, config) ) ; - queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) )); - addCheckpoint( "start", currentState ); + yieldData( queue, NodeOutput.of( START, currentState ) ); log.trace( "START"); var currentNodeId = this.getEntryPoint( currentState ); + addCheckpoint( config, START, currentState, currentNodeId ); Map partialState; int iteration = 0; - while( !Objects.equals(currentNodeId, StateGraph.END) ) { + while( !Objects.equals(currentNodeId, END) ) { log.trace( "NEXT NODE: {}", currentNodeId); @@ -183,18 +333,21 @@ public AsyncGenerator> stream(Map inputs, Invok partialState = action.apply(currentState).get(); - currentState = stateGraph.getStateFactory().apply( currentState.mergeWith(partialState, stateGraph.getChannels()) ); + currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) ); - queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) )); - addCheckpoint( currentNodeId, currentState ); + yieldData( queue, NodeOutput.of(currentNodeId, currentState) ); if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) { + addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() ); break; } - currentNodeId = nextNodeId(currentNodeId, currentState); + final String nextNodeId = nextNodeId(currentNodeId, currentState); + addCheckpoint( config, currentNodeId, currentState, nextNodeId ); - if ( Objects.equals(currentNodeId, StateGraph.END) ) { + currentNodeId = nextNodeId; + + if ( Objects.equals(currentNodeId, END) ) { break; } @@ -205,8 +358,9 @@ public AsyncGenerator> stream(Map inputs, Invok } - queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) )); - addCheckpoint( "stop", currentState ); + yieldData( queue, NodeOutput.of(END, currentState) ); + + addCheckpoint( config, END, currentState, END ); log.trace( "STOP"); } catch (Exception e) { @@ -225,7 +379,7 @@ public AsyncGenerator> stream(Map inputs, Invok * @throws Exception if there is an error creating the stream */ public AsyncGenerator> stream(Map inputs ) throws Exception { - return this.stream( inputs, InvokeConfig.builder().build() ); + return this.stream( inputs, RunnableConfig.builder().build() ); } /** * Invokes the graph execution with the provided inputs and returns the final state. @@ -235,7 +389,7 @@ public AsyncGenerator> stream(Map inputs ) thro * @return an Optional containing the final state if present, otherwise an empty Optional * @throws Exception if there is an error during invocation */ - public Optional invoke(Map inputs, InvokeConfig config ) throws Exception { + public Optional invoke(Map inputs, RunnableConfig config ) throws Exception { var sourceIterator = stream(inputs, config).iterator(); @@ -254,7 +408,7 @@ public Optional invoke(Map inputs, InvokeConfig config ) t * @throws Exception if there is an error during invocation */ public Optional invoke(Map inputs ) throws Exception { - return this.invoke( inputs, InvokeConfig.builder().build() ); + return this.invoke( inputs, RunnableConfig.builder().build() ); } /**