diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index 0c3350f224b11..f9b40d14b0d53 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -31,6 +31,8 @@ import org.opensearch.transport.TransportRequestHandler; import org.opensearch.transport.TransportService; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicLong; /** @@ -49,6 +51,8 @@ public class SegmentReplicationTargetService implements IndexEventListener { private final SegmentReplicationSourceFactory sourceFactory; + private final Map latestReceivedCheckpoint = new HashMap<>(); + /** * The internal actions * @@ -91,6 +95,15 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh * @param replicaShard replica shard on which checkpoint is received */ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) { + + // Checks if received checkpoint is already present and ahead then it replaces old received checkpoint + if (latestReceivedCheckpoint.get(replicaShard.shardId()) != null) { + if (receivedCheckpoint.isAheadOf(latestReceivedCheckpoint.get(replicaShard.shardId()))) { + latestReceivedCheckpoint.replace(replicaShard.shardId(), receivedCheckpoint); + } + } else { + latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint); + } if (onGoingReplications.isShardReplicating(replicaShard.shardId())) { logger.trace( () -> new ParameterizedMessage( @@ -100,10 +113,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe ); return; } + final Thread thread = Thread.currentThread(); if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) { startReplication(receivedCheckpoint, replicaShard, new SegmentReplicationListener() { @Override - public void onReplicationDone(SegmentReplicationState state) {} + public void onReplicationDone(SegmentReplicationState state) { + // if we received a checkpoint during the copy event that is ahead of this + // try and process it. + if (latestReceivedCheckpoint.get(replicaShard.shardId()).isAheadOf(replicaShard.getLatestReplicationCheckpoint())) { + Runnable runnable = () -> onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard); + // Checks if we are using same thread and forks if necessary. + if (thread == Thread.currentThread()) { + threadPool.generic().execute(runnable); + } else { + runnable.run(); + } + } + } @Override public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 7ff6e3dceabc9..8b4bda7de50ad 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -34,6 +34,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.times; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.eq; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { @@ -203,6 +204,40 @@ public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOExc closeShard(indexShard, false); } + public void testReplicationOnDone() throws IOException { + SegmentReplicationTargetService spy = spy(sut); + IndexShard spyShard = spy(indexShard); + ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint(); + ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint( + cp.getShardId(), + cp.getPrimaryTerm(), + cp.getSegmentsGen(), + cp.getSeqNo(), + cp.getSegmentInfosVersion() + 1 + ); + ReplicationCheckpoint anotherNewCheckpoint = new ReplicationCheckpoint( + cp.getShardId(), + cp.getPrimaryTerm(), + cp.getSegmentsGen(), + cp.getSeqNo(), + cp.getSegmentInfosVersion() + 2 + ); + ArgumentCaptor captor = ArgumentCaptor.forClass( + SegmentReplicationTargetService.SegmentReplicationListener.class + ); + doNothing().when(spy).startReplication(any(), any(), any()); + spy.onNewCheckpoint(newCheckpoint, spyShard); + spy.onNewCheckpoint(anotherNewCheckpoint, spyShard); + verify(spy, times(1)).startReplication(eq(newCheckpoint), any(), captor.capture()); + verify(spy, times(1)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); + SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); + listener.onDone(new SegmentReplicationState(new ReplicationLuceneIndex())); + doNothing().when(spy).onNewCheckpoint(any(), any()); + verify(spy, timeout(0).times(2)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); + closeShard(indexShard, false); + + } + public void testBeforeIndexShardClosed_CancelsOngoingReplications() { final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint,