Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple ChannelFactory from Tcp classes #27286

Merged
merged 4 commits into from
Nov 8, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpReadContext;
import org.elasticsearch.transport.nio.channel.TcpWriteContext;

import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -68,7 +70,7 @@ public class NioTransport extends TcpTransport<NioChannel> {
public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);

private final TcpReadHandler tcpReadHandler = new TcpReadHandler(this);
private final Consumer<NioSocketChannel> contextSetter;
private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap();
private final OpenChannels openChannels = new OpenChannels(logger);
private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>();
Expand All @@ -79,6 +81,7 @@ public class NioTransport extends TcpTransport<NioChannel> {
public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) {
super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
contextSetter = (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(this)), new TcpWriteContext(c));
}

@Override
Expand Down Expand Up @@ -206,7 +209,7 @@ protected void doStart() {

// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, tcpReadHandler));
profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, contextSetter));
bindServer(profileSettings);
}
}
Expand Down Expand Up @@ -243,7 +246,7 @@ final void exceptionCaught(NioSocketChannel channel, Throwable cause) {

private NioClient createClient() {
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
ChannelFactory channelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), tcpReadHandler);
ChannelFactory channelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), contextSetter);
return new NioClient(logger, openChannels, selectorSupplier, defaultConnectionProfile.getConnectTimeout(), channelFactory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.AcceptingSelector;
import org.elasticsearch.transport.nio.SocketSelector;
import org.elasticsearch.transport.nio.TcpReadHandler;

import java.io.Closeable;
import java.io.IOException;
Expand All @@ -39,23 +38,27 @@

public class ChannelFactory {

private final TcpReadHandler handler;
private final Consumer<NioSocketChannel> contextSetter;
private final RawChannelFactory rawChannelFactory;

public ChannelFactory(TcpTransport.ProfileSettings profileSettings, TcpReadHandler handler) {
this(new RawChannelFactory(profileSettings), handler);
public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer<NioSocketChannel> contextSetter) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add a Javadoc for this constructor explaining the purpose of the contextSetter parameter?

this(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), contextSetter);
}

ChannelFactory(RawChannelFactory rawChannelFactory, TcpReadHandler handler) {
this.handler = handler;
ChannelFactory(RawChannelFactory rawChannelFactory, Consumer<NioSocketChannel> contextSetter) {
this.contextSetter = contextSetter;
this.rawChannelFactory = rawChannelFactory;
}

public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector,
Consumer<NioChannel> closeListener) throws IOException {
SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress);
NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel, selector);
channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel));
setContexts(channel);
channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel)));
scheduleChannel(channel, selector);
return channel;
Expand All @@ -65,7 +68,7 @@ public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, S
Consumer<NioChannel> closeListener) throws IOException {
SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel);
NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel, selector);
channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel));
setContexts(channel);
channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel)));
scheduleChannel(channel, selector);
return channel;
Expand Down Expand Up @@ -97,6 +100,12 @@ private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSele
}
}

private void setContexts(NioSocketChannel channel) {
contextSetter.accept(channel);
assert channel.getReadContext() != null : "read context should have been set on channel";
assert channel.getWriteContext() != null : "write context should have been set on channel";
}

static class RawChannelFactory {

private final boolean tcpNoDelay;
Expand All @@ -105,12 +114,13 @@ static class RawChannelFactory {
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;

RawChannelFactory(TcpTransport.ProfileSettings profileSettings) {
tcpNoDelay = profileSettings.tcpNoDelay;
tcpKeepAlive = profileSettings.tcpKeepAlive;
tcpReusedAddress = profileSettings.reuseAddress;
tcpSendBufferSize = Math.toIntExact(profileSettings.sendBufferSize.getBytes());
tcpReceiveBufferSize = Math.toIntExact(profileSettings.receiveBufferSize.getBytes());
RawChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReusedAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize) {
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpReusedAddress = tcpReusedAddress;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
}

SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.elasticsearch.transport.nio.TcpReadHandler;
import org.junit.After;
import org.junit.Before;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.net.InetAddress;
Expand All @@ -36,6 +38,7 @@

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand All @@ -55,12 +58,19 @@ public class ChannelFactoryTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void setupFactory() throws IOException {
rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class);
channelFactory = new ChannelFactory(rawChannelFactory, mock(TcpReadHandler.class));
Consumer contextSetter = mock(Consumer.class);
channelFactory = new ChannelFactory(rawChannelFactory, contextSetter);
listener = mock(Consumer.class);
socketSelector = mock(SocketSelector.class);
acceptingSelector = mock(AcceptingSelector.class);
rawChannel = SocketChannel.open();
rawServerChannel = ServerSocketChannel.open();

doAnswer(invocationOnMock -> {
NioSocketChannel channel = (NioSocketChannel) invocationOnMock.getArguments()[0];
channel.setContexts(mock(ReadContext.class), mock(WriteContext.class));
return null;
}).when(contextSetter).accept(any());
}

@After
Expand Down