Skip to content

Commit

Permalink
code refactor, bug fixes and more test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <rishabhmaurya05@gmail.com>
  • Loading branch information
rishabhmaurya committed Dec 4, 2024
1 parent 68aba4f commit 7c7437d
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class RunTask extends DefaultTestClustersTask {
public static final String CUSTOM_SETTINGS_PREFIX = "tests.opensearch.";
private static final int DEFAULT_HTTP_PORT = 9200;
private static final int DEFAULT_TRANSPORT_PORT = 9300;
private static final int DEFAULT_STREAM_PORT = 8815;
private static final int DEFAULT_STREAM_PORT = 9880;
private static final int DEFAULT_DEBUG_PORT = 5005;
public static final String LOCALHOST_ADDRESS_PREFIX = "127.0.0.1:";

Expand Down
6 changes: 6 additions & 0 deletions modules/arrow-flight-rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ dependencies {
runtimeOnly 'org.apache.parquet:parquet-arrow:1.13.1'
}

tasks.named('test').configure {
jacoco {
excludes = ['org/apache/arrow/flight/**']
}
}

tasks.named('forbiddenApisMain').configure {
replaceSignatureFiles 'jdk-signatures'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.arrow.flight.bootstrap;

import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.OpenSearchFlightServer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand Down Expand Up @@ -75,7 +76,7 @@ public FlightService(Settings settings) {
* @param threadPool The ThreadPool instance.
*/
public void initialize(ClusterService clusterService, ThreadPool threadPool) {
this.threadPool.trySet(threadPool);
this.threadPool.trySet(Objects.requireNonNull(threadPool));
if (ServerConfig.isSslEnabled()) {
sslContextProvider = new DefaultSslContextProvider(secureTransportSettingsProvider::get);
} else {
Expand Down Expand Up @@ -103,7 +104,7 @@ protected void doStart() {
(PrivilegedExceptionAction<BufferAllocator>) () -> new RootAllocator(Integer.MAX_VALUE)
);

BaseFlightProducer producer = new BaseFlightProducer(clientManager, streamManager, allocator);
FlightProducer producer = new BaseFlightProducer(clientManager, streamManager, allocator);
FlightServerBuilder builder = new FlightServerBuilder(threadPool.get(), () -> allocator, producer, sslContextProvider);
server = builder.build();
server.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.apache.arrow.flight.OpenSearchFlightClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.Version;
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
Expand All @@ -38,6 +41,8 @@ public class FlightClientManager implements ClusterStateListener, AutoCloseable
private final Supplier<BufferAllocator> allocator;
private final SslContextProvider sslContextProvider;

private static final Logger logger = LogManager.getLogger(FlightClientManager.class);

/**
* Creates a new FlightClientManager instance.
*
Expand Down Expand Up @@ -70,7 +75,8 @@ public void clusterChanged(ClusterChangedEvent event) {
* @return An OpenSearchFlightClient instance for the specified node
*/
public OpenSearchFlightClient getFlightClient(String nodeId) {
return flightClients.computeIfAbsent(nodeId, this::buildFlightClient).flightClient;
FlightClientHolder clientHolder = flightClients.computeIfAbsent(nodeId, this::buildFlightClient);
return clientHolder == null ? null : clientHolder.flightClient;
}

/**
Expand All @@ -79,7 +85,8 @@ public OpenSearchFlightClient getFlightClient(String nodeId) {
* @return The Location of the Flight client for the specified node
*/
public Location getFlightClientLocation(String nodeId) {
return flightClients.computeIfAbsent(nodeId, this::buildFlightClient).location;
FlightClientHolder clientHolder = flightClients.computeIfAbsent(nodeId, this::buildFlightClient);
return clientHolder == null ? null : clientHolder.location;
}

/**
Expand All @@ -105,10 +112,18 @@ public void close() throws Exception {
private FlightClientHolder buildFlightClient(String nodeId) {
DiscoveryNode node = Objects.requireNonNull(clusterService).state().nodes().get(nodeId);
if (node == null) {
throw new IllegalArgumentException("Node with id " + nodeId + " not found in cluster");
return null;
}
Version minVersion = Version.fromString("3.0.0");
if (node.getVersion().before(minVersion)) {
return null;
}
// TODO: handle cases where flight server isn't running like mixed cluster with nodes of previous version
// ideally streaming shouldn't be supported on mixed cluster.

String arrowStreamsEnabled = node.getAttributes().get("arrow.streams.enabled");
if (!"true".equals(arrowStreamsEnabled)) {
return null;
}

String clientPort = node.getAttributes().get("transport.stream.port");
FlightClientBuilder builder = new FlightClientBuilder(
node.getHostAddress(),
Expand All @@ -128,10 +143,7 @@ void updateFlightClients() {

private void initializeFlightClients() {
for (DiscoveryNode node : Objects.requireNonNull(clusterService).state().nodes()) {
String nodeId = node.getId();
if (!flightClients.containsKey(nodeId)) {
getFlightClient(nodeId);
}
getFlightClient(node.getId());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.threadpool.ScalingExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -33,7 +32,7 @@ public ServerConfig() {}

static final Setting<Integer> STREAM_PORT = Setting.intSetting(
"node.attr.transport.stream.port",
8815,
9880,
1024,
65535,
Setting.Property.NodeScope
Expand Down Expand Up @@ -72,7 +71,7 @@ public ServerConfig() {}

static final Setting<Integer> FLIGHT_THREAD_POOL_MAX_SIZE = Setting.intSetting(
"thread_pool.flight-server.max",
100000,
100000, // TODO depends on max concurrent streams per node, decide after benchmark. To be controlled by admission control layer.
1,
Setting.Property.NodeScope
);
Expand All @@ -93,7 +92,6 @@ public ServerConfig() {}

private static final String host = "localhost";
private static int port;
private static ThreadPool threadPool;
private static boolean enableSsl;
private static ScalingExecutorBuilder executorBuilder;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
streamProducerHolder = streamManager.getStreamProducer(streamTicket);
} else {
OpenSearchFlightClient remoteClient = flightClientManager.getFlightClient(streamTicket.getNodeId());
StreamProducer proxyProvider = new ProxyStreamProducer(remoteClient.getStream(ticket));
if (remoteClient == null) {
listener.error(CallStatus.UNAVAILABLE.withDescription("Client doesn't support Stream").cause());
}
StreamProducer proxyProvider = new ProxyStreamProducer(new FlightStreamReader(remoteClient.getStream(ticket)));
streamProducerHolder = new FlightStreamManager.StreamProducerHolder(proxyProvider, allocator);
}
if (streamProducerHolder == null) {
Expand Down Expand Up @@ -144,6 +147,9 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor
return infoBuilder.build();
} else {
OpenSearchFlightClient remoteClient = flightClientManager.getFlightClient(streamTicket.getNodeId());
if (remoteClient == null) {
throw CallStatus.UNAVAILABLE.withDescription("Client doesn't support Stream").toRuntimeException();
}
return remoteClient.getInfo(descriptor);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
*/
public class FlightStreamManager implements StreamManager {

private final DefaultStreamTicketFactory ticketFactory;
private final FlightStreamTicketFactory ticketFactory;
private final FlightClientManager clientManager;
private final Supplier<BufferAllocator> allocatorSupplier;
private final Cache<String, StreamProducerHolder> streamProducers;
Expand All @@ -55,7 +55,7 @@ public FlightStreamManager(Supplier<BufferAllocator> allocatorSupplier, FlightCl
.setExpireAfterWrite(expireAfter)
.setMaximumWeight(MAX_PRODUCERS)
.build();
this.ticketFactory = new DefaultStreamTicketFactory(clientManager::getLocalNodeId);
this.ticketFactory = new FlightStreamTicketFactory(clientManager::getLocalNodeId);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.ExceptionsHelper;
import org.opensearch.arrow.spi.StreamReader;

/**
Expand Down Expand Up @@ -52,10 +53,6 @@ public VectorSchemaRoot getRoot() {
*/
@Override
public void close() {
try {
flightStream.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
ExceptionsHelper.catchAsRuntimeException(flightStream::close);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Default implementation of StreamTicketFactory
*/
@ExperimentalApi
public class DefaultStreamTicketFactory implements StreamTicketFactory {
public class FlightStreamTicketFactory implements StreamTicketFactory {

private final Supplier<String> nodeId;

Expand All @@ -28,7 +28,7 @@ public class DefaultStreamTicketFactory implements StreamTicketFactory {
*
* @param nodeId A Supplier that provides the node ID for the StreamTicket
*/
public DefaultStreamTicketFactory(Supplier<String> nodeId) {
public FlightStreamTicketFactory(Supplier<String> nodeId) {
this.nodeId = nodeId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

package org.opensearch.arrow.flight.core;

import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.ExceptionsHelper;
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.arrow.spi.StreamReader;
import org.opensearch.arrow.spi.StreamTicket;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* ProxyStreamProvider acts as forward proxy for FlightStream.
* It creates a BatchedJob to handle the streaming of data from the remote FlightStream.
Expand All @@ -22,21 +25,21 @@
*/
public class ProxyStreamProducer implements StreamProducer {

private final FlightStream remoteStream;
private final StreamReader remoteStream;

/**
* Constructs a new ProxyStreamProducer instance.
*
* @param remoteStream The remote FlightStream to be proxied.
*/
public ProxyStreamProducer(FlightStream remoteStream) {
public ProxyStreamProducer(StreamReader remoteStream) {
this.remoteStream = remoteStream;
}

/**
* Creates a VectorSchemaRoot for the remote FlightStream.
* @param allocator The allocator to use for creating vectors
* @return
* @return A VectorSchemaRoot representing the schema of the remote FlightStream
*/
@Override
public VectorSchemaRoot createRoot(BufferAllocator allocator) {
Expand Down Expand Up @@ -75,47 +78,35 @@ public String getAction() {
*/
@Override
public void close() {
try {
remoteStream.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
ExceptionsHelper.catchAsRuntimeException(remoteStream::close);
}

static class ProxyBatchedJob implements BatchedJob {

private final FlightStream remoteStream;
private final StreamReader remoteStream;
private final AtomicBoolean isCancelled = new AtomicBoolean(false);

ProxyBatchedJob(FlightStream remoteStream) {
ProxyBatchedJob(StreamReader remoteStream) {
this.remoteStream = remoteStream;
}

@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
while (remoteStream.next()) {
while (!isCancelled.get() && remoteStream.next()) {
flushSignal.awaitConsumption(1000);
}
try {
remoteStream.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

@Override
public void onCancel() {
try {
remoteStream.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
isCancelled.set(true);
}

@Override
public boolean isCancelled() {
// Proxy stream don't have any business logic to set this flag,
// they piggyback on remote stream getting cancelled.
return false;
return isCancelled.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class FlightStreamPluginTests extends OpenSearchTestCase {
@Override
public void setUp() throws Exception {
super.setUp();
settings = Settings.builder().put("node.attr.transport.stream.port", "8815").put(ARROW_STREAMS_SETTING.getKey(), true).build();
settings = Settings.builder().put("node.attr.transport.stream.port", "9880").put(ARROW_STREAMS_SETTING.getKey(), true).build();
clusterService = mock(ClusterService.class);
ClusterState clusterState = mock(ClusterState.class);
DiscoveryNodes nodes = mock(DiscoveryNodes.class);
Expand Down Expand Up @@ -76,10 +76,12 @@ public void testPluginEnableAndDisable() throws IOException {
assertNotNull(settings);
assertFalse(settings.isEmpty());

assertNotNull(plugin.getSecureTransports(null, null, null, null, null, null, null, null));

plugin.close();

Settings disabledSettings = Settings.builder()
.put("node.attr.transport.stream.port", "8815")
.put("node.attr.transport.stream.port", "9880")
.put(ARROW_STREAMS_SETTING.getKey(), false)
.build();
FeatureFlags.initializeFeatureFlags(disabledSettings);
Expand All @@ -102,6 +104,11 @@ public void testPluginEnableAndDisable() throws IOException {
assertTrue(disabledPluginComponents.isEmpty());
assertNull(disabledPlugin.getStreamManager());
assertTrue(disabledPlugin.getExecutorBuilders(disabledSettings).isEmpty());
assertNotNull(disabledPlugin.getSettings());
assertTrue(disabledPlugin.getSettings().isEmpty());

assertNotNull(disabledPlugin.getSecureTransports(null, null, null, null, null, null, null, null));

disabledPlugin.close();
}
}
Loading

0 comments on commit 7c7437d

Please sign in to comment.