Skip to content

Commit

Permalink
feat(serializer): add custom serializer support
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Sep 2, 2024
1 parent d3595cb commit f958f0c
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.serializer.MapSerializer;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.StateSnapshot;
import org.bsc.langgraph4j.utils.TryConsumer;
Expand Down Expand Up @@ -192,7 +192,7 @@ Map<String,Object> getInitialState(Map<String,Object> inputs, RunnableConfig con

State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundException {

Map<String,Object> newData = MapSerializer.INSTANCE.cloneObject(data);
Map<String,Object> newData = StateSerializer.INSTANCE.cloneObject(data);

return stateGraph.getStateFactory().apply(newData);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ private CheckpointSerializer() {}

public void write( Checkpoint object, ObjectOutput out) throws IOException {
out.writeUTF(object.getId());
MapSerializer.INSTANCE.write( object.getState(), out );
StateSerializer.INSTANCE.write( object.getState(), out );
out.writeUTF( object.getNodeId() );
Serializer.writeUTFNullable( object.getNextNodeId(), out );
}

public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException {
return Checkpoint.builder()
.id(in.readUTF())
.state(MapSerializer.INSTANCE.read( in ))
.state(StateSerializer.INSTANCE.read( in ))
.nodeId(in.readUTF())
.nextNodeId(Serializer.readUTFNullable(in).orElse(null))
.build();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package org.bsc.langgraph4j.serializer;

import lombok.extern.slf4j.Slf4j;

import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

@Slf4j
public class StateSerializer implements Serializer<Map<String,Object>> {

private final Map<Class<?>, Serializer<?> > serializers = new HashMap<>();

public static final StateSerializer INSTANCE = new StateSerializer();
private StateSerializer() {}

public <T> void register( Class<T> clazz, Serializer<T> serializer ) {
Objects.requireNonNull( clazz, "Serializer's class cannot be null" );
Objects.requireNonNull( serializer, "Serializer cannot be null" );

serializers.put( clazz, serializer );
}

public <T> boolean unregister( Class<T> clazz ) {
Objects.requireNonNull( clazz, "Serializer's class cannot be null" );
Serializer<?> serializer = serializers.remove( clazz );

return serializer != null;
}

private Optional<Serializer<Object>> getSerializer(Class<?> clazz ) {
Objects.requireNonNull( clazz, "Serializer's class cannot be null" );
return Optional.ofNullable( (Serializer<Object>)serializers.get( clazz ) );
}

@Override
public void write(Map<String,Object> object, ObjectOutput out) throws IOException {
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
int actualSize = 0;

final ObjectOutputStream tupleStream = new ObjectOutputStream( baos );
for( Map.Entry<String,Object> e : object.entrySet() ) {
try {
tupleStream.writeUTF(e.getKey());

Object value = e.getValue();
Optional<Serializer<Object>> serializer = (value != null) ?
getSerializer(value.getClass()) :
Optional.empty();
// check if written by serializer
if (serializer.isPresent()) {
tupleStream.writeObject(value.getClass());
serializer.get().write(value, tupleStream);
}
else {
tupleStream.writeObject(value);
}
tupleStream.flush();

++actualSize;
} catch (IOException ex) {
log.error( "Error writing map key '{}'", e.getKey(), ex );
throw ex;
}
}

out.writeInt( object.size() );
out.writeInt( actualSize ); // actual size
byte[] data = baos.toByteArray();
out.writeInt( data.length );
out.write( data );

}

}

@Override
public Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
Map<String, Object> data = new HashMap<>();

int expectedSize = in.readInt();
int actualSize = in.readInt();
if( expectedSize > 0 && actualSize > 0 ) {

if( expectedSize != actualSize ) {
final String message = String.format( "Deserialize map: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
log.error( message ) ;
throw new IOException( message ) ;
}

int byteLen = in.readInt();
byte[] bytes = new byte[byteLen];
in.readFully(bytes);

try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream( bais );

for( int i = 0; i < actualSize; i++ ) {
String key = ois.readUTF();
Object value = ois.readObject();
// check if it's a serializer
if( value instanceof Class<?>) {
final Class<?> serializerClass = (Class<?>)value;
Serializer<Object> serializer = getSerializer( serializerClass )
.orElseThrow( () -> new IllegalArgumentException( "No serializer found for class " + serializerClass ) );

value = serializer.read(ois);
}

data.put(key, value);

}
}

}
return data;
}

}
64 changes: 55 additions & 9 deletions core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.serializer.MapSerializer;
import org.bsc.langgraph4j.state.AgentState;
import lombok.ToString;
import org.bsc.langgraph4j.serializer.Serializer;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.junit.jupiter.api.Test;

import java.io.*;
Expand All @@ -16,15 +17,15 @@ public class SerializeTest {
private byte[] serializeState(Map<String,Object> state) throws Exception {
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
ObjectOutputStream oas = new ObjectOutputStream(baos);
MapSerializer.INSTANCE.write(state, oas);
StateSerializer.INSTANCE.write(state, oas);
oas.flush();
return baos.toByteArray();
}
}
private Map<String,Object> deserializeState( byte[] bytes ) throws Exception {
try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream( bais );
return MapSerializer.INSTANCE.read( ois );
return StateSerializer.INSTANCE.read( ois );
}
}

Expand All @@ -33,7 +34,7 @@ public void serializeStateTest() throws Exception {

Map<String,Object> data = new HashMap<>();
data.put("a", "b");
data.put("f", null);
data.put("f", null );
data.put("c", "d");

byte[] bytes = serializeState(data);
Expand All @@ -46,23 +47,68 @@ public void serializeStateTest() throws Exception {
assertEquals( "d", deserializeState.get("c") );
}

static class NonSerializableElement {
String value = "TEST";
@ToString
public static class NonSerializableElement {

String value;
public NonSerializableElement() {
this.value = "default";
}
public NonSerializableElement( String value ) {
this.value = value;
}

}

@Test
public void partiallySerializeStateTest() throws Exception {

Map<String,Object> data = new HashMap<>();
data.put("a", "b");
data.put("f", new NonSerializableElement() );
data.put("f", new NonSerializableElement("I'M NOT SERIALIZABLE") );
data.put("c", "d");

assertThrows(IOException.class, () -> {
assertThrows(java.io.NotSerializableException.class, () -> {
serializeState(data);
});

}

@Test
public void customSerializeStateTest() throws Exception {

StateSerializer.INSTANCE.register(NonSerializableElement.class, new Serializer<NonSerializableElement>() {

@Override
public void write(NonSerializableElement object, ObjectOutput out) throws IOException {
out.writeUTF(object.value);
}

@Override
public NonSerializableElement read(ObjectInput in) throws IOException, ClassNotFoundException {
return new NonSerializableElement(in.readUTF());
}
});

Map<String,Object> data = new HashMap<>();
data.put("a", "b");
data.put("x", new NonSerializableElement("I'M NOT SERIALIZABLE 2") );
data.put("f", 'H' );
data.put("c", "d");

System.out.println( data );

byte[] bytes = serializeState(data);

assertNotNull(bytes);
assertTrue(bytes.length > 0);

Map<String,Object> deserializedData = deserializeState( bytes );

assertNotNull(deserializedData);

System.out.println( deserializedData.get( "x" ).getClass() );
System.out.println( deserializedData );
}

}

0 comments on commit f958f0c

Please sign in to comment.