Skip to content

Commit

Permalink
Prevent Java deserialization of internal classes (#1991)
Browse files Browse the repository at this point in the history
Adversaries might be able to forge data which can be abused for DoS attacks.
These classes are already writing a replacement JDK object during serialization
for a long time, so this change should not cause any issues.
  • Loading branch information
Marcono1234 authored Oct 13, 2021
1 parent bda2e3d commit e6fae59
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package com.google.gson.internal;

import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.math.BigDecimal;

Expand Down Expand Up @@ -77,6 +80,11 @@ private Object writeReplace() throws ObjectStreamException {
return new BigDecimal(value);
}

private void readObject(ObjectInputStream in) throws IOException {
// Don't permit directly deserializing this class; writeReplace() should have written a replacement
throw new InvalidObjectException("Deserialization is unsupported");
}

@Override
public int hashCode() {
return value.hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package com.google.gson.internal;

import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.AbstractMap;
Expand Down Expand Up @@ -861,4 +864,9 @@ public K next() {
private Object writeReplace() throws ObjectStreamException {
return new LinkedHashMap<K, V>(this);
}

private void readObject(ObjectInputStream in) throws IOException {
// Don't permit directly deserializing this class; writeReplace() should have written a replacement
throw new InvalidObjectException("Deserialization is unsupported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package com.google.gson.internal;

import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.AbstractMap;
Expand Down Expand Up @@ -627,4 +630,9 @@ public K next() {
private Object writeReplace() throws ObjectStreamException {
return new LinkedHashMap<K, V>(this);
}

private void readObject(ObjectInputStream in) throws IOException {
// Don't permit directly deserializing this class; writeReplace() should have written a replacement
throw new InvalidObjectException("Deserialization is unsupported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
*/
package com.google.gson.internal;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.BigDecimal;

import junit.framework.TestCase;

public class LazilyParsedNumberTest extends TestCase {
Expand All @@ -29,4 +36,15 @@ public void testEquals() {
LazilyParsedNumber n1Another = new LazilyParsedNumber("1");
assertTrue(n1.equals(n1Another));
}

public void testJavaSerialization() throws IOException, ClassNotFoundException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream objOut = new ObjectOutputStream(out);
objOut.writeObject(new LazilyParsedNumber("123"));
objOut.close();

ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(out.toByteArray()));
Number deserialized = (Number) objIn.readObject();
assertEquals(new BigDecimal("123"), deserialized);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
import com.google.gson.internal.LinkedHashTreeMap.AvlBuilder;
import com.google.gson.internal.LinkedHashTreeMap.AvlIterator;
import com.google.gson.internal.LinkedHashTreeMap.Node;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
Expand Down Expand Up @@ -224,6 +231,20 @@ public void testDoubleCapacityAllNodesOnLeft() {
}
}

public void testJavaSerialization() throws IOException, ClassNotFoundException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream objOut = new ObjectOutputStream(out);
Map<String, Integer> map = new LinkedHashTreeMap<String, Integer>();
map.put("a", 1);
objOut.writeObject(map);
objOut.close();

ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(out.toByteArray()));
@SuppressWarnings("unchecked")
Map<String, Integer> deserialized = (Map<String, Integer>) objIn.readObject();
assertEquals(Collections.singletonMap("a", 1), deserialized);
}

private static final Node<String, String> head = new Node<String, String>();

private Node<String, String> node(String value) {
Expand Down
20 changes: 20 additions & 0 deletions gson/src/test/java/com/google/gson/internal/LinkedTreeMapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@

package com.google.gson.internal;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
Expand Down Expand Up @@ -140,6 +146,20 @@ public void testEqualsAndHashCode() throws Exception {
MoreAsserts.assertEqualsAndHashCode(map1, map2);
}

public void testJavaSerialization() throws IOException, ClassNotFoundException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream objOut = new ObjectOutputStream(out);
Map<String, Integer> map = new LinkedTreeMap<String, Integer>();
map.put("a", 1);
objOut.writeObject(map);
objOut.close();

ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(out.toByteArray()));
@SuppressWarnings("unchecked")
Map<String, Integer> deserialized = (Map<String, Integer>) objIn.readObject();
assertEquals(Collections.singletonMap("a", 1), deserialized);
}

@SafeVarargs
private <T> void assertIterationOrder(Iterable<T> actual, T... expected) {
ArrayList<T> actualList = new ArrayList<T>();
Expand Down

0 comments on commit e6fae59

Please sign in to comment.