Skip to content

Commit

Permalink
feat(CompiledGraph): optimize code
Browse files Browse the repository at this point in the history
minimize cloneState() calls
  • Loading branch information
bsorrentino committed Sep 6, 2024
1 parent e8a2f07 commit 4e9a763
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
39 changes: 20 additions & 19 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public RunnableConfig updateState( RunnableConfig config, Map<String,Object> val

String nextNodeId = null;
if( asNode != null ) {
nextNodeId = nextNodeId( asNode, stateGraph.getStateFactory().apply(updatedCheckpoint.getState()) );
nextNodeId = nextNodeId( asNode, updatedCheckpoint.getState() );
}
// update checkpoint in saver
var newConfig = saver.put( config, updatedCheckpoint );
Expand Down Expand Up @@ -120,7 +120,7 @@ public void setMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
}

private String nextNodeId( EdgeValue<State> route , State state, String nodeId ) throws Exception {
private String nextNodeId( EdgeValue<State> route , Map<String,Object> state, String nodeId ) throws Exception {

if( route == null ) {
throw StateGraph.RunnableErrors.missingEdge.exception(nodeId);
Expand All @@ -129,8 +129,9 @@ private String nextNodeId( EdgeValue<State> route , State state, String nodeId )
return route.id();
}
if( route.value() != null ) {
var derefState = stateGraph.getStateFactory().apply(state);
var condition = route.value().action();
var newRoute = condition.apply(state).get();
var newRoute = condition.apply(derefState).get();
var result = route.value().mappings().get(newRoute);
if( result == null ) {
throw StateGraph.RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
Expand All @@ -148,12 +149,12 @@ private String nextNodeId( EdgeValue<State> route , State state, String nodeId )
* @return the next node ID
* @throws Exception if there is an error determining the next node ID
*/
private String nextNodeId(String nodeId, State state) throws Exception {
private String nextNodeId(String nodeId, Map<String,Object> state) throws Exception {
return nextNodeId(edges.get(nodeId), state, nodeId);

}

private String getEntryPoint( State state ) throws Exception {
private String getEntryPoint( Map<String,Object> state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

Expand All @@ -168,11 +169,11 @@ private boolean shouldInterruptAfter( String nodeId ) {
return Arrays.asList(compileConfig.getInterruptAfter()).contains(nodeId);
}

private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception {
private void addCheckpoint( RunnableConfig config, String nodeId, Map<String,Object> state, String nextNodeId ) throws Exception {
if( compileConfig.checkpointSaver().isPresent() ) {
Checkpoint cp = Checkpoint.builder()
.nodeId( nodeId )
.state( state.data() )
.state( cloneState(state) )
.nextNodeId( nextNodeId )
.build();
compileConfig.checkpointSaver().get().put( config, cp );
Expand Down Expand Up @@ -203,7 +204,7 @@ State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundExc
return stateGraph.getStateFactory().apply(newData);
}

private void streamData( State initialState,
private void streamData( Map<String,Object> initialState,
String startNodeId,
RunnableConfig config,
Consumer<NodeOutput<State>> yieldData) throws Exception {
Expand All @@ -226,23 +227,23 @@ private void streamData( State initialState,

if ( shouldInterruptBefore( currentNodeId, startNodeId )) {
log.trace("interrupt before node {}", currentNodeId);
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), currentNodeId );
addCheckpoint( config, currentNodeId, currentState, currentNodeId );
return;
}

partialState = action.apply( cloneState(currentState.data())).get();
partialState = action.apply( cloneState(currentState)).get();

currentState = stateGraph.getStateFactory().apply(AgentState.updateState(currentState, partialState, stateGraph.getChannels()));
currentState = AgentState.updateState(currentState, partialState, stateGraph.getChannels());

yieldData.accept( NodeOutput.of(currentNodeId, cloneState(currentState.data())) );
yieldData.accept( NodeOutput.of(currentNodeId, cloneState(currentState)) );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), stateGraph.getFinishPoint() );
addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() );
break;
}

final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), nextNodeId );
addCheckpoint( config, currentNodeId, currentState, nextNodeId );

if ( shouldInterruptAfter( currentNodeId ) ) {
log.trace( "interrupt after node {}", currentNodeId);
Expand All @@ -262,7 +263,7 @@ private void streamData( State initialState,

}

yieldData.accept( NodeOutput.of(END, cloneState(currentState.data())) );
yieldData.accept( NodeOutput.of(END, cloneState(currentState)) );

// addCheckpoint( config, END, currentState, null );

Expand Down Expand Up @@ -294,7 +295,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna

log.trace( "RESUME FROM {}", startCheckpoint.getNodeId() );

State startState = stateGraph.getStateFactory().apply( startCheckpoint.getState() );
Map<String,Object> startState = startCheckpoint.getState();

// Reset checkpoint id
RunnableConfig resumeConfig = RunnableConfig.builder(config)
Expand All @@ -318,14 +319,14 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna

queue.add( AsyncGenerator.Data.of( NodeOutput.of( START, cloneState(startState.data()) ) ));

String startNodeId = this.getEntryPoint( startState );
String startNodeId = this.getEntryPoint( startState.data() );
if( shouldInterruptBefore( startNodeId, null ) ) return;

addCheckpoint( config, START, cloneState(startState.data()), startNodeId );
addCheckpoint( config, START, startState.data(), startNodeId );

if( shouldInterruptAfter( startNodeId ) ) return;

streamData( startState,
streamData( startState.data(),
startNodeId,
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) ) );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ void testWithAppenderDeprecated() throws Exception {

assertTrue( result.isPresent() );
System.out.println( result.get().data() );
assertEquals( 3, result.get().steps() );
assertEquals( 3, result.get().messages().size() );
assertEquals( 3, result.get().steps() );
assertIterableEquals( listOf( "message1", "message2", "message3"), result.get().messages().values() );
}

Expand Down

0 comments on commit 4e9a763

Please sign in to comment.