Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added check and block for external entity references in XMLProtection #1430

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.predic8.membrane.core.interceptor.xmlprotection;

public class XMLProtectionException extends Exception {
public XMLProtectionException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,16 @@

package com.predic8.membrane.core.interceptor.xmlprotection;

import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.Iterator;
import org.jetbrains.annotations.*;
import org.slf4j.*;

import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;
import javax.xml.stream.*;
import javax.xml.stream.events.*;
import java.io.*;
import java.util.*;
import java.util.function.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static javax.xml.stream.XMLInputFactory.*;

/**
* Filters XML streams, removing potentially malicious elements:
Expand All @@ -43,15 +39,15 @@
* an error response should be returned to the requestor.
*/
public class XMLProtector {
private static Logger log = LoggerFactory.getLogger(XMLProtector.class.getName());
private static XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
private static final Logger log = LoggerFactory.getLogger(XMLProtector.class.getName());
private static final XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
static {
xmlInputFactory.setProperty(XMLInputFactory.IS_REPLACING_ENTITY_REFERENCES, false);
xmlInputFactory.setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);
xmlInputFactory.setProperty(XMLInputFactory.SUPPORT_DTD,false);
xmlInputFactory.setProperty(IS_REPLACING_ENTITY_REFERENCES, false);
xmlInputFactory.setProperty(IS_SUPPORTING_EXTERNAL_ENTITIES, false);
xmlInputFactory.setProperty(SUPPORT_DTD,false);
}

private XMLEventWriter writer;
private final XMLEventWriter writer;
private final int maxAttibuteCount;
private final int maxElementNameLength;
private final boolean removeDTD;
Expand All @@ -63,10 +59,16 @@ public XMLProtector(OutputStreamWriter osw, boolean removeDTD, int maxElementNam
this.maxAttibuteCount = maxAttibuteCount;

if(!removeDTD)
xmlInputFactory.setProperty(XMLInputFactory.SUPPORT_DTD,true);
xmlInputFactory.setProperty(SUPPORT_DTD,true);
}

public boolean protect(InputStreamReader isr) {
/**
* Is XML secure?
* @param isr Stream with XML
* @return false if there is any security problem in the XML
* @throws XMLProtectionException if there are critical issues like external entity references
*/
public boolean protect(InputStreamReader isr) throws XMLProtectionException {
try {
XMLEventReader parser;
synchronized(xmlInputFactory) {
Expand All @@ -91,7 +93,9 @@ public boolean protect(InputStreamReader isr) {
return false;
}
}
} if (event instanceof javax.xml.stream.events.DTD) {
}
if (event instanceof javax.xml.stream.events.DTD dtd) {
checkExternalEntities(dtd);
if (removeDTD) {
log.debug("removed DTD.");
continue;
Expand All @@ -107,4 +111,24 @@ public boolean protect(InputStreamReader isr) {
return true;
}

private static void checkExternalEntities(DTD dtd) throws XMLProtectionException {
if (containsExternalEntityReferences(dtd)) {
String msg = "Possible attack. External entity found in DTD.";
log.warn(msg);
throw new XMLProtectionException(msg);
}
}

private static boolean containsExternalEntityReferences(DTD dtd) {
var entities = dtd.getEntities();
if (entities == null || entities.isEmpty())
return false;

return entities.stream().anyMatch(isExternalEntity());
}

private static @NotNull Predicate<EntityDeclaration> isExternalEntity() {
return ed -> ed.getPublicId() != null || ed.getSystemId() != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

package com.predic8.membrane.core.interceptor.xmlprotection;

import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import com.predic8.membrane.core.exchange.*;
import com.predic8.membrane.core.http.*;
import com.predic8.membrane.core.interceptor.*;
import com.predic8.membrane.core.util.*;
import org.junit.jupiter.api.*;

import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.util.ByteUtil;
import com.predic8.membrane.core.util.MessageUtil;
import static com.predic8.membrane.core.http.MimeType.*;
import static com.predic8.membrane.core.interceptor.Outcome.*;
import static org.junit.jupiter.api.Assertions.*;

public class XMLProtectionInterceptorTest {
private static Exchange exc;
Expand All @@ -38,21 +38,39 @@ public static void setUp() throws Exception {
}

private void runOn(String resource, boolean expectSuccess) throws Exception {
exc.getRequest().getHeader().setContentType("application/xml");
exc.getRequest().getHeader().setContentType(APPLICATION_XML);
exc.getRequest().setBodyContent(ByteUtil.getByteArrayData(this.getClass().getResourceAsStream(resource)));
Outcome outcome = interceptor.handleRequest(exc);
assertEquals(expectSuccess ? Outcome.CONTINUE : Outcome.ABORT, outcome);
assertEquals(expectSuccess ? CONTINUE : ABORT, outcome);
}

@Test
public void testInvariant() throws Exception {
void testInvariant() throws Exception {
runOn("/customer.xml", true);
}

@Test
public void testNotWellformed() throws Exception {
void testNotWellformed() throws Exception {
runOn("/xml/not-wellformed.xml", false);
}

@Test
void removeDTD() throws Exception {
exc.setRequest(Request.post("/").body("""
<?xml version="1.0" encoding="ISO-8859-1"?>
<!DOCTYPE foo [
<!ELEMENT foo ANY >
]>
<foo/>
""").contentType(APPLICATION_XML).build());

// Should pass
assertEquals(CONTINUE, interceptor.handleRequest(exc));
predic8 marked this conversation as resolved.
Show resolved Hide resolved

// Should still contain the XML
assertTrue(exc.getRequest().getBodyAsStringDecoded().contains("<foo"));

// DTD should be removed
assertFalse(exc.getRequest().getBodyAsStringDecoded().contains("DOCTYPE"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
public class XMLProtectorTest {

private static final Logger LOG = LoggerFactory.getLogger(XMLProtectorTest.class);
private XMLProtector xmlProtector;
private byte[] input, output;
private byte[] input, output;

private boolean runOn(String resource) throws Exception {
return runOn(resource, true);
}

private boolean runOn(String resource, boolean removeDTD) throws Exception {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
xmlProtector = new XMLProtector(new OutputStreamWriter(baos, UTF_8), removeDTD, 1000, 1000);
XMLProtector xmlProtector = new XMLProtector(new OutputStreamWriter(baos, UTF_8), removeDTD, 1000, 1000);
input = ByteUtil.getByteArrayData(this.getClass().getResourceAsStream(resource));

if (resource.endsWith("lmx")) {
Expand Down Expand Up @@ -65,50 +64,41 @@ private void reverse() {
}

@Test
public void testInvariant() throws Exception {
void invariant() throws Exception {
assertTrue(runOn("/customer.xml"));
}

@Test
public void testNotWellformed() throws Exception {
void NotWellformed() throws Exception {
assertFalse(runOn("/xml/not-wellformed.xml"));
}

@Test
public void testDTDRemoval1() throws Exception {
void DTDRemoval() throws Exception {
assertTrue(runOn("/xml/entity-expansion.lmx"));
assertTrue(output.length < input.length / 2);
assertFalse(new String(output).contains("ENTITY"));
}

@Test
public void testDTDRemoval2() throws Exception {
assertTrue(runOn("/xml/entity-external.xml"));
assertTrue(output.length < input.length * 2 / 3);
assertFalse(new String(output).contains("ENTITY"));
}

@Test
public void testExpandingEntities() throws Exception {
void expandingEntities() throws Exception {
assertTrue(runOn("/xml/entity-expansion.lmx", false));
assertTrue(output.length > input.length / 2);
assertTrue(new String(output).contains("ENTITY"));
}

@Test
public void testExternalEntities() throws Exception {
assertTrue(runOn("/xml/entity-external.xml", false));
assertTrue(output.length > input.length * 2 / 3);
assertTrue(new String(output).contains("ENTITY"));
void externalEntities() {
assertThrows(XMLProtectionException.class, () -> runOn("/xml/entity-external.xml", false));
}

@Test
public void testLongElementName() throws Exception {
void longElementName() throws Exception {
assertFalse(runOn("/xml/long-element-name.xml"));
}

@Test
public void testManyAttributes() throws Exception {
void manyAttributes() throws Exception {
assertFalse(runOn("/xml/many-attributes.xml"));
}
}
Loading