diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java index 2ca5415..d6d2f79 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java @@ -1,6 +1,6 @@ package org.bsc.langgraph4j; -import org.bsc.langgraph4j.state.BaseCheckpointSaver; +import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; import java.util.Optional; diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java index e205b89..2eb1a70 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java @@ -116,6 +116,7 @@ private String getEntryPoint( State state ) throws Exception { * @throws Exception if there is an error creating the stream */ public AsyncGenerator> stream(Map inputs, InvokeConfig config ) throws Exception { + Objects.requireNonNull(config, "config cannot be null"); return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> { diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java index 1588a35..3fad8d9 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java @@ -1,6 +1,6 @@ package org.bsc.langgraph4j; -import org.bsc.langgraph4j.state.CheckpointConfig; +import org.bsc.langgraph4j.checkpoint.CheckpointConfig; import java.util.Optional; diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java index 30ffa4d..c9fbbb2 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java @@ -203,6 +203,8 @@ private Node makeFakeNode(String id) { * @throws GraphStateException if there are errors related to the graph state */ public CompiledGraph compile( CompileConfig config ) throws GraphStateException { + Objects.requireNonNull(config, "config cannot be null"); + if (entryPoint == null) { throw Errors.missingEntryPoint.exception(); } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java similarity index 58% rename from core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java rename to core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java index 1877c76..1e65fc1 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/BaseCheckpointSaver.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java @@ -1,11 +1,13 @@ -package org.bsc.langgraph4j.state; +package org.bsc.langgraph4j.checkpoint; +import java.io.Externalizable; import java.util.Collection; import java.util.Optional; public interface BaseCheckpointSaver { + Collection list(); Optional get( String id ); - void put( Checkpoint checkpoint ); + void put( Checkpoint checkpoint ) throws Exception; } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java similarity index 52% rename from core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java rename to core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java index dad10fb..8b5f8e7 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/Checkpoint.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java @@ -1,9 +1,26 @@ -package org.bsc.langgraph4j.state; +package org.bsc.langgraph4j.checkpoint; -import lombok.Value; +import lombok.Data; +import org.bsc.langgraph4j.state.AgentState; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.*; + +/** + * Represents a checkpoint of an agent state. + * + * The checkpoint is an immutable object that holds an {@link AgentState} + * and a {@code String} that represents the next state. + * + * The checkpoint is serializable and can be persisted and restored. + * + * @see AgentState + * @see Externalizable + */ public class Checkpoint { @lombok.Value(staticConstructor="of") @@ -12,8 +29,8 @@ public static class Value { String next; } - private final String id; - private final Value value; + String id; + Value value; public final String getId() { return id; diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java similarity index 73% rename from core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java rename to core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java index 0b33249..77de6e7 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/CheckpointConfig.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java @@ -1,4 +1,4 @@ -package org.bsc.langgraph4j.state; +package org.bsc.langgraph4j.checkpoint; import lombok.Value; diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java new file mode 100644 index 0000000..4f3a093 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java @@ -0,0 +1,31 @@ +package org.bsc.langgraph4j.checkpoint; + +import org.bsc.langgraph4j.serializer.CheckpointSerializer; + +import java.util.*; + +import static java.util.Collections.unmodifiableSet; + +public class MemorySaver implements BaseCheckpointSaver { + + private final Set checkpoints = new LinkedHashSet<>(); + + + @Override + public Collection list() { + return unmodifiableSet(checkpoints); // immutable checkpoints; + } + + @Override + public Optional get(String id) { + return checkpoints.stream() + .filter( checkpoint -> checkpoint.getId().equals(id) ) + .findFirst(); + } + + @Override + public void put(Checkpoint checkpoint) throws Exception { + checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) ); + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java new file mode 100644 index 0000000..08c1414 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java @@ -0,0 +1,74 @@ +package org.bsc.langgraph4j.serializer; + +import lombok.extern.log4j.Log4j; +import lombok.extern.slf4j.Slf4j; +import org.bsc.langgraph4j.state.AgentState; + +import java.io.*; +import java.util.HashMap; +import java.util.Map; + +@Slf4j +public class AgentStateSerializer implements Serializer { + public static final AgentStateSerializer INSTANCE = new AgentStateSerializer(); + private AgentStateSerializer() {} + + @Override + public void write(AgentState object, ObjectOutput out) throws IOException { + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + int actuoalSize = 0; + + final ObjectOutputStream tupleStream = new ObjectOutputStream( baos ); + for( Map.Entry e : object.data().entrySet() ) { + try { + tupleStream.writeUTF(e.getKey()); + tupleStream.writeObject(e.getValue()); + ++actuoalSize; + } catch (IOException ex) { + log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() ); + throw ex; + } + } + + out.writeInt( object.data().size() ); + out.writeInt( actuoalSize ); // actual size + byte[] data = baos.toByteArray(); + out.writeInt( data.length ); + out.write( data ); + + } + + } + + @Override + public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException { + Map data = new HashMap<>(); + + int expectedSize = in.readInt(); + int actualSize = in.readInt(); + if( expectedSize > 0 && actualSize > 0 ) { + + if( expectedSize != actualSize ) { + final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ; + log.error( message ) ; + throw new IOException( message ) ; + } + + int byteLen = in.readInt(); + byte[] bytes = new byte[byteLen]; + in.readFully(bytes); + + try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream( bais ); + + for( int i = 0; i < actualSize; i++ ) { + String key = ois.readUTF(); + Object value = ois.readObject(); + data.put(key, value); + } + } + + } + return new AgentState(data); + } +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java new file mode 100644 index 0000000..cac4f8e --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java @@ -0,0 +1,31 @@ +package org.bsc.langgraph4j.serializer; + +import org.bsc.langgraph4j.checkpoint.Checkpoint; +import org.bsc.langgraph4j.state.AgentState; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +public class CheckpointSerializer implements Serializer { + + public static final CheckpointSerializer INSTANCE = new CheckpointSerializer(); + + private CheckpointSerializer() {} + + public void write( Checkpoint object, ObjectOutput out) throws IOException { + out.writeUTF(object.getId()); + Checkpoint.Value value = object.getValue(); + AgentStateSerializer.INSTANCE.write( value.getState(), out ); + out.writeUTF( value.getNext() ); + } + + public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException { + String id = in.readUTF(); + AgentState state = AgentStateSerializer.INSTANCE.read( in ); + String next = in.readUTF(); + Checkpoint.Value value = Checkpoint.Value.of( state, next ); + return new Checkpoint(id, value); + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java new file mode 100644 index 0000000..15404f6 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java @@ -0,0 +1,40 @@ +package org.bsc.langgraph4j.serializer; + +import org.bsc.langgraph4j.checkpoint.Checkpoint; + +import java.io.*; +import java.util.Objects; + +public interface Serializer { + + void write(T object, ObjectOutput out) throws IOException; + + + T read(ObjectInput in) throws IOException, ClassNotFoundException ; + + default byte[] writeObject(T object) throws IOException { + Objects.requireNonNull( object, "object cannot be null" ); + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + ObjectOutputStream oas = new ObjectOutputStream(baos); + write(object, oas); + oas.flush(); + return baos.toByteArray(); + } + } + + default T readObject(byte[] bytes) throws IOException, ClassNotFoundException { + Objects.requireNonNull( bytes, "bytes cannot be null" ); + if( bytes.length == 0 ) { + throw new IllegalArgumentException("bytes cannot be empty"); + } + try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream(bais); + return read(ois); + } + } + + default T cloneObject(T object) throws IOException, ClassNotFoundException { + Objects.requireNonNull( object, "object cannot be null" ); + return readObject(writeObject(object)); + } +} diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java new file mode 100644 index 0000000..4721afe --- /dev/null +++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java @@ -0,0 +1,72 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.serializer.AgentStateSerializer; +import org.bsc.langgraph4j.state.AgentState; +import org.junit.jupiter.api.Test; + +import java.io.*; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class SerializeTest { + + + private byte[] serializeState(AgentState state) throws Exception { + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + ObjectOutputStream oas = new ObjectOutputStream(baos); + AgentStateSerializer.INSTANCE.write(state, oas); + oas.flush(); + return baos.toByteArray(); + } + } + private AgentState deserializeState( byte[] bytes ) throws Exception { + try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream( bais ); + return AgentStateSerializer.INSTANCE.read( ois ); + } + } + + @Test + public void serializeStateTest() throws Exception { + + Map data = new HashMap<>(); + data.put("a", "b"); + data.put("f", null); + data.put("c", "d"); + + final AgentState state = new AgentState(data); + + byte[] bytes = serializeState(state); + + assertNotNull(bytes); + AgentState deserializeState = deserializeState( bytes ); + + assertEquals( 3, deserializeState.data().size() ); + assertEquals( "b", deserializeState.data().get("a") ); + assertEquals( "d", deserializeState.data().get("c") ); + } + + static class NonSerializableElement { + String value = "TEST"; + public NonSerializableElement() { + } + } + @Test + public void partiallySerializeStateTest() throws Exception { + + Map data = new HashMap<>(); + data.put("a", "b"); + data.put("f", new NonSerializableElement() ); + data.put("c", "d"); + + final AgentState state = new AgentState(data); + + assertThrows(IOException.class, () -> { + serializeState(state); + }); + + } + +}