Skip to content

Commit

Permalink
[Segment Replication] Checkpoint Replay on Replica Shard (#3658)
Browse files Browse the repository at this point in the history
* Adding Latest Recevied checkpoint, replay checkpoint logic along with tests

Signed-off-by: Rishikesh1159 <rishireddy1159@gmail.com>
  • Loading branch information
Rishikesh1159 authored Jul 21, 2022
1 parent 8b5a10c commit bb593d6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.transport.TransportService;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;

/**
Expand All @@ -49,6 +51,8 @@ public class SegmentReplicationTargetService implements IndexEventListener {

private final SegmentReplicationSourceFactory sourceFactory;

private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = new HashMap<>();

/**
* The internal actions
*
Expand Down Expand Up @@ -91,6 +95,15 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh
* @param replicaShard replica shard on which checkpoint is received
*/
public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) {

// Checks if received checkpoint is already present and ahead then it replaces old received checkpoint
if (latestReceivedCheckpoint.get(replicaShard.shardId()) != null) {
if (receivedCheckpoint.isAheadOf(latestReceivedCheckpoint.get(replicaShard.shardId()))) {
latestReceivedCheckpoint.replace(replicaShard.shardId(), receivedCheckpoint);
}
} else {
latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint);
}
if (onGoingReplications.isShardReplicating(replicaShard.shardId())) {
logger.trace(
() -> new ParameterizedMessage(
Expand All @@ -100,10 +113,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe
);
return;
}
final Thread thread = Thread.currentThread();
if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
startReplication(receivedCheckpoint, replicaShard, new SegmentReplicationListener() {
@Override
public void onReplicationDone(SegmentReplicationState state) {}
public void onReplicationDone(SegmentReplicationState state) {
// if we received a checkpoint during the copy event that is ahead of this
// try and process it.
if (latestReceivedCheckpoint.get(replicaShard.shardId()).isAheadOf(replicaShard.getLatestReplicationCheckpoint())) {
Runnable runnable = () -> onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard);
// Checks if we are using same thread and forks if necessary.
if (thread == Thread.currentThread()) {
threadPool.generic().execute(runnable);
} else {
runnable.run();
}
}
}

@Override
public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.eq;

public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {
Expand Down Expand Up @@ -203,6 +204,40 @@ public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOExc
closeShard(indexShard, false);
}

public void testReplicationOnDone() throws IOException {
SegmentReplicationTargetService spy = spy(sut);
IndexShard spyShard = spy(indexShard);
ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint();
ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint(
cp.getShardId(),
cp.getPrimaryTerm(),
cp.getSegmentsGen(),
cp.getSeqNo(),
cp.getSegmentInfosVersion() + 1
);
ReplicationCheckpoint anotherNewCheckpoint = new ReplicationCheckpoint(
cp.getShardId(),
cp.getPrimaryTerm(),
cp.getSegmentsGen(),
cp.getSeqNo(),
cp.getSegmentInfosVersion() + 2
);
ArgumentCaptor<SegmentReplicationTargetService.SegmentReplicationListener> captor = ArgumentCaptor.forClass(
SegmentReplicationTargetService.SegmentReplicationListener.class
);
doNothing().when(spy).startReplication(any(), any(), any());
spy.onNewCheckpoint(newCheckpoint, spyShard);
spy.onNewCheckpoint(anotherNewCheckpoint, spyShard);
verify(spy, times(1)).startReplication(eq(newCheckpoint), any(), captor.capture());
verify(spy, times(1)).onNewCheckpoint(eq(anotherNewCheckpoint), any());
SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue();
listener.onDone(new SegmentReplicationState(new ReplicationLuceneIndex()));
doNothing().when(spy).onNewCheckpoint(any(), any());
verify(spy, timeout(0).times(2)).onNewCheckpoint(eq(anotherNewCheckpoint), any());
closeShard(indexShard, false);

}

public void testBeforeIndexShardClosed_CancelsOngoingReplications() {
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
Expand Down

0 comments on commit bb593d6

Please sign in to comment.