Skip to content

Commit

Permalink
feat(LangGraphStreamingServer): implementation refinement
Browse files Browse the repository at this point in the history
work on #9
  • Loading branch information
bsorrentino committed Jul 13, 2024
1 parent 3cf8b64 commit 3b8c6cf
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 35 deletions.
32 changes: 26 additions & 6 deletions jetty/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
<artifactId>jetty</artifactId>

<packaging>war</packaging>
<name>Servlet 3 Webapp with Jetty Plugin</name>
<name>langgraph4j::jetty::server</name>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jetty.version>12.0.11</jetty.version>
<jackson.version>2.17.2</jackson.version>
</properties>

<dependencyManagement>
Expand All @@ -27,7 +28,12 @@
</dependencyManagement>

<dependencies>

<dependency>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j-jdk8</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<!--
<dependency>
<groupId>javax.servlet</groupId>
Expand All @@ -36,14 +42,18 @@
<scope>provided</scope>
</dependency>
-->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>

<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
<version>${jetty.version}</version>
</dependency>

<!-- https://mvnrepository.com/artifact/org.eclipse.jetty.ee10/jetty-ee10-servlet -->
<dependency>
<groupId>org.eclipse.jetty.ee10</groupId>
<artifactId>jetty-ee10-servlet</artifactId>
Expand All @@ -55,6 +65,8 @@
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>


</dependencies>

<build>
Expand All @@ -67,8 +79,16 @@
<execution><goals><goal>java</goal></goals></execution>
</executions>
<configuration>
<mainClass>org.bsc.langgraph4j.JettyStreamingServer</mainClass>
<classpathScope>test</classpathScope>
<mainClass>org.bsc.langgraph4j.LangGraphStreamingServerTest</mainClass>
</configuration>
<dependencies>
<dependency>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j-jdk8</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
</plugin>
</plugins>
</build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

public class Langgraph4jStreamingServer {
public interface LangGraphStreamingServer {

CompletableFuture<Void> start() throws Exception;

public static void main(String[] args) throws Exception {

}

public static <State extends AgentState> void start(CompiledGraph<State> compiledGraph) throws Exception {
public static <State extends AgentState> LangGraphStreamingServer of(CompiledGraph<State> compiledGraph) throws Exception {

Server server = new Server();
ServerConnector connector = new ServerConnector(server);
Expand All @@ -48,39 +46,36 @@ public static <State extends AgentState> void start(CompiledGraph<State> compile
ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
context.setContextPath("/");
// Add the streaming servlet
context.addServlet(new ServletHolder(new StreamingServlet(compiledGraph)), "/stream");
context.addServlet(new ServletHolder(new StreamingServlet<State>(compiledGraph)), "/stream");
context.addServlet(new ServletHolder(new GraphServlet<State>(compiledGraph)), "/graph");

Handler.Sequence handlerList = new Handler.Sequence(resourceHandler, context );

server.setHandler(handlerList);

server.start();
server.join();
return new LangGraphStreamingServer() {
@Override
public CompletableFuture<Void> start() throws Exception {
return CompletableFuture.runAsync( () -> {
try {
server.start();
} catch( Exception e ) {
throw new RuntimeException(e);
}
}, Runnable::run);
}
};
}

public static class StreamingServlet<State extends AgentState> extends HttpServlet {
class StreamingServlet<State extends AgentState> extends HttpServlet {
final CompiledGraph<State> compiledGraph;
final ObjectMapper objectMapper = new ObjectMapper();

public StreamingServlet( CompiledGraph<State> compiledGraph ) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
}

private AsyncGenerator<String> streamTest() {

List<String> chunks = Arrays.asList(
"a", "b", "c", "d", "e",
"f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
"p", "q", "r", "s", "t",
"u", "v", "w", "x", "y", "z");

return AsyncGenerator.map( chunks,
s -> CompletableFuture.completedFuture( "chunk " + s) );

}


@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setHeader("Accept", "application/json");
Expand All @@ -99,7 +94,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
.forEachAsync( s -> {
writer.println(s.node());
writer.flush();
}).thenAccept( v -> {

try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).thenAccept( v -> {
writer.close();
});

Expand All @@ -112,7 +113,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
/**
* return the graph representation in mermaid format
*/
public static class GraphServlet<State extends AgentState> extends HttpServlet {
class GraphServlet<State extends AgentState> extends HttpServlet {

final CompiledGraph<State> compiledGraph;

Expand All @@ -131,7 +132,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t
// Start asynchronous processing
request.startAsync();
final PrintWriter writer = response.getWriter();
writer.println(result);
writer.println(result.getContent());
writer.close();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,62 @@
package org.bsc.langgraph4j;public class LangGraphStreamingServerTest {
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.action.EdgeAction;
import org.bsc.langgraph4j.state.AgentState;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

public class LangGraphStreamingServerTest {


public static void main(String[] args) throws Exception {
StateGraph<AgentState> workflow = new StateGraph<>(AgentState::new);

workflow.setEntryPoint("agent_1");

workflow.addNode("agent_1", node_async((state ) -> {
System.out.println("agent_1 ");
System.out.println(state);
return mapOf("prop1", "test");
}) ) ;

workflow.addNode("agent_2", node_async( state -> {

System.out.print( "agent_2: ");
System.out.println( state );

return mapOf("prop2", "test");
}));

workflow.addEdge("agent_2", "agent_1" );


EdgeAction<AgentState> conditionalAge = new EdgeAction<>() {
int steps= 0;
@Override
public String apply(AgentState state) {
if( ++steps == 10 ) {
return END;
}
return "a2";
}
};

workflow.addConditionalEdges("agent_1",
edge_async(conditionalAge), mapOf( "a2", "agent_2", "end", END ) );

CompiledGraph<AgentState> app = workflow.compile();

LangGraphStreamingServer server = LangGraphStreamingServer.of(app);

server.start().join();

}

}

0 comments on commit 3b8c6cf

Please sign in to comment.