Skip to content

Commit

Permalink
introduce factory for stream ticket
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <rishabhmaurya05@gmail.com>
  • Loading branch information
rishabhmaurya committed Nov 26, 2024
1 parent d1f14e3 commit dffb8e6
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,12 @@ public interface StreamManager extends AutoCloseable {
* @throws IllegalStateException if the stream has been cancelled or closed
*/
StreamReader getStreamReader(StreamTicket ticket);

/**
* Gets the StreamTicketFactory instance associated with this StreamManager.
* By default, returns the singleton instance of StreamTicketFactory.
*
* @return the StreamTicketFactory instance
*/
StreamTicketFactory getStreamTicketFactory();
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,13 @@ public interface StreamProducer extends Closeable {
*
* @return Estimated number of rows, or -1 if unknown
*/
default int estimatedRowCount() {
return -1;
}
int estimatedRowCount();

/**
* Task action name
* @return action name
*/
default String getAction() {
return "";
}
String getAction();

/**
* BatchedJob interface for producing stream data in batches.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,4 @@ public interface StreamTicket {
* @return Base64 encoded byte array containing the ticket information
*/
byte[] toBytes();

/**
* Creates a StreamTicket from its serialized byte representation.
*
* @param bytes Base64 encoded byte array containing ticket information
* @return a new StreamTicket instance
* @throws IllegalArgumentException if the input is invalid
*/
static StreamTicket fromBytes(byte[] bytes) {
throw new UnsupportedOperationException("Implementation must be provided by concrete class");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.spi;

import org.opensearch.common.annotation.ExperimentalApi;

/**
* Factory interface for creating and managing StreamTicket instances.
* This factory provides methods to create and deserialize StreamTickets,
* ensuring consistent ticket creation.
*/
@ExperimentalApi
public interface StreamTicketFactory {
/**
* Generates a new StreamTicket
*
* @return A new StreamTicket instance
*/
StreamTicket generateTicket();

/**
* Deserializes a StreamTicket from its byte representation.
*
* @param bytes The byte array containing the serialized ticket data
* @return A StreamTicket instance reconstructed from the byte array
* @throws IllegalArgumentException if bytes is null or invalid
*/
StreamTicket fromBytes(byte[] bytes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamM
*/
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
StreamTicket streamTicket = FlightStreamTicket.fromBytes(ticket.getBytes());
StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(ticket.getBytes());
try {
FlightStreamManager.StreamProducerHolder streamProducerHolder;
if (streamTicket.getNodeID().equals(flightClientManager.getLocalNodeId())) {
Expand Down Expand Up @@ -127,7 +127,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
@Override
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
// TODO: this api should only be used internally
StreamTicket streamTicket = FlightStreamTicket.fromBytes(descriptor.getCommand());
StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(descriptor.getCommand());
FlightStreamManager.StreamProducerHolder streamProducerHolder;
if (streamTicket.getNodeID().equals(flightClientManager.getLocalNodeId())) {
streamProducerHolder = streamManager.getStreamProducer(streamTicket);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.flight.core;

import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.arrow.spi.StreamTicketFactory;
import org.opensearch.common.annotation.ExperimentalApi;

import java.util.UUID;
import java.util.function.Supplier;

/**
* Default implementation of StreamTicketFactory
*/
@ExperimentalApi
public class DefaultStreamTicketFactory implements StreamTicketFactory {

private final Supplier<String> nodeId;

/**
* Constructs a new DefaultStreamTicketFactory instance.
*
* @param nodeId A Supplier that provides the node ID for the StreamTicket
*/
public DefaultStreamTicketFactory(Supplier<String> nodeId) {
this.nodeId = nodeId;
}

/**
* Generates a new StreamTicket with a unique ticket ID.
*
* @return A new StreamTicket instance
*/
@Override
public StreamTicket generateTicket() {
return new FlightStreamTicket(generateUniqueTicket(), nodeId.get());
}

/**
* Deserializes a StreamTicket from its byte representation.
*
* @param bytes The byte array containing the serialized ticket data
* @return A StreamTicket instance reconstructed from the byte array
* @throws IllegalArgumentException if bytes is null or invalid
*/
@Override
public StreamTicket fromBytes(byte[] bytes) {
return FlightStreamTicket.fromBytes(bytes);
}

private String generateUniqueTicket() {
return UUID.randomUUID().toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.arrow.spi.StreamReader;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.arrow.spi.StreamTicketFactory;
import org.opensearch.common.SetOnce;
import org.opensearch.common.cache.Cache;
import org.opensearch.common.cache.CacheBuilder;
Expand All @@ -33,7 +34,7 @@
*/
public class FlightStreamManager implements StreamManager {

private final StreamTicketFactory ticketFactory;
private final DefaultStreamTicketFactory ticketFactory;
private final FlightClientManager clientManager;
private final Supplier<BufferAllocator> allocatorSupplier;
private final Cache<String, StreamProducerHolder> streamProducers;
Expand All @@ -54,7 +55,7 @@ public FlightStreamManager(Supplier<BufferAllocator> allocatorSupplier, FlightCl
.setExpireAfterWrite(expireAfter)
.setMaximumWeight(MAX_PRODUCERS)
.build();
this.ticketFactory = new StreamTicketFactory(clientManager::getLocalNodeId);
this.ticketFactory = new DefaultStreamTicketFactory(clientManager::getLocalNodeId);
}

/**
Expand All @@ -65,7 +66,7 @@ public FlightStreamManager(Supplier<BufferAllocator> allocatorSupplier, FlightCl
*/
@Override
public StreamTicket registerStream(StreamProducer provider, TaskId parentTaskId) {
FlightStreamTicket ticket = ticketFactory.createTicket();
StreamTicket ticket = ticketFactory.generateTicket();
streamProducers.put(ticket.getTicketID(), new StreamProducerHolder(provider, allocatorSupplier.get()));
return ticket;
}
Expand All @@ -81,6 +82,11 @@ public StreamReader getStreamReader(StreamTicket ticket) {
return new FlightStreamReader(stream);
}

@Override
public StreamTicketFactory getStreamTicketFactory() {
return ticketFactory;
}

/**
* Retrieves the ArrowStreamProvider associated with the given StreamTicket.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public byte[] toBytes() {
return Base64.getEncoder().encode(buffer.array());
}

public static StreamTicket fromBytes(byte[] bytes) {
static StreamTicket fromBytes(byte[] bytes) {
if (bytes == null || bytes.length < 4) {
throw new IllegalArgumentException("Invalid byte array input.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ public BatchedJob createJob(BufferAllocator allocator) {
return new ProxyBatchedJob(remoteStream);
}

/**
* Provides an estimate of the total number of rows that will be produced.
*/
@Override
public int estimatedRowCount() {
// TODO get it from remote flight stream
return -1;
}

/**
* Task action name
*/
@Override
public String getAction() {
// TODO get it from remote flight stream
return "";
}

/**
* Closes the remote FlightStream.
*/
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,6 @@ public void testInitializeWithoutSecureTransportSettingsProvider() {
});
}

public void testDoubleInitialization() {
flightService.initialize(clusterService, threadPool);

flightService.initialize(clusterService, threadPool);

assertNotNull(flightService.getStreamManager());
}

public void testStopWithoutStart() {
flightService.initialize(clusterService, threadPool);

Expand Down Expand Up @@ -182,6 +174,11 @@ public void testLifecycleStateTransitions() throws Exception {
assertEquals("CLOSED", testService.lifecycleState().toString());
}

@Override
public void tearDown() throws Exception {
super.tearDown();
}

private void verifyServerRunning(FlightService flightService, int clientPort) throws InterruptedException {
FlightClientBuilder builder = new FlightClientBuilder(
"localhost",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class BaseFlightProducerTests extends OpenSearchTestCase {
public void setUp() throws Exception {
super.setUp();
streamManager = mock(FlightStreamManager.class);
when(streamManager.getStreamTicketFactory()).thenReturn(new DefaultStreamTicketFactory(() -> LOCAL_NODE_ID));
when(flightClientManager.getLocalNodeId()).thenReturn(LOCAL_NODE_ID);
allocator = mock(BufferAllocator.class);
streamProducer = mock(StreamProducer.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public StreamReader getStreamReader(StreamTicket ticket) {
return streamManager.getStreamReader(ticket);
}

@Override
public StreamTicketFactory getStreamTicketFactory() {
return streamManager.getStreamTicketFactory();
}

@Override
public void close() throws Exception {
streamManager.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public String getAction() {
}
}, searchContext.getTask().getParentTaskId());
StreamSearchResult streamSearchResult = searchContext.streamSearchResult();
streamSearchResult.flights(List.of(new OSTicket(ticket)));
streamSearchResult.flights(List.of(new OSTicket(ticket.toBytes())));
return false;
}
}
Expand Down
15 changes: 7 additions & 8 deletions server/src/main/java/org/opensearch/search/stream/OSTicket.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.search.stream;

import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -25,30 +26,28 @@
@ExperimentalApi
public class OSTicket implements Writeable, ToXContentFragment {

private final StreamTicket streamTicket;
private final byte[] bytes;

public OSTicket(StreamTicket ticket) {
this.streamTicket = ticket;
public OSTicket(byte[] bytes) {
this.bytes = bytes;
}

public OSTicket(StreamInput in) throws IOException {
byte[] bytes = in.readByteArray();
this.streamTicket = StreamTicket.fromBytes(bytes);
bytes = in.readByteArray();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
byte[] bytes = streamTicket.toBytes();
return builder.value(new String(bytes, StandardCharsets.UTF_8));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(streamTicket.toBytes());
out.writeByteArray(bytes);
}

@Override
public String toString() {
return "OSTicket{" + "ticketID='" + streamTicket.getTicketID() + '\'' + ", nodeID='" + streamTicket.getNodeID() + '\'' + '}';
return "OSTicket{" + new String(bytes, StandardCharsets.UTF_8) + "}";
}
}

0 comments on commit dffb8e6

Please sign in to comment.