Skip to content

Commit

Permalink
feat: finalize checkpoint implementation
Browse files Browse the repository at this point in the history
add AgentState and Checkpointer serializer
add support for MemorySaver

work on #11
  • Loading branch information
bsorrentino committed Aug 6, 2024
1 parent b8ab321 commit 1564efc
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.state.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;

import java.util.Optional;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private String getEntryPoint( State state ) throws Exception {
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.state.CheckpointConfig;
import org.bsc.langgraph4j.checkpoint.CheckpointConfig;

import java.util.Optional;

Expand Down
2 changes: 2 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ private Node<State> makeFakeNode(String id) {
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph<State> compile( CompileConfig config ) throws GraphStateException {
Objects.requireNonNull(config, "config cannot be null");

if (entryPoint == null) {
throw Errors.missingEntryPoint.exception();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package org.bsc.langgraph4j.state;
package org.bsc.langgraph4j.checkpoint;

import java.io.Externalizable;
import java.util.Collection;
import java.util.Optional;

public interface BaseCheckpointSaver {


Collection<Checkpoint> list();
Optional<Checkpoint> get( String id );
void put( Checkpoint checkpoint );
void put( Checkpoint checkpoint ) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
package org.bsc.langgraph4j.state;
package org.bsc.langgraph4j.checkpoint;

import lombok.Value;
import lombok.Data;
import org.bsc.langgraph4j.state.AgentState;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.*;


/**
* Represents a checkpoint of an agent state.
*
* The checkpoint is an immutable object that holds an {@link AgentState}
* and a {@code String} that represents the next state.
*
* The checkpoint is serializable and can be persisted and restored.
*
* @see AgentState
* @see Externalizable
*/
public class Checkpoint {

@lombok.Value(staticConstructor="of")
Expand All @@ -12,8 +29,8 @@ public static class Value {
String next;
}

private final String id;
private final Value value;
String id;
Value value;

public final String getId() {
return id;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.bsc.langgraph4j.state;
package org.bsc.langgraph4j.checkpoint;

import lombok.Value;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.bsc.langgraph4j.checkpoint;

import org.bsc.langgraph4j.serializer.CheckpointSerializer;

import java.util.*;

import static java.util.Collections.unmodifiableSet;

public class MemorySaver implements BaseCheckpointSaver {

private final Set<Checkpoint> checkpoints = new LinkedHashSet<>();


@Override
public Collection<Checkpoint> list() {
return unmodifiableSet(checkpoints); // immutable checkpoints;
}

@Override
public Optional<Checkpoint> get(String id) {
return checkpoints.stream()
.filter( checkpoint -> checkpoint.getId().equals(id) )
.findFirst();
}

@Override
public void put(Checkpoint checkpoint) throws Exception {
checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) );
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.bsc.langgraph4j.serializer;

import lombok.extern.log4j.Log4j;
import lombok.extern.slf4j.Slf4j;
import org.bsc.langgraph4j.state.AgentState;

import java.io.*;
import java.util.HashMap;
import java.util.Map;

@Slf4j
public class AgentStateSerializer implements Serializer<AgentState> {
public static final AgentStateSerializer INSTANCE = new AgentStateSerializer();
private AgentStateSerializer() {}

@Override
public void write(AgentState object, ObjectOutput out) throws IOException {
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
int actuoalSize = 0;

final ObjectOutputStream tupleStream = new ObjectOutputStream( baos );
for( Map.Entry<String,Object> e : object.data().entrySet() ) {
try {
tupleStream.writeUTF(e.getKey());
tupleStream.writeObject(e.getValue());
++actuoalSize;
} catch (IOException ex) {
log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() );
throw ex;
}
}

out.writeInt( object.data().size() );
out.writeInt( actuoalSize ); // actual size
byte[] data = baos.toByteArray();
out.writeInt( data.length );
out.write( data );

}

}

@Override
public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException {
Map<String, Object> data = new HashMap<>();

int expectedSize = in.readInt();
int actualSize = in.readInt();
if( expectedSize > 0 && actualSize > 0 ) {

if( expectedSize != actualSize ) {
final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
log.error( message ) ;
throw new IOException( message ) ;
}

int byteLen = in.readInt();
byte[] bytes = new byte[byteLen];
in.readFully(bytes);

try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream( bais );

for( int i = 0; i < actualSize; i++ ) {
String key = ois.readUTF();
Object value = ois.readObject();
data.put(key, value);
}
}

}
return new AgentState(data);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.bsc.langgraph4j.serializer;

import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.state.AgentState;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

public class CheckpointSerializer implements Serializer<Checkpoint> {

public static final CheckpointSerializer INSTANCE = new CheckpointSerializer();

private CheckpointSerializer() {}

public void write( Checkpoint object, ObjectOutput out) throws IOException {
out.writeUTF(object.getId());
Checkpoint.Value value = object.getValue();
AgentStateSerializer.INSTANCE.write( value.getState(), out );
out.writeUTF( value.getNext() );
}

public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException {
String id = in.readUTF();
AgentState state = AgentStateSerializer.INSTANCE.read( in );
String next = in.readUTF();
Checkpoint.Value value = Checkpoint.Value.of( state, next );
return new Checkpoint(id, value);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.bsc.langgraph4j.serializer;

import org.bsc.langgraph4j.checkpoint.Checkpoint;

import java.io.*;
import java.util.Objects;

public interface Serializer<T> {

void write(T object, ObjectOutput out) throws IOException;


T read(ObjectInput in) throws IOException, ClassNotFoundException ;

default byte[] writeObject(T object) throws IOException {
Objects.requireNonNull( object, "object cannot be null" );
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
ObjectOutputStream oas = new ObjectOutputStream(baos);
write(object, oas);
oas.flush();
return baos.toByteArray();
}
}

default T readObject(byte[] bytes) throws IOException, ClassNotFoundException {
Objects.requireNonNull( bytes, "bytes cannot be null" );
if( bytes.length == 0 ) {
throw new IllegalArgumentException("bytes cannot be empty");
}
try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream(bais);
return read(ois);
}
}

default T cloneObject(T object) throws IOException, ClassNotFoundException {
Objects.requireNonNull( object, "object cannot be null" );
return readObject(writeObject(object));
}
}
72 changes: 72 additions & 0 deletions core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.serializer.AgentStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.junit.jupiter.api.Test;

import java.io.*;
import java.util.HashMap;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;

public class SerializeTest {


private byte[] serializeState(AgentState state) throws Exception {
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
ObjectOutputStream oas = new ObjectOutputStream(baos);
AgentStateSerializer.INSTANCE.write(state, oas);
oas.flush();
return baos.toByteArray();
}
}
private AgentState deserializeState( byte[] bytes ) throws Exception {
try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream( bais );
return AgentStateSerializer.INSTANCE.read( ois );
}
}

@Test
public void serializeStateTest() throws Exception {

Map<String,Object> data = new HashMap<>();
data.put("a", "b");
data.put("f", null);
data.put("c", "d");

final AgentState state = new AgentState(data);

byte[] bytes = serializeState(state);

assertNotNull(bytes);
AgentState deserializeState = deserializeState( bytes );

assertEquals( 3, deserializeState.data().size() );
assertEquals( "b", deserializeState.data().get("a") );
assertEquals( "d", deserializeState.data().get("c") );
}

static class NonSerializableElement {
String value = "TEST";
public NonSerializableElement() {
}
}
@Test
public void partiallySerializeStateTest() throws Exception {

Map<String,Object> data = new HashMap<>();
data.put("a", "b");
data.put("f", new NonSerializableElement() );
data.put("c", "d");

final AgentState state = new AgentState(data);

assertThrows(IOException.class, () -> {
serializeState(state);
});

}

}

0 comments on commit 1564efc

Please sign in to comment.