diff --git a/CHANGELOG.md b/CHANGELOG.md index 51c01351be1fd..bfc1cd70fa2b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - PR reference to checkout code for changelog verifier ([#4296](https://github.com/opensearch-project/OpenSearch/pull/4296)) - Restore using the class ClusterInfoRequest and ClusterInfoRequestBuilder from package 'org.opensearch.action.support.master.info' for subclasses ([#4324](https://github.com/opensearch-project/OpenSearch/pull/4324)) - Fixed cancellation of segment replication events ([#4225](https://github.com/opensearch-project/OpenSearch/pull/4225)) +- [Segment Replication] Add check to cancel ongoing replication with old primary on onNewCheckpoint on replica ([#4363](https://github.com/opensearch-project/OpenSearch/pull/4363)) - [Segment Replication] Bump segment infos counter before commit during replica promotion ([#4365](https://github.com/opensearch-project/OpenSearch/pull/4365)) ### Security diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index d1d6104a416ca..7c28406036ddd 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -56,6 +56,10 @@ public class SegmentReplicationTarget extends ReplicationTarget { private final SegmentReplicationState state; protected final MultiFileWriter multiFileWriter; + public ReplicationCheckpoint getCheckpoint() { + return this.checkpoint; + } + public SegmentReplicationTarget( ReplicationCheckpoint checkpoint, IndexShard indexShard, diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index 9e6b66dc4d7d6..8fc53ccd3bc08 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -18,6 +18,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.CancellableThreads; +import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; @@ -34,7 +35,6 @@ import org.opensearch.transport.TransportRequestHandler; import org.opensearch.transport.TransportService; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; @@ -54,7 +54,7 @@ public class SegmentReplicationTargetService implements IndexEventListener { private final SegmentReplicationSourceFactory sourceFactory; - private final Map latestReceivedCheckpoint = new HashMap<>(); + private final Map latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap(); // Empty Implementation, only required while Segment Replication is under feature flag. public static final SegmentReplicationTargetService NO_OP = new SegmentReplicationTargetService() { @@ -151,14 +151,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe } else { latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint); } - if (onGoingReplications.isShardReplicating(replicaShard.shardId())) { - logger.trace( - () -> new ParameterizedMessage( - "Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}", - replicaShard.getLatestReplicationCheckpoint() - ) - ); - return; + SegmentReplicationTarget ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId()); + if (ongoingReplicationTarget != null) { + if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) { + logger.trace( + "Cancelling ongoing replication from old primary with primary term {}", + ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() + ); + onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary"); + } else { + logger.trace( + () -> new ParameterizedMessage( + "Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}", + replicaShard.getLatestReplicationCheckpoint() + ) + ); + return; + } } final Thread thread = Thread.currentThread(); if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) { diff --git a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java index d648ca6041ff8..20600856c9444 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java @@ -49,6 +49,7 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; /** * This class holds a collection of all on going replication events on the current node (i.e., the node is the target node @@ -236,13 +237,18 @@ public boolean cancelForShard(ShardId shardId, String reason) { } /** - * check if a shard is currently replicating + * Get target for shard * - * @param shardId shardId for which to check if replicating - * @return true if shard is currently replicating + * @param shardId shardId + * @return ReplicationTarget for input shardId */ - public boolean isShardReplicating(ShardId shardId) { - return onGoingTargetEvents.values().stream().anyMatch(t -> t.indexShard.shardId().equals(shardId)); + public T getOngoingReplicationTarget(ShardId shardId) { + final List replicationTargetList = onGoingTargetEvents.values() + .stream() + .filter(t -> t.indexShard.shardId().equals(shardId)) + .collect(Collectors.toList()); + assert replicationTargetList.size() <= 1 : "More than one on-going replication targets"; + return replicationTargetList.size() > 0 ? replicationTargetList.get(0) : null; } /** diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 7d9b0f09f21cd..1d253b0a9a300 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -49,6 +49,8 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { private ReplicationCheckpoint initialCheckpoint; private ReplicationCheckpoint aheadCheckpoint; + private ReplicationCheckpoint newPrimaryCheckpoint; + @Override public void setUp() throws Exception { super.setUp(); @@ -74,6 +76,13 @@ public void setUp() throws Exception { initialCheckpoint.getSeqNo(), initialCheckpoint.getSegmentInfosVersion() + 1 ); + newPrimaryCheckpoint = new ReplicationCheckpoint( + initialCheckpoint.getShardId(), + initialCheckpoint.getPrimaryTerm() + 1, + initialCheckpoint.getSegmentsGen(), + initialCheckpoint.getSeqNo(), + initialCheckpoint.getSegmentInfosVersion() + 1 + ); } @Override @@ -160,7 +169,7 @@ public void testShardAlreadyReplicating() throws InterruptedException { // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. SegmentReplicationTargetService serviceSpy = spy(sut); final SegmentReplicationTarget target = new SegmentReplicationTarget( - checkpoint, + initialCheckpoint, replicaShard, replicationSource, mock(SegmentReplicationTargetService.SegmentReplicationListener.class) @@ -185,9 +194,47 @@ public void testShardAlreadyReplicating() throws InterruptedException { // wait for the new checkpoint to arrive, before the listener completes. latch.await(30, TimeUnit.SECONDS); + verify(targetSpy, times(0)).cancel(any()); verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any()); } + public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException { + // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. + SegmentReplicationTargetService serviceSpy = spy(sut); + // Create a Mockito spy of target to stub response of few method calls. + final SegmentReplicationTarget targetSpy = spy( + new SegmentReplicationTarget( + initialCheckpoint, + replicaShard, + replicationSource, + mock(SegmentReplicationTargetService.SegmentReplicationListener.class) + ) + ); + + CountDownLatch latch = new CountDownLatch(1); + // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown + // of latch. + doAnswer(invocation -> { + final ActionListener listener = invocation.getArgument(0); + // a new checkpoint arrives before we've completed. + serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); + listener.onResponse(null); + latch.countDown(); + return null; + }).when(targetSpy).startReplication(any()); + doNothing().when(targetSpy).onDone(); + + // start replication. This adds the target to on-ongoing replication collection + serviceSpy.startReplication(targetSpy); + + // wait for the new checkpoint to arrive, before the listener completes. + latch.await(5, TimeUnit.SECONDS); + doNothing().when(targetSpy).startReplication(any()); + verify(targetSpy, times(1)).cancel("Cancelling stuck target after new primary"); + verify(serviceSpy, times(1)).startReplication(eq(newPrimaryCheckpoint), eq(replicaShard), any()); + closeShards(replicaShard); + } + public void testNewCheckpointBehindCurrentCheckpoint() { SegmentReplicationTargetService spy = spy(sut); spy.onNewCheckpoint(checkpoint, replicaShard); diff --git a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java index 7587f48503625..1789dd3b2a288 100644 --- a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java +++ b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java @@ -105,7 +105,25 @@ public void onFailure(ReplicationState state, OpenSearchException e, boolean sen collection.cancel(recoveryId, "meh"); } } + } + public void testMultiReplicationsForSingleShard() throws Exception { + try (ReplicationGroup shards = createGroup(0)) { + final ReplicationCollection collection = new ReplicationCollection<>(logger, threadPool); + final IndexShard shard1 = shards.addReplica(); + final IndexShard shard2 = shards.addReplica(); + final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shard1); + final long recoveryId2 = startRecovery(collection, shards.getPrimaryNode(), shard2); + try { + collection.getOngoingReplicationTarget(shard1.shardId()); + } catch (AssertionError e) { + assertEquals(e.getMessage(), "More than one on-going replication targets"); + } finally { + collection.cancel(recoveryId, "meh"); + collection.cancel(recoveryId2, "meh"); + } + closeShards(shard1, shard2); + } } public void testRecoveryCancellation() throws Exception {