Skip to content

Commit

Permalink
Fix download reporting for segment replication (#10644) (#10705)
Browse files Browse the repository at this point in the history
(cherry picked from commit c001417)

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 21045f8 commit f934b33
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.transport.TransportService;

import java.util.List;
import java.util.function.BiConsumer;

import static org.opensearch.indices.replication.SegmentReplicationSourceService.Actions.GET_CHECKPOINT_INFO;
import static org.opensearch.indices.replication.SegmentReplicationSourceService.Actions.GET_SEGMENT_FILES;
Expand Down Expand Up @@ -80,8 +81,13 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
// fileProgressTracker is a no-op for node to node recovery
// MultiFileWriter takes care of progress tracking for downloads in this scenario
// TODO: Move state management and tracking into replication methods and use chunking and data
// copy mechanisms only from MultiFileWriter
final Writeable.Reader<GetSegmentFilesResponse> reader = GetSegmentFilesResponse::new;
final ActionListener<GetSegmentFilesResponse> responseListener = ActionListener.map(listener, r -> r);
final GetSegmentFilesRequest request = new GetSegmentFilesRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -95,6 +96,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
try {
Expand All @@ -117,7 +119,12 @@ public void getSegmentFiles(
assert directoryFiles.contains(file) == false : "Local store already contains the file " + file;
toDownloadSegmentNames.add(file);
}
indexShard.getFileDownloader().download(remoteDirectory, storeDirectory, toDownloadSegmentNames);
indexShard.getFileDownloader()
.download(
remoteDirectory,
new ReplicationStatsDirectoryWrapper(storeDirectory, fileProgressTracker),
toDownloadSegmentNames
);
logger.debug("Downloaded segment files from remote store {}", filesToFetch);
} finally {
indexShard.store().decRef();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@

package org.opensearch.indices.replication;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.common.util.CancellableThreads.ExecutionCancelledException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.store.StoreFileMetadata;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;

import java.io.IOException;
import java.util.List;
import java.util.function.BiConsumer;

/**
* Represents the source of a replication event.
Expand All @@ -39,13 +45,15 @@ public interface SegmentReplicationSource {
* @param checkpoint {@link ReplicationCheckpoint} Checkpoint to fetch metadata for.
* @param filesToFetch {@link List} List of files to fetch.
* @param indexShard {@link IndexShard} Reference to the IndexShard.
* @param fileProgressTracker {@link BiConsumer} A consumer that updates the replication progress for shard files.
* @param listener {@link ActionListener} Listener that completes with the list of files copied.
*/
void getSegmentFiles(
long replicationId,
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
);

Expand All @@ -58,4 +66,69 @@ void getSegmentFiles(
* Cancel any ongoing requests, should resolve any ongoing listeners with onFailure with a {@link ExecutionCancelledException}.
*/
default void cancel() {}

/**
* Directory wrapper that records copy process for replication statistics
*
* @opensearch.internal
*/
final class ReplicationStatsDirectoryWrapper extends FilterDirectory {
private final BiConsumer<String, Long> fileProgressTracker;

ReplicationStatsDirectoryWrapper(Directory in, BiConsumer<String, Long> fileProgressTracker) {
super(in);
this.fileProgressTracker = fileProgressTracker;
}

@Override
public void copyFrom(Directory from, String src, String dest, IOContext context) throws IOException {
// here we wrap the index input form the source directory to report progress of file copy for the recovery stats.
// we increment the num bytes recovered in the readBytes method below, if users pull statistics they can see immediately
// how much has been recovered.
in.copyFrom(new FilterDirectory(from) {
@Override
public IndexInput openInput(String name, IOContext context) throws IOException {
final IndexInput input = in.openInput(name, context);
return new IndexInput("StatsDirectoryWrapper(" + input.toString() + ")") {
@Override
public void close() throws IOException {
input.close();
}

@Override
public long getFilePointer() {
throw new UnsupportedOperationException("only straight copies are supported");
}

@Override
public void seek(long pos) throws IOException {
throw new UnsupportedOperationException("seeks are not supported");
}

@Override
public long length() {
return input.length();
}

@Override
public IndexInput slice(String sliceDescription, long offset, long length) throws IOException {
throw new UnsupportedOperationException("slices are not supported");
}

@Override
public byte readByte() throws IOException {
throw new UnsupportedOperationException("use a buffer if you wanna perform well");
}

@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
// we rely on the fact that copyFrom uses a buffer
input.readBytes(b, offset, len);
fileProgressTracker.accept(dest, (long) len);
}
};
}
}, src, dest, context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,14 @@ public void startReplication(ActionListener<Void> listener) {
final List<StoreFileMetadata> filesToFetch = getFiles(checkpointInfo);
state.setStage(SegmentReplicationState.Stage.GET_FILES);
cancellableThreads.checkForCancel();
source.getSegmentFiles(getId(), checkpointInfo.getCheckpoint(), filesToFetch, indexShard, getFilesListener);
source.getSegmentFiles(
getId(),
checkpointInfo.getCheckpoint(),
filesToFetch,
indexShard,
this::updateFileRecoveryBytes,
getFilesListener
);
}, listener::onFailure);

getFilesListener.whenComplete(response -> {
Expand Down Expand Up @@ -240,6 +247,20 @@ private boolean validateLocalChecksum(StoreFileMetadata file) {
}
}

/**
* Updates the state to reflect recovery progress for the given file and
* updates the last access time for the target.
* @param fileName Name of the file being downloaded
* @param bytesRecovered Number of bytes recovered
*/
private void updateFileRecoveryBytes(String fileName, long bytesRecovered) {
ReplicationLuceneIndex index = state.getIndex();
if (index != null) {
index.addRecoveredBytesToFile(fileName, bytesRecovered);
}
setLastAccessTime();
}

private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse) throws OpenSearchCorruptionException {
cancellableThreads.checkForCancel();
state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

import static org.opensearch.index.engine.EngineTestCase.assertAtMostOneLuceneDocumentPerSequenceNumber;
Expand Down Expand Up @@ -388,9 +389,10 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
super.getSegmentFiles(replicationId, checkpoint, filesToFetch, indexShard, listener);
super.getSegmentFiles(replicationId, checkpoint, filesToFetch, indexShard, (fileName, bytesRecovered) -> {}, listener);
runAfterGetFiles[index.getAndIncrement()].run();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.index.engine.EngineTestCase.assertAtMostOneLuceneDocumentPerSequenceNumber;
Expand Down Expand Up @@ -725,6 +726,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
// set the listener, we will only fail it once we cancel the source.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -87,6 +88,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
// randomly resolve the listener, indicating the source has resolved.
Expand Down Expand Up @@ -131,6 +133,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
Assert.fail("Should not be reached");
Expand Down Expand Up @@ -176,6 +179,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
Assert.fail("Unreachable");
Expand Down Expand Up @@ -223,6 +227,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {}
};
Expand Down Expand Up @@ -269,6 +274,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public void testGetSegmentFiles() {
checkpoint,
Arrays.asList(testMetadata),
mock(IndexShard.class),
(fileName, bytesRecovered) -> {},
mock(ActionListener.class)
);
CapturingTransport.CapturedRequest[] requestList = transport.getCapturedRequestsAndClear();
Expand Down Expand Up @@ -153,6 +154,7 @@ public void testTransportTimeoutForGetSegmentFilesAction() {
checkpoint,
Arrays.asList(testMetadata),
mock(IndexShard.class),
(fileName, bytesRecovered) -> {},
mock(ActionListener.class)
);
CapturingTransport.CapturedRequest[] requestList = transport.getCapturedRequestsAndClear();
Expand All @@ -178,6 +180,7 @@ public void testGetSegmentFiles_CancelWhileRequestOpen() throws InterruptedExcep
checkpoint,
Arrays.asList(testMetadata),
mock(IndexShard.class),
(fileName, bytesRecovered) -> {},
new ActionListener<>() {
@Override
public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void testGetSegmentFiles() throws ExecutionException, InterruptedExceptio
List<StoreFileMetadata> filesToFetch = primaryShard.getSegmentMetadataMap().values().stream().collect(Collectors.toList());
final PlainActionFuture<GetSegmentFilesResponse> res = PlainActionFuture.newFuture();
replicationSource = new RemoteStoreReplicationSource(primaryShard);
replicationSource.getSegmentFiles(REPLICATION_ID, checkpoint, filesToFetch, replicaShard, res);
replicationSource.getSegmentFiles(REPLICATION_ID, checkpoint, filesToFetch, replicaShard, (fileName, bytesRecovered) -> {}, res);
GetSegmentFilesResponse response = res.get();
assertEquals(response.files.size(), filesToFetch.size());
assertTrue(response.files.containsAll(filesToFetch));
Expand All @@ -104,7 +104,14 @@ public void testGetSegmentFilesAlreadyExists() throws IOException, InterruptedEx
try {
final PlainActionFuture<GetSegmentFilesResponse> res = PlainActionFuture.newFuture();
replicationSource = new RemoteStoreReplicationSource(primaryShard);
replicationSource.getSegmentFiles(REPLICATION_ID, checkpoint, filesToFetch, primaryShard, res);
replicationSource.getSegmentFiles(
REPLICATION_ID,
checkpoint,
filesToFetch,
primaryShard,
(fileName, bytesRecovered) -> {},
res
);
res.get();
} catch (AssertionError | ExecutionException ex) {
latch.countDown();
Expand All @@ -118,7 +125,14 @@ public void testGetSegmentFilesReturnEmptyResponse() throws ExecutionException,
final ReplicationCheckpoint checkpoint = primaryShard.getLatestReplicationCheckpoint();
final PlainActionFuture<GetSegmentFilesResponse> res = PlainActionFuture.newFuture();
replicationSource = new RemoteStoreReplicationSource(primaryShard);
replicationSource.getSegmentFiles(REPLICATION_ID, checkpoint, Collections.emptyList(), primaryShard, res);
replicationSource.getSegmentFiles(
REPLICATION_ID,
checkpoint,
Collections.emptyList(),
primaryShard,
(fileName, bytesRecovered) -> {},
res
);
GetSegmentFilesResponse response = res.get();
assert (response.files.isEmpty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -215,6 +216,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
Assert.fail("Should not be called");
Expand Down Expand Up @@ -280,6 +282,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList()));
Expand Down Expand Up @@ -337,6 +340,7 @@ public void getSegmentFiles(
ReplicationCheckpoint checkpoint,
List<StoreFileMetadata> filesToFetch,
IndexShard indexShard,
BiConsumer<String, Long> fileProgressTracker,
ActionListener<GetSegmentFilesResponse> listener
) {
Assert.fail("Unreachable");
Expand Down
Loading

0 comments on commit f934b33

Please sign in to comment.