diff --git a/riposte-core/src/main/java/com/nike/riposte/server/Server.java b/riposte-core/src/main/java/com/nike/riposte/server/Server.java index 3b6f8923..8ab17b80 100644 --- a/riposte-core/src/main/java/com/nike/riposte/server/Server.java +++ b/riposte-core/src/main/java/com/nike/riposte/server/Server.java @@ -133,7 +133,7 @@ public void startup() throws CertificateException, IOException, InterruptedExcep serverConfig.requestSecurityValidator(), serverConfig.workerChannelIdleTimeoutMillis(), serverConfig.proxyRouterConnectTimeoutMillis(), serverConfig.incompleteHttpCallTimeoutMillis(), serverConfig.maxOpenIncomingServerChannels(), serverConfig.isDebugChannelLifecycleLoggingEnabled(), - serverConfig.userIdHeaderKeys() + serverConfig.userIdHeaderKeys(), serverConfig.responseCompressionThresholdBytes() ); } diff --git a/riposte-core/src/main/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializer.java b/riposte-core/src/main/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializer.java index 04d3bdfd..e6b39266 100644 --- a/riposte-core/src/main/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializer.java +++ b/riposte-core/src/main/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializer.java @@ -234,6 +234,7 @@ public class HttpChannelInitializer extends ChannelInitializer { private final int maxOpenChannelsThreshold; private final ChannelGroup openChannelsGroup; private final boolean debugChannelLifecycleLoggingEnabled; + private final int responseCompressionThresholdBytes; private final StreamingAsyncHttpClient streamingAsyncHttpClientForProxyRouterEndpoints; @@ -338,7 +339,8 @@ public HttpChannelInitializer(SslContext sslCtx, long incompleteHttpCallTimeoutMillis, int maxOpenChannelsThreshold, boolean debugChannelLifecycleLoggingEnabled, - List userIdHeaderKeys) { + List userIdHeaderKeys, + int responseCompressionThresholdBytes) { if (endpoints == null || endpoints.isEmpty()) throw new IllegalArgumentException("endpoints cannot be empty"); @@ -418,6 +420,7 @@ public HttpChannelInitializer(SslContext sslCtx, cachedResponseFilterHandler = (hasReqResFilters) ? new ResponseFilterHandler(requestAndResponseFilters) : null; this.userIdHeaderKeys = userIdHeaderKeys; + this.responseCompressionThresholdBytes = responseCompressionThresholdBytes; } @Override @@ -463,8 +466,8 @@ public void initChannel(SocketChannel ch) { // request/response/size threshold). This must be after HttpRequestDecoder on the incoming pipeline and // before HttpResponseEncoder on the outbound pipeline (keep in mind that "before" on outbound means // later in the list since outbound is processed in reverse order). - // TODO: Make the threshold configurable - p.addLast(SMART_HTTP_CONTENT_COMPRESSOR_HANDLER_NAME, new SmartHttpContentCompressor(500)); + p.addLast(SMART_HTTP_CONTENT_COMPRESSOR_HANDLER_NAME, + new SmartHttpContentCompressor(responseCompressionThresholdBytes)); // INBOUND - Add the "before security" RequestFilterHandler before security and even before routing // (if we have any filters to apply). This is here before RoutingHandler so that it can intercept requests diff --git a/riposte-core/src/test/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializerTest.java b/riposte-core/src/test/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializerTest.java index 039d1e63..af6bf132 100644 --- a/riposte-core/src/test/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializerTest.java +++ b/riposte-core/src/test/java/com/nike/riposte/server/channelpipeline/HttpChannelInitializerTest.java @@ -154,13 +154,15 @@ public void constructor_works_with_valid_args() { int maxOpenChannelsThreshold = 1000; boolean debugChannelLifecycleLoggingEnabled = true; List userIdHeaderKeys = mock(List.class); + int responseCompressionThresholdBytes = 5678; // when HttpChannelInitializer hci = new HttpChannelInitializer( sslCtx, maxRequestSizeInBytes, endpoints, reqResFilters, longRunningTaskExecutor, riposteErrorHandler, riposteUnhandledErrorHandler, validationService, requestContentDeserializer, responseSender, metricsListener, defaultCompletableFutureTimeoutMillis, accessLogger, pipelineCreateHooks, requestSecurityValidator, workerChannelIdleTimeoutMillis, proxyRouterConnectTimeoutMillis, - incompleteHttpCallTimeoutMillis, maxOpenChannelsThreshold, debugChannelLifecycleLoggingEnabled, userIdHeaderKeys); + incompleteHttpCallTimeoutMillis, maxOpenChannelsThreshold, debugChannelLifecycleLoggingEnabled, userIdHeaderKeys, + responseCompressionThresholdBytes); // then assertThat(extractField(hci, "sslCtx"), is(sslCtx)); @@ -182,6 +184,7 @@ public void constructor_works_with_valid_args() { assertThat(extractField(hci, "maxOpenChannelsThreshold"), is(maxOpenChannelsThreshold)); assertThat(extractField(hci, "debugChannelLifecycleLoggingEnabled"), is(debugChannelLifecycleLoggingEnabled)); assertThat(extractField(hci, "userIdHeaderKeys"), is(userIdHeaderKeys)); + assertThat(extractField(hci, "responseCompressionThresholdBytes"), is(responseCompressionThresholdBytes)); StreamingAsyncHttpClient sahc = extractField(hci, "streamingAsyncHttpClientForProxyRouterEndpoints"); assertThat(extractField(sahc, "idleChannelTimeoutMillis"), is(workerChannelIdleTimeoutMillis)); @@ -206,7 +209,8 @@ public void constructor_gracefully_handles_some_null_args() { HttpChannelInitializer hci = new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), null, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); // then assertThat(extractField(hci, "sslCtx"), nullValue()); @@ -236,7 +240,8 @@ public void constructor_handles_empty_after_security_request_handlers() { HttpChannelInitializer hci = new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), reqResFilters, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); // then RequestFilterHandler beforeSecReqFH = extractField(hci, "beforeSecurityRequestFilterHandler"); @@ -259,7 +264,8 @@ public void constructor_handles_empty_before_security_request_handlers() { HttpChannelInitializer hci = new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), reqResFilters, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); // then RequestFilterHandler beforeSecReqFH = extractField(hci, "afterSecurityRequestFilterHandler"); @@ -277,7 +283,8 @@ public void constructor_throws_IllegalArgumentException_if_endpoints_is_null() { new HttpChannelInitializer( null, 42, null, null, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); } @Test(expected = IllegalArgumentException.class) @@ -286,7 +293,8 @@ public void constructor_throws_IllegalArgumentException_if_endpoints_is_empty() new HttpChannelInitializer( null, 42, Collections.emptyList(), null, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); } @Test(expected = IllegalArgumentException.class) @@ -295,7 +303,8 @@ public void constructor_throws_IllegalArgumentException_if_riposteErrorHandler_i new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), null, null, null, mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); } @Test(expected = IllegalArgumentException.class) @@ -304,7 +313,8 @@ public void constructor_throws_IllegalArgumentException_if_riposteUnhandledError new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), null, null, mock(RiposteErrorHandler.class), null, null, null, mock(ResponseSender.class), null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); } @Test(expected = IllegalArgumentException.class) @@ -313,7 +323,8 @@ public void constructor_throws_IllegalArgumentException_if_responseSender_is_nul new HttpChannelInitializer( null, 42, Arrays.asList(getMockEndpoint("/some/path")), null, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, null, null, 4242L, null, - null, null, 121, 42, 321, 100, false, null); + null, null, 121, 42, 321, 100, false, null, + 123); } private Pair findChannelHandler(List channelHandlers, Class classToFind, boolean findLast) { @@ -347,7 +358,7 @@ private HttpChannelInitializer basicHttpChannelInitializer(SslContext sslCtx, lo sslCtx, 42, Arrays.asList(getMockEndpoint("/some/path")), requestAndResponseFilters, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), validationService, null, mock(ResponseSender.class), null, 4242L, null, null, null, workerChannelIdleTimeoutMillis, 4200, 1234, maxOpenChannelsThreshold, debugChannelLifecycleLoggingEnabled, - null); + null, 123); } @Test @@ -585,7 +596,7 @@ public void initChannel_adds_AccessLogStartHandler_immediately_after_DTraceStart } @Test - public void initChannel_adds_HttpContentCompressor_before_HttpResponseEncoder_for_outbound_handler() { + public void initChannel_adds_SmartHttpContentCompressor_before_HttpResponseEncoder_for_outbound_handler() { // given HttpChannelInitializer hci = basicHttpChannelInitializerNoUtilityHandlers(); @@ -596,19 +607,24 @@ public void initChannel_adds_HttpContentCompressor_before_HttpResponseEncoder_fo ArgumentCaptor channelHandlerArgumentCaptor = ArgumentCaptor.forClass(ChannelHandler.class); verify(channelPipelineMock, atLeastOnce()).addLast(anyString(), channelHandlerArgumentCaptor.capture()); List handlers = channelHandlerArgumentCaptor.getAllValues(); - Pair httpContentCompressor = findChannelHandler(handlers, HttpContentCompressor.class); + Pair httpContentCompressor = findChannelHandler(handlers, SmartHttpContentCompressor.class); Pair httpResponseEncoder = findChannelHandler(handlers, HttpResponseEncoder.class); assertThat(httpContentCompressor, notNullValue()); assertThat(httpResponseEncoder, notNullValue()); - // HttpContentCompressor's index should be later than HttpResponseEncoder's index to verify that it comes BEFORE HttpResponseEncoder on the OUTBOUND handlers - // (since the outbound handlers are processed in reverse order). + // SmartHttpContentCompressor's index should be later than HttpResponseEncoder's index to verify that it comes + // BEFORE HttpResponseEncoder on the OUTBOUND handlers (since the outbound handlers are processed in + // reverse order). assertThat(httpContentCompressor.getLeft(), is(greaterThan(httpResponseEncoder.getLeft()))); + // Verify that SmartHttpContentCompressor's threshold is set to the specified config value. + long expectedThresholdValue = ((Integer)extractField(hci, "responseCompressionThresholdBytes")).longValue(); + assertThat(extractField(httpContentCompressor.getRight(), "responseSizeThresholdBytes"), + is(expectedThresholdValue)); } @Test - public void initChannel_adds_RequestInfoSetterHandler_after_HttpContentCompressor() { + public void initChannel_adds_RequestInfoSetterHandler_after_SmartHttpContentCompressor() { // given HttpChannelInitializer hci = basicHttpChannelInitializerNoUtilityHandlers(); @@ -619,7 +635,7 @@ public void initChannel_adds_RequestInfoSetterHandler_after_HttpContentCompresso ArgumentCaptor channelHandlerArgumentCaptor = ArgumentCaptor.forClass(ChannelHandler.class); verify(channelPipelineMock, atLeastOnce()).addLast(anyString(), channelHandlerArgumentCaptor.capture()); List handlers = channelHandlerArgumentCaptor.getAllValues(); - Pair httpContentCompressor = findChannelHandler(handlers, HttpContentCompressor.class); + Pair httpContentCompressor = findChannelHandler(handlers, SmartHttpContentCompressor.class); Pair requestInfoSetterHandler = findChannelHandler(handlers, RequestInfoSetterHandler.class); assertThat(httpContentCompressor, notNullValue()); diff --git a/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifyAutoPayloadDecompressionComponentTest.java b/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifyAutoPayloadDecompressionComponentTest.java index 9f12f574..3c9fe411 100644 --- a/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifyAutoPayloadDecompressionComponentTest.java +++ b/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifyAutoPayloadDecompressionComponentTest.java @@ -8,6 +8,7 @@ import com.nike.riposte.server.http.StandardEndpoint; import com.nike.riposte.server.http.impl.SimpleProxyRouterEndpoint; import com.nike.riposte.server.testutils.ComponentTestUtils; +import com.nike.riposte.server.testutils.ComponentTestUtils.CompressionType; import com.nike.riposte.util.Matcher; import com.fasterxml.jackson.core.JsonProcessingException; @@ -20,23 +21,12 @@ import org.junit.Test; import org.junit.runner.RunWith; -import java.io.BufferedReader; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.InputStreamReader; import java.util.Arrays; -import java.util.Base64; import java.util.Collection; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.function.Function; -import java.util.zip.DataFormatException; -import java.util.zip.Deflater; -import java.util.zip.GZIPInputStream; -import java.util.zip.GZIPOutputStream; -import java.util.zip.Inflater; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultHttpHeaders; @@ -45,6 +35,12 @@ import io.restassured.response.ExtractableResponse; import static com.nike.riposte.server.componenttest.VerifyAutoPayloadDecompressionComponentTest.DeserializationEndpointWithDecompressionEnabled.SOME_OBJ_FIELD_VALUE_HEADER_KEY; +import static com.nike.riposte.server.testutils.ComponentTestUtils.base64Decode; +import static com.nike.riposte.server.testutils.ComponentTestUtils.base64Encode; +import static com.nike.riposte.server.testutils.ComponentTestUtils.deflatePayload; +import static com.nike.riposte.server.testutils.ComponentTestUtils.gzipPayload; +import static com.nike.riposte.server.testutils.ComponentTestUtils.inflatePayload; +import static com.nike.riposte.server.testutils.ComponentTestUtils.ungzipPayload; import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_ENCODING; import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaders.Names.TRANSFER_ENCODING; @@ -118,35 +114,6 @@ private static void verifyExpectedContentAndTransferHeaders( assertThat(response.header(RECEIVED_TRANSFER_ENCODING_HEADER)).isEqualTo(expectedTransferEncoding); } - private enum CompressionType { - GZIP(VerifyAutoPayloadDecompressionComponentTest::gzipPayload, - VerifyAutoPayloadDecompressionComponentTest::ungzipPayload, - HttpHeaders.Values.GZIP), - DEFLATE(VerifyAutoPayloadDecompressionComponentTest::deflatePayload, - VerifyAutoPayloadDecompressionComponentTest::inflatePayload, - HttpHeaders.Values.DEFLATE); - - private final Function compressionFunction; - private final Function decompressionFunction; - public final String contentEncodingHeaderValue; - - CompressionType(Function compressionFunction, - Function decompressionFunction, - String contentEncodingHeaderValue) { - this.compressionFunction = compressionFunction; - this.decompressionFunction = decompressionFunction; - this.contentEncodingHeaderValue = contentEncodingHeaderValue; - } - - public byte[] compress(String s) { - return compressionFunction.apply(s); - } - - public String decompress(byte[] compressed) { - return decompressionFunction.apply(compressed); - } - } - @DataProvider(value = { "GZIP", "DEFLATE" @@ -311,84 +278,6 @@ public void verify_compression_helper_methods_work_as_expected() { assertThat(gzipped).isNotEqualTo(deflated); } - private static byte[] gzipPayload(String payload) { - ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); - try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(bytesOut)) { - byte[] payloadBytes = payload.getBytes(UTF_8); - gzipOutputStream.write(payloadBytes); - gzipOutputStream.finish(); - return bytesOut.toByteArray(); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static String ungzipPayload(byte[] compressed) { - try { - if ((compressed == null) || (compressed.length == 0)) { - throw new RuntimeException("Null/empty compressed payload. is_null=" + (compressed == null)); - } - - final StringBuilder outStr = new StringBuilder(); - final GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(compressed)); - final BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(gis, "UTF-8")); - - String line; - while ((line = bufferedReader.readLine()) != null) { - outStr.append(line); - } - - return outStr.toString(); - } - catch(IOException ex) { - throw new RuntimeException(ex); - } - } - - private static byte[] deflatePayload(String payload) { - Deflater deflater = new Deflater(6, false); - byte[] payloadBytes = payload.getBytes(UTF_8); - deflater.setInput(payloadBytes); - deflater.finish(); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - while (!deflater.finished()) { - int count = deflater.deflate(buffer); - outputStream.write(buffer, 0, count); - } - - return outputStream.toByteArray(); - } - - private static String inflatePayload(byte[] compressed) { - Inflater inflater = new Inflater(); - inflater.setInput(compressed); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - while (!inflater.finished()) { - try { - int count = inflater.inflate(buffer); - outputStream.write(buffer, 0, count); - } - catch (DataFormatException e) { - throw new RuntimeException(e); - } - } - - return new String(outputStream.toByteArray(), UTF_8); - } - - private static String base64Encode(byte[] bytes) { - return Base64.getEncoder().encodeToString(bytes); - } - - private static byte[] base64Decode(String encodedStr) { - return Base64.getDecoder().decode(encodedStr); - } - private static final String RECEIVED_PAYLOAD_BYTES_AS_BASE64_RESPONSE_HEADER_KEY = "received-payload-bytes-base-64"; private static final String RECEIVED_CONTENT_ENCODING_HEADER = "received-content-encoding"; private static final String RECEIVED_CONTENT_LENGTH_HEADER = "received-content-length"; diff --git a/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifySmartHttpContentCompressorComponentTest.java b/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifySmartHttpContentCompressorComponentTest.java new file mode 100644 index 00000000..0923fb1d --- /dev/null +++ b/riposte-core/src/test/java/com/nike/riposte/server/componenttest/VerifySmartHttpContentCompressorComponentTest.java @@ -0,0 +1,195 @@ +package com.nike.riposte.server.componenttest; + +import com.nike.riposte.server.Server; +import com.nike.riposte.server.config.ServerConfig; +import com.nike.riposte.server.http.Endpoint; +import com.nike.riposte.server.http.RequestInfo; +import com.nike.riposte.server.http.ResponseInfo; +import com.nike.riposte.server.http.StandardEndpoint; +import com.nike.riposte.server.testutils.ComponentTestUtils; +import com.nike.riposte.server.testutils.ComponentTestUtils.CompressionType; +import com.nike.riposte.server.testutils.ComponentTestUtils.NettyHttpClientRequestBuilder; +import com.nike.riposte.server.testutils.ComponentTestUtils.NettyHttpClientResponse; +import com.nike.riposte.util.Matcher; + +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.IOException; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; + +import static com.nike.riposte.server.testutils.ComponentTestUtils.generatePayload; +import static com.nike.riposte.server.testutils.ComponentTestUtils.request; +import static io.netty.handler.codec.http.HttpHeaders.Names.ACCEPT_ENCODING; +import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_ENCODING; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Component test verifying the functionality of {@link com.nike.riposte.server.handler.SmartHttpContentCompressor}. + */ +@RunWith(DataProviderRunner.class) +public class VerifySmartHttpContentCompressorComponentTest { + + private static Server server; + private static ServerConfig serverConfig; + private int incompleteCallTimeoutMillis = 2000; + + @BeforeClass + public static void setUpClass() throws Exception { + serverConfig = new ResponsePayloadCompressionServerConfig(); + server = new Server(serverConfig); + server.startup(); + } + + @AfterClass + public static void tearDown() throws Exception { + server.shutdown(); + } + + @DataProvider(value = { + "GZIP | 499 | false", + "DEFLATE | 499 | false", + "IDENTITY | 499 | false", + "GZIP | 500 | false", + "DEFLATE | 500 | false", + "IDENTITY | 500 | false", + "GZIP | 501 | true", + "DEFLATE | 501 | true", + "IDENTITY | 501 | false" + }, splitBy = "\\|") + @Test + public void response_should_be_compressed_based_on_payload_size_and_accept_encoding_header( + CompressionType compressionType, int desiredUncompressedPayloadSize, boolean expectCompressed + ) throws Exception { + // given + NettyHttpClientRequestBuilder request = request() + .withMethod(HttpMethod.GET) + .withUri(BasicEndpoint.MATCHING_PATH) + .withHeader(ACCEPT_ENCODING, compressionType.contentEncodingHeaderValue) + .withHeader(BasicEndpoint.DESIRED_UNCOMPRESSED_PAYLOAD_SIZE_HEADER_KEY, desiredUncompressedPayloadSize); + + // when + NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(), + incompleteCallTimeoutMillis); + + // then + assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.OK.code()); + + String contentEncodingHeader = serverResponse.headers.get(CONTENT_ENCODING); + String decompressedPayload; + + if (expectCompressed) { + assertThat(contentEncodingHeader).isEqualTo(compressionType.contentEncodingHeaderValue); + decompressedPayload = compressionType.decompress(serverResponse.payloadBytes); + } + else { + assertThat(contentEncodingHeader).isNull(); + decompressedPayload = serverResponse.payload; + } + + assertThat(decompressedPayload).hasSize(desiredUncompressedPayloadSize); + assertThat(decompressedPayload).startsWith(BasicEndpoint.RESPONSE_PAYLOAD_PREFIX); + + } + + @DataProvider(value = { + "GZIP", + "DEFLATE", + "IDENTITY" + }, splitBy = "\\|") + @Test + public void response_should_not_be_compressed_when_ResponseInfo_disables_compression( + CompressionType compressionType + ) throws Exception { + // given + NettyHttpClientRequestBuilder request = request() + .withMethod(HttpMethod.GET) + .withUri(BasicEndpoint.MATCHING_PATH) + .withHeader(ACCEPT_ENCODING, compressionType.contentEncodingHeaderValue) + .withHeader(BasicEndpoint.DESIRED_UNCOMPRESSED_PAYLOAD_SIZE_HEADER_KEY, 1000) + .withHeader(BasicEndpoint.DISABLE_COMPRESSION_HEADER_KEY, "true"); + + // when + NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(), + incompleteCallTimeoutMillis); + + // then + assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.OK.code()); + + assertThat(serverResponse.headers.get(CONTENT_ENCODING)).isNull(); + + assertThat(serverResponse.payload).hasSize(1000); + assertThat(serverResponse.payload).startsWith(BasicEndpoint.RESPONSE_PAYLOAD_PREFIX); + + } + + private static String generatePayloadOfSizeInBytes(String prefix, int length) { + return prefix + generatePayload(length - prefix.length()); + } + + private static class BasicEndpoint extends StandardEndpoint { + + public static final String MATCHING_PATH = "/basicEndpoint"; + public static final String RESPONSE_PAYLOAD_PREFIX = "basic-endpoint-" + UUID.randomUUID().toString(); + public static final String DESIRED_UNCOMPRESSED_PAYLOAD_SIZE_HEADER_KEY = "desired-uncompressed-payload-size"; + public static final String DISABLE_COMPRESSION_HEADER_KEY = "disable-compression"; + + @Override + public CompletableFuture> execute(RequestInfo request, Executor longRunningTaskExecutor, ChannelHandlerContext ctx) { + String responsePayload = generatePayloadOfSizeInBytes( + RESPONSE_PAYLOAD_PREFIX, + Integer.parseInt(request.getHeaders().get(DESIRED_UNCOMPRESSED_PAYLOAD_SIZE_HEADER_KEY)) + ); + boolean disableCompression = "true".equals(request.getHeaders().get(DISABLE_COMPRESSION_HEADER_KEY)); + + return CompletableFuture.completedFuture( + ResponseInfo.newBuilder(responsePayload) + .withPreventCompressedOutput(disableCompression) + .build() + ); + } + + @Override + public Matcher requestMatcher() { + return Matcher.match(MATCHING_PATH, HttpMethod.GET); + } + } + + public static class ResponsePayloadCompressionServerConfig implements ServerConfig { + private final Collection> endpoints = singletonList(new BasicEndpoint()); + + private final int port; + + public ResponsePayloadCompressionServerConfig() { + try { + port = ComponentTestUtils.findFreePort(); + } catch (IOException e) { + throw new RuntimeException("Couldn't allocate port", e); + } + } + + @Override + public Collection> appEndpoints() { + return endpoints; + } + + @Override + public int endpointsPort() { + return port; + } + } + +} diff --git a/riposte-core/src/test/java/com/nike/riposte/server/testutils/ComponentTestUtils.java b/riposte-core/src/test/java/com/nike/riposte/server/testutils/ComponentTestUtils.java index 1307f43a..a94ca7a0 100644 --- a/riposte-core/src/test/java/com/nike/riposte/server/testutils/ComponentTestUtils.java +++ b/riposte-core/src/test/java/com/nike/riposte/server/testutils/ComponentTestUtils.java @@ -8,13 +8,24 @@ import org.apache.commons.lang3.RandomUtils; +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStreamReader; import java.net.ServerSocket; +import java.util.Base64; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import java.util.zip.Inflater; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; @@ -73,6 +84,116 @@ public static ByteBuf createByteBufPayload(int payloadSize) { return Unpooled.wrappedBuffer(generatePayload(payloadSize).getBytes(UTF_8)); } + public static byte[] gzipPayload(String payload) { + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(bytesOut)) { + byte[] payloadBytes = payload.getBytes(UTF_8); + gzipOutputStream.write(payloadBytes); + gzipOutputStream.finish(); + return bytesOut.toByteArray(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static String ungzipPayload(byte[] compressed) { + try { + if ((compressed == null) || (compressed.length == 0)) { + throw new RuntimeException("Null/empty compressed payload. is_null=" + (compressed == null)); + } + + final StringBuilder outStr = new StringBuilder(); + final GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(compressed)); + final BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(gis, "UTF-8")); + + String line; + while ((line = bufferedReader.readLine()) != null) { + outStr.append(line); + } + + return outStr.toString(); + } + catch(IOException ex) { + throw new RuntimeException(ex); + } + } + + public static byte[] deflatePayload(String payload) { + Deflater deflater = new Deflater(6, false); + byte[] payloadBytes = payload.getBytes(UTF_8); + deflater.setInput(payloadBytes); + deflater.finish(); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + while (!deflater.finished()) { + int count = deflater.deflate(buffer); + outputStream.write(buffer, 0, count); + } + + return outputStream.toByteArray(); + } + + public static String inflatePayload(byte[] compressed) { + Inflater inflater = new Inflater(); + inflater.setInput(compressed); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + while (!inflater.finished()) { + try { + int count = inflater.inflate(buffer); + outputStream.write(buffer, 0, count); + } + catch (DataFormatException e) { + throw new RuntimeException(e); + } + } + + return new String(outputStream.toByteArray(), UTF_8); + } + + public static String base64Encode(byte[] bytes) { + return Base64.getEncoder().encodeToString(bytes); + } + + public static byte[] base64Decode(String encodedStr) { + return Base64.getDecoder().decode(encodedStr); + } + + public enum CompressionType { + GZIP(ComponentTestUtils::gzipPayload, + ComponentTestUtils::ungzipPayload, + HttpHeaders.Values.GZIP), + DEFLATE(ComponentTestUtils::deflatePayload, + ComponentTestUtils::inflatePayload, + HttpHeaders.Values.DEFLATE), + IDENTITY(s -> s.getBytes(UTF_8), + b -> new String(b, UTF_8), + HttpHeaders.Values.IDENTITY); + + private final Function compressionFunction; + private final Function decompressionFunction; + public final String contentEncodingHeaderValue; + + CompressionType(Function compressionFunction, + Function decompressionFunction, + String contentEncodingHeaderValue) { + this.compressionFunction = compressionFunction; + this.decompressionFunction = decompressionFunction; + this.contentEncodingHeaderValue = contentEncodingHeaderValue; + } + + public byte[] compress(String s) { + return compressionFunction.apply(s); + } + + public String decompress(byte[] compressed) { + return decompressionFunction.apply(compressed); + } + } + public static Bootstrap createNettyHttpClientBootstrap() { Bootstrap bootstrap = new Bootstrap(); bootstrap.group(new NioEventLoopGroup()) @@ -150,12 +271,16 @@ public static class NettyHttpClientResponse { public final int statusCode; public final HttpHeaders headers; public final String payload; + public final byte[] payloadBytes; public final FullHttpResponse fullHttpResponse; public NettyHttpClientResponse(FullHttpResponse fullHttpResponse) { this.statusCode = fullHttpResponse.getStatus().code(); this.headers = fullHttpResponse.headers(); - this.payload = fullHttpResponse.content().toString(UTF_8); + ByteBuf content = fullHttpResponse.content(); + this.payloadBytes = new byte[content.readableBytes()]; + content.getBytes(content.readerIndex(), this.payloadBytes); + this.payload = new String(this.payloadBytes, UTF_8); this.fullHttpResponse = fullHttpResponse; } } diff --git a/riposte-spi/src/main/java/com/nike/riposte/server/config/ServerConfig.java b/riposte-spi/src/main/java/com/nike/riposte/server/config/ServerConfig.java index 4584246e..a90b2b04 100644 --- a/riposte-spi/src/main/java/com/nike/riposte/server/config/ServerConfig.java +++ b/riposte-spi/src/main/java/com/nike/riposte/server/config/ServerConfig.java @@ -400,6 +400,18 @@ default int maxRequestSizeInBytes() { return 0; } + /** + * @return The size threshold (in bytes) above which response payloads are eligible for gzip/deflate compression. + * Compressing small payloads can actually result in a "compressed" payload that is larger than the original and + * in that case you're using extra CPU for a worse outcome. If a response payload's size is smaller than the + * byte threshold value returned by this method then it will not be automatically compressed. You can also disable + * compression on a per-response basis by setting {@link + * com.nike.riposte.server.http.ResponseInfo#setPreventCompressedOutput(boolean)} to true. + */ + default int responseCompressionThresholdBytes() { + return 500; + } + /** * @return The {@link Executor} that should be used for long running tasks when non-blocking endpoints need to do * blocking I/O and there is no nonblocking driver/client, or if the endpoint needs to do serious number crunching diff --git a/riposte-spi/src/test/java/com/nike/riposte/server/config/ServerConfigTest.java b/riposte-spi/src/test/java/com/nike/riposte/server/config/ServerConfigTest.java index 22beaaf2..d4069025 100644 --- a/riposte-spi/src/test/java/com/nike/riposte/server/config/ServerConfigTest.java +++ b/riposte-spi/src/test/java/com/nike/riposte/server/config/ServerConfigTest.java @@ -47,6 +47,7 @@ public Collection> appEndpoints() { assertThat(defaultImpl.numBossThreads(), is(1)); assertThat(defaultImpl.numWorkerThreads(), is(0)); assertThat(defaultImpl.maxRequestSizeInBytes(), is(0)); + assertThat(defaultImpl.responseCompressionThresholdBytes(), is(500)); assertThat(defaultImpl.createSslContext(), notNullValue()); assertThat(defaultImpl.requestContentValidationService(), nullValue()); assertThat(defaultImpl.isDebugActionsEnabled(), is(false));