Skip to content

Commit

Permalink
feat: add support for 'interruptBeforereak' and 'interruptAfter'
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 28, 2024
1 parent fd072d2 commit 3083d9d
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

private boolean shouldInterruptBefore( String nodeId ) {
return Arrays.asList(compileConfig.getInterruptBefore()).contains(nodeId);
}

private boolean shouldInterruptAfter( String nodeId ) {
return Arrays.asList(compileConfig.getInterruptAfter()).contains(nodeId);
}

private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception {
if( compileConfig.checkpointSaver().isPresent() ) {
Checkpoint cp = Checkpoint.builder()
Expand Down Expand Up @@ -203,8 +211,13 @@ private void streamData( State initialState,
log.trace( "NEXT NODE: {}", currentNodeId);

var action = nodes.get(currentNodeId);
if (action == null) {

if (action == null)
throw StateGraph.RunnableErrors.missingNode.exception(currentNodeId);

if ( shouldInterruptBefore( currentNodeId ) ) {
log.trace("interrupt before node {}", currentNodeId);
return;
}

partialState = action.apply(currentState).get();
Expand All @@ -221,11 +234,16 @@ private void streamData( State initialState,
final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, currentState, nextNodeId );

if ( shouldInterruptAfter( currentNodeId ) ) {
log.trace( "interrupt after node {}", currentNodeId);
return;
}

currentNodeId = nextNodeId;

if ( Objects.equals(currentNodeId, END) ) {
if ( Objects.equals(currentNodeId, END) )
break;
}


if( ++iteration > maxIterations ) {
log.warn( "Maximum number of iterations ({}) reached!", maxIterations);
Expand Down Expand Up @@ -273,7 +291,6 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) )
);

}));

}
Expand All @@ -286,13 +303,17 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna
queue.add( AsyncGenerator.Data.of( NodeOutput.of( START, startState ) ));

String startNodeId = this.getEntryPoint( startState );
if( shouldInterruptBefore( startNodeId ) ) return;

addCheckpoint( config, START, startState, startNodeId );

if( shouldInterruptAfter( startNodeId ) ) return;


streamData( startState,
startNodeId,
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) ) );

}));

}
Expand Down

0 comments on commit 3083d9d

Please sign in to comment.