-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: finalize checkpoint implementation
add AgentState and Checkpointer serializer add support for MemorySaver work on #11
- Loading branch information
1 parent
b8ab321
commit 1564efc
Showing
12 changed files
with
279 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 4 additions & 2 deletions
6
...anggraph4j/state/BaseCheckpointSaver.java → ...aph4j/checkpoint/BaseCheckpointSaver.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
package org.bsc.langgraph4j.state; | ||
package org.bsc.langgraph4j.checkpoint; | ||
|
||
import java.io.Externalizable; | ||
import java.util.Collection; | ||
import java.util.Optional; | ||
|
||
public interface BaseCheckpointSaver { | ||
|
||
|
||
Collection<Checkpoint> list(); | ||
Optional<Checkpoint> get( String id ); | ||
void put( Checkpoint checkpoint ); | ||
void put( Checkpoint checkpoint ) throws Exception; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...c/langgraph4j/state/CheckpointConfig.java → ...ggraph4j/checkpoint/CheckpointConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
package org.bsc.langgraph4j.state; | ||
package org.bsc.langgraph4j.checkpoint; | ||
|
||
import lombok.Value; | ||
|
||
|
31 changes: 31 additions & 0 deletions
31
core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package org.bsc.langgraph4j.checkpoint; | ||
|
||
import org.bsc.langgraph4j.serializer.CheckpointSerializer; | ||
|
||
import java.util.*; | ||
|
||
import static java.util.Collections.unmodifiableSet; | ||
|
||
public class MemorySaver implements BaseCheckpointSaver { | ||
|
||
private final Set<Checkpoint> checkpoints = new LinkedHashSet<>(); | ||
|
||
|
||
@Override | ||
public Collection<Checkpoint> list() { | ||
return unmodifiableSet(checkpoints); // immutable checkpoints; | ||
} | ||
|
||
@Override | ||
public Optional<Checkpoint> get(String id) { | ||
return checkpoints.stream() | ||
.filter( checkpoint -> checkpoint.getId().equals(id) ) | ||
.findFirst(); | ||
} | ||
|
||
@Override | ||
public void put(Checkpoint checkpoint) throws Exception { | ||
checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) ); | ||
} | ||
|
||
} |
74 changes: 74 additions & 0 deletions
74
core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package org.bsc.langgraph4j.serializer; | ||
|
||
import lombok.extern.log4j.Log4j; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.bsc.langgraph4j.state.AgentState; | ||
|
||
import java.io.*; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
@Slf4j | ||
public class AgentStateSerializer implements Serializer<AgentState> { | ||
public static final AgentStateSerializer INSTANCE = new AgentStateSerializer(); | ||
private AgentStateSerializer() {} | ||
|
||
@Override | ||
public void write(AgentState object, ObjectOutput out) throws IOException { | ||
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { | ||
int actuoalSize = 0; | ||
|
||
final ObjectOutputStream tupleStream = new ObjectOutputStream( baos ); | ||
for( Map.Entry<String,Object> e : object.data().entrySet() ) { | ||
try { | ||
tupleStream.writeUTF(e.getKey()); | ||
tupleStream.writeObject(e.getValue()); | ||
++actuoalSize; | ||
} catch (IOException ex) { | ||
log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() ); | ||
throw ex; | ||
} | ||
} | ||
|
||
out.writeInt( object.data().size() ); | ||
out.writeInt( actuoalSize ); // actual size | ||
byte[] data = baos.toByteArray(); | ||
out.writeInt( data.length ); | ||
out.write( data ); | ||
|
||
} | ||
|
||
} | ||
|
||
@Override | ||
public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException { | ||
Map<String, Object> data = new HashMap<>(); | ||
|
||
int expectedSize = in.readInt(); | ||
int actualSize = in.readInt(); | ||
if( expectedSize > 0 && actualSize > 0 ) { | ||
|
||
if( expectedSize != actualSize ) { | ||
final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ; | ||
log.error( message ) ; | ||
throw new IOException( message ) ; | ||
} | ||
|
||
int byteLen = in.readInt(); | ||
byte[] bytes = new byte[byteLen]; | ||
in.readFully(bytes); | ||
|
||
try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { | ||
ObjectInputStream ois = new ObjectInputStream( bais ); | ||
|
||
for( int i = 0; i < actualSize; i++ ) { | ||
String key = ois.readUTF(); | ||
Object value = ois.readObject(); | ||
data.put(key, value); | ||
} | ||
} | ||
|
||
} | ||
return new AgentState(data); | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package org.bsc.langgraph4j.serializer; | ||
|
||
import org.bsc.langgraph4j.checkpoint.Checkpoint; | ||
import org.bsc.langgraph4j.state.AgentState; | ||
|
||
import java.io.IOException; | ||
import java.io.ObjectInput; | ||
import java.io.ObjectOutput; | ||
|
||
public class CheckpointSerializer implements Serializer<Checkpoint> { | ||
|
||
public static final CheckpointSerializer INSTANCE = new CheckpointSerializer(); | ||
|
||
private CheckpointSerializer() {} | ||
|
||
public void write( Checkpoint object, ObjectOutput out) throws IOException { | ||
out.writeUTF(object.getId()); | ||
Checkpoint.Value value = object.getValue(); | ||
AgentStateSerializer.INSTANCE.write( value.getState(), out ); | ||
out.writeUTF( value.getNext() ); | ||
} | ||
|
||
public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException { | ||
String id = in.readUTF(); | ||
AgentState state = AgentStateSerializer.INSTANCE.read( in ); | ||
String next = in.readUTF(); | ||
Checkpoint.Value value = Checkpoint.Value.of( state, next ); | ||
return new Checkpoint(id, value); | ||
} | ||
|
||
} |
40 changes: 40 additions & 0 deletions
40
core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package org.bsc.langgraph4j.serializer; | ||
|
||
import org.bsc.langgraph4j.checkpoint.Checkpoint; | ||
|
||
import java.io.*; | ||
import java.util.Objects; | ||
|
||
public interface Serializer<T> { | ||
|
||
void write(T object, ObjectOutput out) throws IOException; | ||
|
||
|
||
T read(ObjectInput in) throws IOException, ClassNotFoundException ; | ||
|
||
default byte[] writeObject(T object) throws IOException { | ||
Objects.requireNonNull( object, "object cannot be null" ); | ||
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { | ||
ObjectOutputStream oas = new ObjectOutputStream(baos); | ||
write(object, oas); | ||
oas.flush(); | ||
return baos.toByteArray(); | ||
} | ||
} | ||
|
||
default T readObject(byte[] bytes) throws IOException, ClassNotFoundException { | ||
Objects.requireNonNull( bytes, "bytes cannot be null" ); | ||
if( bytes.length == 0 ) { | ||
throw new IllegalArgumentException("bytes cannot be empty"); | ||
} | ||
try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { | ||
ObjectInputStream ois = new ObjectInputStream(bais); | ||
return read(ois); | ||
} | ||
} | ||
|
||
default T cloneObject(T object) throws IOException, ClassNotFoundException { | ||
Objects.requireNonNull( object, "object cannot be null" ); | ||
return readObject(writeObject(object)); | ||
} | ||
} |
72 changes: 72 additions & 0 deletions
72
core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import org.bsc.langgraph4j.serializer.AgentStateSerializer; | ||
import org.bsc.langgraph4j.state.AgentState; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.io.*; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
public class SerializeTest { | ||
|
||
|
||
private byte[] serializeState(AgentState state) throws Exception { | ||
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { | ||
ObjectOutputStream oas = new ObjectOutputStream(baos); | ||
AgentStateSerializer.INSTANCE.write(state, oas); | ||
oas.flush(); | ||
return baos.toByteArray(); | ||
} | ||
} | ||
private AgentState deserializeState( byte[] bytes ) throws Exception { | ||
try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { | ||
ObjectInputStream ois = new ObjectInputStream( bais ); | ||
return AgentStateSerializer.INSTANCE.read( ois ); | ||
} | ||
} | ||
|
||
@Test | ||
public void serializeStateTest() throws Exception { | ||
|
||
Map<String,Object> data = new HashMap<>(); | ||
data.put("a", "b"); | ||
data.put("f", null); | ||
data.put("c", "d"); | ||
|
||
final AgentState state = new AgentState(data); | ||
|
||
byte[] bytes = serializeState(state); | ||
|
||
assertNotNull(bytes); | ||
AgentState deserializeState = deserializeState( bytes ); | ||
|
||
assertEquals( 3, deserializeState.data().size() ); | ||
assertEquals( "b", deserializeState.data().get("a") ); | ||
assertEquals( "d", deserializeState.data().get("c") ); | ||
} | ||
|
||
static class NonSerializableElement { | ||
String value = "TEST"; | ||
public NonSerializableElement() { | ||
} | ||
} | ||
@Test | ||
public void partiallySerializeStateTest() throws Exception { | ||
|
||
Map<String,Object> data = new HashMap<>(); | ||
data.put("a", "b"); | ||
data.put("f", new NonSerializableElement() ); | ||
data.put("c", "d"); | ||
|
||
final AgentState state = new AgentState(data); | ||
|
||
assertThrows(IOException.class, () -> { | ||
serializeState(state); | ||
}); | ||
|
||
} | ||
|
||
} |