diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 792d76b1358..2891579d55c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -54,6 +54,7 @@ import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.DefaultHttp2HeadersEncoder; import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; import io.netty.handler.codec.http2.Http2CodecUtil; @@ -69,6 +70,7 @@ import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; +import io.netty.handler.codec.http2.Http2HeadersEncoder; import io.netty.handler.codec.http2.Http2InboundFrameLogger; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.codec.http2.Http2Settings; @@ -150,7 +152,9 @@ static NettyClientHandler newHandler( Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); - Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); + Http2HeadersEncoder encoder = new DefaultHttp2HeadersEncoder( + Http2HeadersEncoder.NEVER_SENSITIVE, false, 16, Integer.MAX_VALUE); + Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(encoder); Http2Connection connection = new DefaultHttp2Connection(false); WeightedFairQueueByteDistributor dist = new WeightedFairQueueByteDistributor(connection); dist.allocationQuantum(16 * 1024); // Make benchmarks fast again. diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 61dfcb15ecf..5ddfc10d98a 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -36,6 +36,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Strings; import com.google.common.base.Ticker; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; @@ -69,6 +70,7 @@ import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; import io.grpc.testing.TlsTesting; +import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelDuplexHandler; @@ -76,6 +78,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoopGroup; import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.local.LocalChannel; @@ -92,6 +95,7 @@ import java.io.InputStream; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -101,6 +105,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; @@ -538,6 +543,46 @@ public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { } } + @Test + public void huffmanCodingShouldNotBePerformed() throws Exception { + String longStringOfA = Strings.repeat("a", 128); + + negotiator = ProtocolNegotiators.serverPlaintext(); + startServer(); + + NettyClientTransport transport = newTransport(ProtocolNegotiators.plaintext(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, false, + TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), + new ReflectiveChannelFactory<>(NioSocketChannel.class), group); + + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER), + longStringOfA); + + callMeMaybe(transport.start(clientTransportListener)); + + AtomicBoolean foundExpectedHeaderBytes = new AtomicBoolean(false); + + transport.channel().pipeline().addFirst(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (msg instanceof ByteBuf) { + if (((ByteBuf) msg).toString(StandardCharsets.UTF_8).contains(longStringOfA)) { + foundExpectedHeaderBytes.set(true); + } + } + super.write(ctx, msg, promise); + } + }); + + new Rpc(transport, headers).halfClose().waitForResponse(); + + if (!foundExpectedHeaderBytes.get()) { + fail("expected to find UTF-8 encoded 'a's in the header"); + } + } + @Test public void maxHeaderListSizeShouldBeEnforcedOnServer() throws Exception { startServer(100, 1);