Skip to content

Commit

Permalink
Added check and block for external entity references in XMLProtection (
Browse files Browse the repository at this point in the history
…#1430)

* Added check for external entity references. Block if there are some.

* Minor refactorings

* Formatted code

---------

Co-authored-by: t-burch <119930761+t-burch@users.noreply.github.com>
  • Loading branch information
predic8 and t-burch authored Dec 19, 2024
1 parent 07bc40a commit e2beb06
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.predic8.membrane.core.interceptor.xmlprotection;

public class XMLProtectionException extends Exception {
public XMLProtectionException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,16 @@

package com.predic8.membrane.core.interceptor.xmlprotection;

import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.Iterator;
import org.jetbrains.annotations.*;
import org.slf4j.*;

import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;
import javax.xml.stream.*;
import javax.xml.stream.events.*;
import java.io.*;
import java.util.*;
import java.util.function.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static javax.xml.stream.XMLInputFactory.*;

/**
* Filters XML streams, removing potentially malicious elements:
Expand All @@ -43,15 +39,15 @@
* an error response should be returned to the requestor.
*/
public class XMLProtector {
private static Logger log = LoggerFactory.getLogger(XMLProtector.class.getName());
private static XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
private static final Logger log = LoggerFactory.getLogger(XMLProtector.class.getName());
private static final XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
static {
xmlInputFactory.setProperty(XMLInputFactory.IS_REPLACING_ENTITY_REFERENCES, false);
xmlInputFactory.setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);
xmlInputFactory.setProperty(XMLInputFactory.SUPPORT_DTD,false);
xmlInputFactory.setProperty(IS_REPLACING_ENTITY_REFERENCES, false);
xmlInputFactory.setProperty(IS_SUPPORTING_EXTERNAL_ENTITIES, false);
xmlInputFactory.setProperty(SUPPORT_DTD,false);
}

private XMLEventWriter writer;
private final XMLEventWriter writer;
private final int maxAttibuteCount;
private final int maxElementNameLength;
private final boolean removeDTD;
Expand All @@ -63,10 +59,16 @@ public XMLProtector(OutputStreamWriter osw, boolean removeDTD, int maxElementNam
this.maxAttibuteCount = maxAttibuteCount;

if(!removeDTD)
xmlInputFactory.setProperty(XMLInputFactory.SUPPORT_DTD,true);
xmlInputFactory.setProperty(SUPPORT_DTD,true);
}

public boolean protect(InputStreamReader isr) {
/**
* Is XML secure?
* @param isr Stream with XML
* @return false if there is any security problem in the XML
* @throws XMLProtectionException if there are critical issues like external entity references
*/
public boolean protect(InputStreamReader isr) throws XMLProtectionException {
try {
XMLEventReader parser;
synchronized(xmlInputFactory) {
Expand All @@ -91,7 +93,9 @@ public boolean protect(InputStreamReader isr) {
return false;
}
}
} if (event instanceof javax.xml.stream.events.DTD) {
}
if (event instanceof javax.xml.stream.events.DTD dtd) {
checkExternalEntities(dtd);
if (removeDTD) {
log.debug("removed DTD.");
continue;
Expand All @@ -107,4 +111,24 @@ public boolean protect(InputStreamReader isr) {
return true;
}

private static void checkExternalEntities(DTD dtd) throws XMLProtectionException {
if (containsExternalEntityReferences(dtd)) {
String msg = "Possible attack. External entity found in DTD.";
log.warn(msg);
throw new XMLProtectionException(msg);
}
}

private static boolean containsExternalEntityReferences(DTD dtd) {
var entities = dtd.getEntities();
if (entities == null || entities.isEmpty())
return false;

return entities.stream().anyMatch(isExternalEntity());
}

private static @NotNull Predicate<EntityDeclaration> isExternalEntity() {
return ed -> ed.getPublicId() != null || ed.getSystemId() != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,63 @@

package com.predic8.membrane.core.interceptor.xmlprotection;

import com.predic8.membrane.core.exchange.*;
import com.predic8.membrane.core.http.*;
import com.predic8.membrane.core.interceptor.*;
import com.predic8.membrane.core.util.*;
import org.junit.jupiter.api.*;

import static com.predic8.membrane.core.http.MimeType.*;
import static com.predic8.membrane.core.interceptor.Outcome.*;
import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class XMLProtectionInterceptorTest {
private static Exchange exc;
private static XMLProtectionInterceptor interceptor;

import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.util.ByteUtil;
import com.predic8.membrane.core.util.MessageUtil;
@BeforeAll
public static void setUp() throws Exception {
exc = new Exchange(null);
exc.setRequest(MessageUtil.getGetRequest("/axis2/services/BLZService"));
exc.setOriginalHostHeader("thomas-bayer.com:80");

public class XMLProtectionInterceptorTest {
private static Exchange exc;
private static XMLProtectionInterceptor interceptor;
interceptor = new XMLProtectionInterceptor();
}

@BeforeAll
public static void setUp() throws Exception {
exc = new Exchange(null);
exc.setRequest(MessageUtil.getGetRequest("/axis2/services/BLZService"));
exc.setOriginalHostHeader("thomas-bayer.com:80");
private void runOn(String resource, boolean expectSuccess) throws Exception {
exc.getRequest().getHeader().setContentType(APPLICATION_XML);
exc.getRequest().setBodyContent(ByteUtil.getByteArrayData(this.getClass().getResourceAsStream(resource)));
Outcome outcome = interceptor.handleRequest(exc);
assertEquals(expectSuccess ? CONTINUE : ABORT, outcome);
}

interceptor = new XMLProtectionInterceptor();
}
@Test
void testInvariant() throws Exception {
runOn("/customer.xml", true);
}

private void runOn(String resource, boolean expectSuccess) throws Exception {
exc.getRequest().getHeader().setContentType("application/xml");
exc.getRequest().setBodyContent(ByteUtil.getByteArrayData(this.getClass().getResourceAsStream(resource)));
Outcome outcome = interceptor.handleRequest(exc);
assertEquals(expectSuccess ? Outcome.CONTINUE : Outcome.ABORT, outcome);
}
@Test
void testNotWellformed() throws Exception {
runOn("/xml/not-wellformed.xml", false);
}

@Test
public void testInvariant() throws Exception {
runOn("/customer.xml", true);
}
@Test
void removeDTD() throws Exception {
exc.setRequest(Request.post("/").body("""
<?xml version="1.0" encoding="ISO-8859-1"?>
<!DOCTYPE foo [
<!ELEMENT foo ANY >
]>
<foo/>
""").contentType(APPLICATION_XML).build());

@Test
public void testNotWellformed() throws Exception {
runOn("/xml/not-wellformed.xml", false);
}
// Should pass
assertEquals(CONTINUE, interceptor.handleRequest(exc));

// Should still contain the XML
assertTrue(exc.getRequest().getBodyAsStringDecoded().contains("<foo"));

// DTD should be removed
assertFalse(exc.getRequest().getBodyAsStringDecoded().contains("DOCTYPE"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
public class XMLProtectorTest {

private static final Logger LOG = LoggerFactory.getLogger(XMLProtectorTest.class);
private XMLProtector xmlProtector;
private byte[] input, output;
private byte[] input, output;

private boolean runOn(String resource) throws Exception {
return runOn(resource, true);
}

private boolean runOn(String resource, boolean removeDTD) throws Exception {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
xmlProtector = new XMLProtector(new OutputStreamWriter(baos, UTF_8), removeDTD, 1000, 1000);
XMLProtector xmlProtector = new XMLProtector(new OutputStreamWriter(baos, UTF_8), removeDTD, 1000, 1000);
input = ByteUtil.getByteArrayData(this.getClass().getResourceAsStream(resource));

if (resource.endsWith("lmx")) {
Expand Down Expand Up @@ -65,50 +64,41 @@ private void reverse() {
}

@Test
public void testInvariant() throws Exception {
void invariant() throws Exception {
assertTrue(runOn("/customer.xml"));
}

@Test
public void testNotWellformed() throws Exception {
void NotWellformed() throws Exception {
assertFalse(runOn("/xml/not-wellformed.xml"));
}

@Test
public void testDTDRemoval1() throws Exception {
void DTDRemoval() throws Exception {
assertTrue(runOn("/xml/entity-expansion.lmx"));
assertTrue(output.length < input.length / 2);
assertFalse(new String(output).contains("ENTITY"));
}

@Test
public void testDTDRemoval2() throws Exception {
assertTrue(runOn("/xml/entity-external.xml"));
assertTrue(output.length < input.length * 2 / 3);
assertFalse(new String(output).contains("ENTITY"));
}

@Test
public void testExpandingEntities() throws Exception {
void expandingEntities() throws Exception {
assertTrue(runOn("/xml/entity-expansion.lmx", false));
assertTrue(output.length > input.length / 2);
assertTrue(new String(output).contains("ENTITY"));
}

@Test
public void testExternalEntities() throws Exception {
assertTrue(runOn("/xml/entity-external.xml", false));
assertTrue(output.length > input.length * 2 / 3);
assertTrue(new String(output).contains("ENTITY"));
void externalEntities() {
assertThrows(XMLProtectionException.class, () -> runOn("/xml/entity-external.xml", false));
}

@Test
public void testLongElementName() throws Exception {
void longElementName() throws Exception {
assertFalse(runOn("/xml/long-element-name.xml"));
}

@Test
public void testManyAttributes() throws Exception {
void manyAttributes() throws Exception {
assertFalse(runOn("/xml/many-attributes.xml"));
}
}

0 comments on commit e2beb06

Please sign in to comment.