Skip to content

Commit

Permalink
feat(server): add builder
Browse files Browse the repository at this point in the history
with support of:
- port
- inputArg metadata

work on #9
  • Loading branch information
bsorrentino committed Jul 15, 2024
1 parent 4ba26dd commit 9e8109d
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 87 deletions.
206 changes: 120 additions & 86 deletions jetty/src/main/java/org/bsc/langgraph4j/LangGraphStreamingServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,92 +6,121 @@
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.state.AgentState;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHolder;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.ResourceHandler;
import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.resource.ResourceFactory;

import java.io.IOException;
import java.io.PrintWriter;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;


public interface LangGraphStreamingServer {

CompletableFuture<Void> start() throws Exception;

public static <State extends AgentState> LangGraphStreamingServer of(CompiledGraph<State> compiledGraph) throws Exception {

Server server = new Server();
ServerConnector connector = new ServerConnector(server);
connector.setPort(8080);
server.addConnector(connector);

ResourceHandler resourceHandler = new ResourceHandler();
Path publicResourcesPath = Paths.get( "jetty", "src", "main", "webapp" );
resourceHandler.setBaseResource(ResourceFactory.of(resourceHandler).newResource(publicResourcesPath));
resourceHandler.setDirAllowed(true);

ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
context.setContextPath("/");
// Add the streaming servlet
context.addServlet(new ServletHolder(new StreamingServlet<State>(compiledGraph)), "/stream");
context.addServlet(new ServletHolder(new GraphServlet<State>(compiledGraph)), "/graph");

Handler.Sequence handlerList = new Handler.Sequence(resourceHandler, context );

server.setHandler(handlerList);

return new LangGraphStreamingServer() {
@Override
public CompletableFuture<Void> start() throws Exception {
return CompletableFuture.runAsync( () -> {
try {
server.start();
} catch( Exception e ) {
throw new RuntimeException(e);
}
}, Runnable::run);
}
};
static Builder builder() {
return new Builder();
}

class StreamingServlet<State extends AgentState> extends HttpServlet {
final CompiledGraph<State> compiledGraph;
final ObjectMapper objectMapper = new ObjectMapper();
class Builder {
private int port = 8080;
private Map<String,ArgumentMetadata> inputArgs = new HashMap<>();

public Builder port(int port) {
this.port = port;
return this;
}

public StreamingServlet( CompiledGraph<State> compiledGraph ) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
public Builder addInputStringArg(String name, boolean required) {
inputArgs.put(name, new ArgumentMetadata("string", required) );
return this;
}
public Builder addInputStringArg(String name) {
inputArgs.put(name, new ArgumentMetadata("string", true) );
return this;
}

@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setHeader("Accept", "application/json");
response.setContentType("text/plain");
response.setCharacterEncoding("UTF-8");

final PrintWriter writer = response.getWriter();
public <State extends AgentState> LangGraphStreamingServer build(CompiledGraph<State> compiledGraph) {
Server server = new Server();
ServerConnector connector = new ServerConnector(server);
connector.setPort(port);
server.addConnector(connector);

ResourceHandler resourceHandler = new ResourceHandler();

// Path publicResourcesPath = Paths.get("jetty", "src", "main", "webapp");
// Resource baseResource = ResourceFactory.of(resourceHandler).newResource(publicResourcesPath));
Resource baseResource = ResourceFactory.of(resourceHandler).newClassLoaderResource("webapp");
resourceHandler.setBaseResource(baseResource);

resourceHandler.setDirAllowed(true);

ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
context.setContextPath("/");
// Add the streaming servlet
context.addServlet(new ServletHolder(new GraphExecutionServlet<State>(compiledGraph)), "/stream");
context.addServlet(new ServletHolder(new GraphInitServlet<State>(compiledGraph, inputArgs)), "/init");

Map<String, Object> dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<Map<String, Object>>() {});
Handler.Sequence handlerList = new Handler.Sequence(resourceHandler, context);

// Start asynchronous processing
request.startAsync();
server.setHandler(handlerList);

try {
compiledGraph.stream(dataMap)
.forEachAsync( s -> {
return new LangGraphStreamingServer() {
@Override
public CompletableFuture<Void> start() throws Exception {
return CompletableFuture.runAsync(() -> {
try {
server.start();
} catch (Exception e) {
throw new RuntimeException(e);
}
}, Runnable::run);

}
};

}
}
}


class GraphExecutionServlet<State extends AgentState> extends HttpServlet {
final CompiledGraph<State> compiledGraph;
final ObjectMapper objectMapper = new ObjectMapper();

public GraphExecutionServlet(CompiledGraph<State> compiledGraph) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
}

@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setHeader("Accept", "application/json");
response.setContentType("text/plain");
response.setCharacterEncoding("UTF-8");

final PrintWriter writer = response.getWriter();

Map<String, Object> dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<Map<String, Object>>() {
});

// Start asynchronous processing
request.startAsync();

try {
compiledGraph.stream(dataMap)
.forEachAsync(s -> {
writer.println(s.node());
writer.flush();

Expand All @@ -100,41 +129,46 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).thenAccept( v -> {
writer.close();
});
}).thenAccept(v -> {
writer.close();
});

} catch (Exception e) {
throw new RuntimeException(e);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

/**
* return the graph representation in mermaid format
*/
class GraphServlet<State extends AgentState> extends HttpServlet {
record ArgumentMetadata (
String type,
boolean required
) {}

final CompiledGraph<State> compiledGraph;
/**
* return the graph representation in mermaid format
*/
class GraphInitServlet<State extends AgentState> extends HttpServlet {

public GraphServlet( CompiledGraph<State> compiledGraph ) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
}
final CompiledGraph<State> compiledGraph;
final Map<String, ArgumentMetadata> inputArgs;

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setContentType("text/plain");
response.setCharacterEncoding("UTF-8");
public GraphInitServlet(CompiledGraph<State> compiledGraph, Map<String, ArgumentMetadata> inputArgs) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
this.inputArgs = inputArgs;
}

GraphRepresentation result = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID);
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setContentType("text/plain");
response.setCharacterEncoding("UTF-8");

// Start asynchronous processing
request.startAsync();
final PrintWriter writer = response.getWriter();
writer.println(result.getContent());
writer.close();
}
}
GraphRepresentation result = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID);

}
// Start asynchronous processing
request.startAsync();
final PrintWriter writer = response.getWriter();
writer.println(result.getContent());
writer.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ public String apply(AgentState state) {

CompiledGraph<AgentState> app = workflow.compile();

LangGraphStreamingServer server = LangGraphStreamingServer.of(app);
LangGraphStreamingServer server = LangGraphStreamingServer.builder()
.port(8080)
.build(app);

server.start().join();

Expand Down

0 comments on commit 9e8109d

Please sign in to comment.