Skip to content

Commit

Permalink
Merge branch 'feature/#20_threadid' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 26, 2024
2 parents deabff6 + c2a8e87 commit bc7cfaf
Show file tree
Hide file tree
Showing 16 changed files with 786 additions and 369 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dev.langchain4j.model.openai.OpenAiChatModel;
import lombok.var;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.InvokeConfig;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
Expand Down Expand Up @@ -45,13 +45,11 @@ private AgentExecutor.GraphBuilder newGraphBuilder() throws Exception {
.maxTokens(2000)
.build();


var agentExecutor = new AgentExecutor();

return agentExecutor.graphBuilder()
.chatLanguageModel(chatLanguageModel)
.objectsWithTools(listOf(new TestTool()));

}

private List<AgentExecutor.State> executeAgent( String prompt ) throws Exception {
Expand All @@ -64,7 +62,7 @@ private List<AgentExecutor.State> executeAgent( String prompt ) throws Exceptio
}

private List<AgentExecutor.State> executeAgent(String prompt, String threadId, BaseCheckpointSaver saver) throws Exception {
var config = InvokeConfig.builder().checkpointThreadId(threadId).build();
var config = RunnableConfig.builder().threadId(threadId).build();

var iterator = newGraphBuilder()
.checkpointSaver( saver )
Expand Down
205 changes: 152 additions & 53 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,41 @@
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.serializer.MapSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.state.StateSnapshot;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Supplier;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;


@FunctionalInterface
interface TryConsumer<T, Ex extends Throwable> extends Consumer<T> {

void tryAccept( T t ) throws Ex;

default void accept( T t ) {
try {
tryAccept(t);
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
}

static <T,Ex extends Throwable> Consumer<T> Try( TryConsumer<T, Ex> consumer ) {
return consumer;
}
}


/**
* Represents a compiled graph of nodes and edges.
Expand Down Expand Up @@ -55,20 +79,54 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
);
}

public Collection<StateSnapshot<State>> getStateHistory( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

return saver.list(config).stream()
.map( checkpoint -> StateSnapshot.of( checkpoint, config, stateGraph.getStateFactory() ) )
.collect(Collectors.toList());
}

public StateSnapshot<State> getState( RunnableConfig config ) {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

var checkpoint = saver.get(config).orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

return StateSnapshot.of( checkpoint, config, stateGraph.getStateFactory() );
}

public RunnableConfig updateState( RunnableConfig config, Map<String,Object> values, String asNode ) throws Exception {
var saver = compileConfig.getCheckpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

// merge values with checkpoint values
var updatedCheckpoint = saver.get(config)
.map( cp -> cp.updateState(values, stateGraph.getChannels()) )
.orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

// update checkpoint in saver
var newConfig = saver.put( config, updatedCheckpoint );

return RunnableConfig.builder(newConfig)
.checkPointId( updatedCheckpoint.getId() )
.nextNode( asNode )
.build();
}

public EdgeValue<State> getEntryPoint() {
return stateGraph.getEntryPoint();
}

public String getFinishPoint() {
return stateGraph.getFinishPoint();
}

/**
* Sets the maximum number of iterations for the graph execution.
*
* @param maxIterations the maximum number of iterations
* @throws IllegalArgumentException if maxIterations is less than or equal to 0
*/
void setMaxIterations(int maxIterations) {
public void setMaxIterations(int maxIterations) {
if( maxIterations <= 0 ) {
throw new IllegalArgumentException("maxIterations must be > 0!");
}
Expand Down Expand Up @@ -112,67 +170,54 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

private void addCheckpoint( String nodeId, State state ) throws Exception {
private void addCheckpoint( RunnableConfig config, String nodeId, State state, String nextNodeId ) throws Exception {
if( compileConfig.getCheckpointSaver().isPresent() ) {
Checkpoint.Value value = Checkpoint.Value.of(state, nodeId);
compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) );
Checkpoint cp = Checkpoint.builder()
.nodeId( nodeId )
.state( state )
.nextNodeId( nextNodeId )
.build();
compileConfig.getCheckpointSaver().get().put( config, cp );
}
}

Map<String,Object> getInitialStateFromSchema() {
return stateGraph.getChannels().entrySet().stream()
return stateGraph.getChannels().entrySet().stream()
.filter( c -> c.getValue().getDefault().isPresent() )
.collect(Collectors.toMap(Map.Entry::getKey, e ->
e.getValue().getDefault().get().get()
));
}

State getInitialState(Map<String,Object> inputs) {
Map<String,Object> getInitialState(Map<String,Object> inputs, RunnableConfig config) {

return compileConfig.getCheckpointSaver()
.flatMap(BaseCheckpointSaver::getLast)
.map( cp -> {
var state = cp.getValue().getState().mergeWith(inputs, stateGraph.getChannels());
return stateGraph.getStateFactory().apply(state);
})
.orElseGet( () -> {
var initialState =
stateGraph.getStateFactory()
.andThen( state -> state.mergeWith(inputs, stateGraph.getChannels()) )
.apply( getInitialStateFromSchema() );
return stateGraph.getStateFactory().apply(initialState);
});
.flatMap( saver -> saver.get( config ) )
.map( cp -> AgentState.updateState( cp.getState(), inputs, stateGraph.getChannels() ))
.orElseGet( () -> AgentState.updateState(getInitialStateFromSchema(), inputs, stateGraph.getChannels() ));
}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {
State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundException {

try {
Map<String,Object> newData = MapSerializer.INSTANCE.cloneObject(data);

var currentState = getInitialState(inputs);
return stateGraph.getStateFactory().apply(newData);
}

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) ));
addCheckpoint( "start", currentState );
private void streamData( State initialState,
String startNodeId,
RunnableConfig config,
Consumer<NodeOutput<State>> yieldData) throws Exception {

log.trace( "START");
var currentState = initialState;

var currentNodeId = this.getEntryPoint( currentState );
var currentNodeId = startNodeId;

Map<String, Object> partialState;

int iteration = 0;

while( !Objects.equals(currentNodeId, StateGraph.END) ) {
while( !Objects.equals(currentNodeId, END) ) {

log.trace( "NEXT NODE: {}", currentNodeId);

Expand All @@ -183,18 +228,21 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok

partialState = action.apply(currentState).get();

currentState = stateGraph.getStateFactory().apply( currentState.mergeWith(partialState, stateGraph.getChannels()) );
currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) );

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) ));
addCheckpoint( currentNodeId, currentState );
yieldData.accept( NodeOutput.of(currentNodeId, currentState) );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() );
break;
}

currentNodeId = nextNodeId(currentNodeId, currentState);
final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, currentState, nextNodeId );

currentNodeId = nextNodeId;

if ( Objects.equals(currentNodeId, StateGraph.END) ) {
if ( Objects.equals(currentNodeId, END) ) {
break;
}

Expand All @@ -205,15 +253,66 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok

}

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) ));
addCheckpoint( "stop", currentState );
yieldData.accept( NodeOutput.of(END, currentState) );

// addCheckpoint( config, END, currentState, null );

log.trace( "STOP");

} catch (Exception e) {
throw new RuntimeException( e );
}
}


/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, RunnableConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

final boolean isResumeRequest = (inputs == null);

if( isResumeRequest ) {

BaseCheckpointSaver saver = compileConfig.getCheckpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));

Checkpoint startCheckpoint = saver.get( config ).orElseThrow( () -> (new IllegalStateException("Resume request without a saved checkpoint!")) );

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), TryConsumer.Try(queue -> {

log.trace( "RESUME FROM {}", startCheckpoint.getNodeId() );

State startState = stateGraph.getStateFactory().apply( startCheckpoint.getState() );

streamData( startState,
startCheckpoint.getNextNodeId(),
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) )
);

}));

}

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), TryConsumer.Try(queue -> {

log.trace( "START" );

State startState = cloneState( getInitialState(inputs, config) ) ;
queue.add( AsyncGenerator.Data.of( NodeOutput.of( START, startState ) ));

String startNodeId = this.getEntryPoint( startState );
addCheckpoint( config, START, startState, startNodeId );

streamData( startState,
startNodeId,
config,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) ) );

});
}));

}

Expand All @@ -225,7 +324,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
return this.stream( inputs, InvokeConfig.builder().build() );
return this.stream( inputs, RunnableConfig.builder().build() );
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
Expand All @@ -235,7 +334,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
public Optional<State> invoke(Map<String,Object> inputs, RunnableConfig config ) throws Exception {

var sourceIterator = stream(inputs, config).iterator();

Expand All @@ -254,7 +353,7 @@ public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) t
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return this.invoke( inputs, InvokeConfig.builder().build() );
return this.invoke( inputs, RunnableConfig.builder().build() );
}

/**
Expand Down
39 changes: 0 additions & 39 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java

This file was deleted.

Loading

0 comments on commit bc7cfaf

Please sign in to comment.