Skip to content

Commit

Permalink
Test for TLS channel pipeline handlers (#3303)
Browse files Browse the repository at this point in the history
Adding tests for TLS-required handlers on NettyTLSConnectionInitializer

Signed-off-by: Lucas Saldanha <lucascrsaldanha@gmail.com>
  • Loading branch information
lucassaldanha authored Jan 20, 2022
1 parent 00a3875 commit 1b45cac
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,12 @@ public RlpxAgent build() {
LOG.debug("TLS Configuration found using NettyTLSConnectionInitializer");
connectionInitializer =
new NettyTLSConnectionInitializer(
nodeKey, config, localNode, connectionEvents, metricsSystem, p2pTLSConfiguration);
nodeKey,
config,
localNode,
connectionEvents,
metricsSystem,
p2pTLSConfiguration.get());
} else {
LOG.debug("Using default NettyConnectionInitializer");
connectionInitializer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ private TimeoutHandler<Channel> timeoutHandler(
() -> connectionFuture.completeExceptionally(new TimeoutException(s)));
}

protected void addAdditionalOutboundHandlers(final SocketChannel ch)
void addAdditionalOutboundHandlers(final Channel ch)
throws GeneralSecurityException, IOException {}

protected void addAdditionalInboundHandlers(final SocketChannel ch)
void addAdditionalInboundHandlers(final Channel ch)
throws GeneralSecurityException, IOException {}

private IntSupplier pendingTaskCounter(final EventLoopGroup eventLoopGroup) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
import org.hyperledger.besu.ethereum.p2p.rlpx.handshake.Handshaker;
import org.hyperledger.besu.plugin.services.MetricsSystem;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Optional;
import java.util.function.Supplier;

import io.netty.channel.socket.SocketChannel;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.compression.SnappyFrameDecoder;
Expand All @@ -41,53 +42,65 @@

public class NettyTLSConnectionInitializer extends NettyConnectionInitializer {

private final Optional<TLSConfiguration> p2pTLSConfiguration;
private final Optional<Supplier<TLSContextFactory>> tlsContextFactorySupplier;

public NettyTLSConnectionInitializer(
final NodeKey nodeKey,
final RlpxConfiguration config,
final LocalNode localNode,
final PeerConnectionEventDispatcher eventDispatcher,
final MetricsSystem metricsSystem,
final Optional<TLSConfiguration> p2pTLSConfiguration) {
final TLSConfiguration p2pTLSConfiguration) {
this(
nodeKey,
config,
localNode,
eventDispatcher,
metricsSystem,
defaultTlsContextFactorySupplier(p2pTLSConfiguration));
}

@VisibleForTesting
NettyTLSConnectionInitializer(
final NodeKey nodeKey,
final RlpxConfiguration config,
final LocalNode localNode,
final PeerConnectionEventDispatcher eventDispatcher,
final MetricsSystem metricsSystem,
final Supplier<TLSContextFactory> tlsContextFactorySupplier) {
super(nodeKey, config, localNode, eventDispatcher, metricsSystem);
this.p2pTLSConfiguration = p2pTLSConfiguration;
this.tlsContextFactorySupplier = Optional.ofNullable(tlsContextFactorySupplier);
}

@Override
protected void addAdditionalOutboundHandlers(final SocketChannel ch)
throws GeneralSecurityException, IOException {
if (p2pTLSConfiguration.isPresent()) {
SslContext sslContext =
TLSContextFactory.buildFrom(p2pTLSConfiguration.get()).createNettyClientSslContext();
ch.pipeline().addLast("ssl", sslContext.newHandler(ch.alloc()));
ch.pipeline().addLast(new SnappyFrameDecoder());
ch.pipeline().addLast(new SnappyFrameEncoder());
ch.pipeline()
.addLast(
new LengthFieldBasedFrameDecoder(
LENGTH_MAX_MESSAGE_FRAME, 0, LENGTH_FRAME_SIZE, 0, LENGTH_FRAME_SIZE));
ch.pipeline().addLast(new LengthFieldPrepender(LENGTH_FRAME_SIZE));
void addAdditionalOutboundHandlers(final Channel ch) throws GeneralSecurityException {
if (tlsContextFactorySupplier.isPresent()) {
final SslContext clientSslContext =
tlsContextFactorySupplier.get().get().createNettyClientSslContext();
addHandlersToChannelPipeline(ch, clientSslContext);
}
}

@Override
protected void addAdditionalInboundHandlers(final SocketChannel ch)
throws GeneralSecurityException, IOException {
if (p2pTLSConfiguration.isPresent()) {
SslContext sslContext =
TLSContextFactory.buildFrom(p2pTLSConfiguration.get()).createNettyServerSslContext();
ch.pipeline().addLast("ssl", sslContext.newHandler(ch.alloc()));
ch.pipeline().addLast(new SnappyFrameDecoder());
ch.pipeline().addLast(new SnappyFrameEncoder());
ch.pipeline()
.addLast(
new LengthFieldBasedFrameDecoder(
LENGTH_MAX_MESSAGE_FRAME, 0, LENGTH_FRAME_SIZE, 0, LENGTH_FRAME_SIZE));
ch.pipeline().addLast(new LengthFieldPrepender(LENGTH_FRAME_SIZE));
void addAdditionalInboundHandlers(final Channel ch) throws GeneralSecurityException {
if (tlsContextFactorySupplier.isPresent()) {
final SslContext serverSslContext =
tlsContextFactorySupplier.get().get().createNettyServerSslContext();
addHandlersToChannelPipeline(ch, serverSslContext);
}
}

private void addHandlersToChannelPipeline(final Channel ch, final SslContext sslContext) {
ch.pipeline().addLast(sslContext.newHandler(ch.alloc()));
ch.pipeline().addLast(new SnappyFrameDecoder());
ch.pipeline().addLast(new SnappyFrameEncoder());
ch.pipeline()
.addLast(
new LengthFieldBasedFrameDecoder(
LENGTH_MAX_MESSAGE_FRAME, 0, LENGTH_FRAME_SIZE, 0, LENGTH_FRAME_SIZE));
ch.pipeline().addLast(new LengthFieldPrepender(LENGTH_FRAME_SIZE));
}

@Override
public Handshaker buildInstance() {
return new PlainHandshaker();
Expand All @@ -97,4 +110,20 @@ public Handshaker buildInstance() {
public Framer buildFramer(final HandshakeSecrets secrets) {
return new PlainFramer();
}

@VisibleForTesting
static Supplier<TLSContextFactory> defaultTlsContextFactorySupplier(
final TLSConfiguration tlsConfiguration) {
if (tlsConfiguration == null) {
throw new IllegalStateException("TLSConfiguration cannot be null when using TLS");
}

return () -> {
try {
return TLSContextFactory.buildFrom(tlsConfiguration);
} catch (final Exception e) {
throw new IllegalStateException("Error creating TLSContextFactory", e);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright Hyperledger Besu Contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/

package org.hyperledger.besu.ethereum.p2p.rlpx.connections.netty;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.hyperledger.besu.crypto.NodeKey;
import org.hyperledger.besu.ethereum.p2p.config.RlpxConfiguration;
import org.hyperledger.besu.ethereum.p2p.peers.LocalNode;
import org.hyperledger.besu.ethereum.p2p.rlpx.connections.PeerConnectionEventDispatcher;
import org.hyperledger.besu.metrics.noop.NoOpMetricsSystem;

import java.util.ArrayList;
import java.util.List;

import com.google.common.collect.ImmutableList;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.compression.SnappyFrameDecoder;
import io.netty.handler.codec.compression.SnappyFrameEncoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class NettyTLSConnectionInitializerTest {

@Mock private NodeKey nodeKey;
@Mock private RlpxConfiguration rlpxConfiguration;
@Mock private LocalNode localNode;
@Mock private PeerConnectionEventDispatcher eventDispatcher;
@Mock private TLSContextFactory tlsContextFactory;
@Mock private SslContext sslContext;
@Mock private SslHandler sslHandler;

private NettyTLSConnectionInitializer nettyTLSConnectionInitializer;

@Before
public void before() throws Exception {
nettyTLSConnectionInitializer =
new NettyTLSConnectionInitializer(
nodeKey,
rlpxConfiguration,
localNode,
eventDispatcher,
new NoOpMetricsSystem(),
() -> tlsContextFactory);

when(tlsContextFactory.createNettyServerSslContext()).thenReturn(sslContext);
when(tlsContextFactory.createNettyClientSslContext()).thenReturn(sslContext);
when(sslContext.newHandler(any())).thenReturn(sslHandler);
}

@Test
public void addAdditionalOutboundHandlersIncludesAllExpectedHandlersToChannelPipeline()
throws Exception {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
nettyTLSConnectionInitializer.addAdditionalOutboundHandlers(embeddedChannel);

// TLS
assertThat(embeddedChannel.pipeline().get(SslHandler.class)).isNotNull();

// Snappy compression
assertThat(embeddedChannel.pipeline().get(SnappyFrameDecoder.class)).isNotNull();
assertThat(embeddedChannel.pipeline().get(SnappyFrameEncoder.class)).isNotNull();

// Message Framing
assertThat(embeddedChannel.pipeline().get(LengthFieldBasedFrameDecoder.class)).isNotNull();
assertThat(embeddedChannel.pipeline().get(LengthFieldPrepender.class)).isNotNull();

assertHandlersOrderInPipeline(embeddedChannel.pipeline());
}

@Test
public void addAdditionalInboundHandlersIncludesAllExpectedHandlersToChannelPipeline()
throws Exception {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
nettyTLSConnectionInitializer.addAdditionalInboundHandlers(embeddedChannel);

// TLS
assertThat(embeddedChannel.pipeline().get(SslHandler.class)).isNotNull();

// Snappy compression
assertThat(embeddedChannel.pipeline().get(SnappyFrameDecoder.class)).isNotNull();
assertThat(embeddedChannel.pipeline().get(SnappyFrameEncoder.class)).isNotNull();

// Message Framing
assertThat(embeddedChannel.pipeline().get(LengthFieldBasedFrameDecoder.class)).isNotNull();
assertThat(embeddedChannel.pipeline().get(LengthFieldPrepender.class)).isNotNull();

assertHandlersOrderInPipeline(embeddedChannel.pipeline());
}

private void assertHandlersOrderInPipeline(final ChannelPipeline pipeline) {
// Appending '#0' because Netty adds it to the handler's names
final ArrayList<String> expectedHandlerNamesInOrder =
new ArrayList<>(
ImmutableList.of(
"SslHandler#0",
"SnappyFrameDecoder#0",
"SnappyFrameEncoder#0",
"LengthFieldBasedFrameDecoder#0",
"LengthFieldPrepender#0",
"DefaultChannelPipeline$TailContext#0")); // This final handler is Netty's default

final List<String> actualHandlerNamesInOrder = pipeline.names();
assertThat(actualHandlerNamesInOrder).isEqualTo(expectedHandlerNamesInOrder);
}

@Test
public void defaultTlsContextFactorySupplierThrowsErrorWithNullTLSConfiguration() {
assertThatThrownBy(() -> NettyTLSConnectionInitializer.defaultTlsContextFactorySupplier(null))
.isInstanceOf(IllegalStateException.class)
.hasMessage("TLSConfiguration cannot be null when using TLS");
}

@Test
public void defaultTlsContextFactorySupplierCapturesInternalError() {
final TLSConfiguration tlsConfiguration = mock(TLSConfiguration.class);
when(tlsConfiguration.getKeyStoreType()).thenThrow(new RuntimeException());

assertThatThrownBy(
() ->
NettyTLSConnectionInitializer.defaultTlsContextFactorySupplier(tlsConfiguration)
.get())
.isInstanceOf(IllegalStateException.class)
.hasMessage("Error creating TLSContextFactory");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mock-maker-inline

0 comments on commit 1b45cac

Please sign in to comment.