From 5f63ceda02afc76d518b46a66042867b229aac25 Mon Sep 17 00:00:00 2001 From: Raghuvansh Raj Date: Fri, 7 Jul 2023 17:52:58 +0530 Subject: [PATCH] Additional refactors to AsyncUploadUtils and TransferNIOGroup Signed-off-by: Raghuvansh Raj --- .../repositories/s3/S3BlobContainer.java | 33 ++-- .../repositories/s3/S3BlobStore.java | 12 +- .../repositories/s3/S3Repository.java | 24 +-- .../repositories/s3/S3RepositoryPlugin.java | 42 +++- .../s3/async/AsyncPartsHandler.java | 183 ++++++++++++++++++ .../s3/S3BlobContainerMockClientTests.java | 115 ++++++----- .../s3/S3BlobContainerRetriesTests.java | 70 ++++--- 7 files changed, 328 insertions(+), 151 deletions(-) create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 389ac02344b59..81a902a6992d8 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -45,6 +45,7 @@ import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.BlobStoreException; import org.opensearch.common.blobstore.DeleteResult; +import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.support.AbstractBlobContainer; @@ -77,7 +78,6 @@ import org.opensearch.core.common.Strings; import org.opensearch.repositories.s3.async.UploadRequest; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.utils.CompletableFutureUtils; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -97,7 +97,7 @@ import static org.opensearch.repositories.s3.S3Repository.MAX_FILE_SIZE_USING_MULTIPART; import static org.opensearch.repositories.s3.S3Repository.MIN_PART_SIZE_USING_MULTIPART; -class S3BlobContainer extends AbstractBlobContainer { +class S3BlobContainer extends AbstractBlobContainer implements VerifyingMultiStreamBlobContainer { private static final Logger logger = LogManager.getLogger(S3BlobContainer.class); @@ -175,17 +175,7 @@ public void writeBlob(String blobName, InputStream inputStream, long blobSize, b } @Override - public boolean isMultiStreamUploadSupported() { - return blobStore.isMultipartUploadEnabled(); - } - - @Override - public boolean isRemoteDataIntegritySupported() { - return true; - } - - @Override - public CompletableFuture writeBlobByStreams(WriteContext writeContext) throws IOException { + public void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException { UploadRequest uploadRequest = new UploadRequest( blobStore.bucket(), buildKey(writeContext.getFileName()), @@ -196,20 +186,23 @@ public CompletableFuture writeBlobByStreams(WriteContext writeContext) thr writeContext.getExpectedChecksum() ); try { - long partSize = blobStore.getAsyncUploadUtils().calculateOptimalPartSize(writeContext.getFileSize()); + long partSize = blobStore.getAsyncTransferManager().calculateOptimalPartSize(writeContext.getFileSize()); StreamContext streamContext = SocketAccess.doPrivileged(() -> writeContext.getStreamProvider(partSize)); try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) { S3AsyncClient s3AsyncClient = writeContext.getWritePriority() == WritePriority.HIGH ? amazonS3Reference.get().priorityClient() : amazonS3Reference.get().client(); - CompletableFuture returnFuture = new CompletableFuture<>(); - CompletableFuture completableFuture = blobStore.getAsyncUploadUtils() + CompletableFuture completableFuture = blobStore.getAsyncTransferManager() .uploadObject(s3AsyncClient, uploadRequest, streamContext); - - CompletableFutureUtils.forwardExceptionTo(returnFuture, completableFuture); - CompletableFutureUtils.forwardResultTo(completableFuture, returnFuture); - return completableFuture; + completableFuture.whenComplete((response, throwable) -> { + if (throwable == null) { + completionListener.onResponse(response); + } else { + Exception ex = throwable instanceof Error ? new Exception(throwable) : (Exception) throwable; + completionListener.onFailure(ex); + } + }); } } catch (Exception e) { logger.info("exception error from blob container for file {}", writeContext.getFileName()); diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java index 1aba640124776..3ec2a37ec3c87 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java @@ -43,7 +43,7 @@ import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.StorageClass; import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; -import org.opensearch.repositories.s3.async.AsyncUploadUtils; +import org.opensearch.repositories.s3.async.AsyncTransferManager; import java.io.IOException; import java.util.Locale; @@ -71,7 +71,7 @@ class S3BlobStore implements BlobStore { private final StatsMetricPublisher statsMetricPublisher = new StatsMetricPublisher(); - private final AsyncUploadUtils asyncUploadUtils; + private final AsyncTransferManager asyncTransferManager; private final AsyncExecutorBuilder priorityExecutorBuilder; private final AsyncExecutorBuilder normalExecutorBuilder; private final boolean multipartUploadEnabled; @@ -86,7 +86,7 @@ class S3BlobStore implements BlobStore { String cannedACL, String storageClass, RepositoryMetadata repositoryMetadata, - AsyncUploadUtils asyncUploadUtils, + AsyncTransferManager asyncTransferManager, AsyncExecutorBuilder priorityExecutorBuilder, AsyncExecutorBuilder normalExecutorBuilder ) { @@ -99,7 +99,7 @@ class S3BlobStore implements BlobStore { this.cannedACL = initCannedACL(cannedACL); this.storageClass = initStorageClass(storageClass); this.repositoryMetadata = repositoryMetadata; - this.asyncUploadUtils = asyncUploadUtils; + this.asyncTransferManager = asyncTransferManager; this.normalExecutorBuilder = normalExecutorBuilder; this.priorityExecutorBuilder = priorityExecutorBuilder; } @@ -203,7 +203,7 @@ public static ObjectCannedACL initCannedACL(String cannedACL) { throw new BlobStoreException("cannedACL is not valid: [" + cannedACL + "]"); } - public AsyncUploadUtils getAsyncUploadUtils() { - return asyncUploadUtils; + public AsyncTransferManager getAsyncTransferManager() { + return asyncTransferManager; } } diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java index 5f8bb885ef262..2dd7c7c4107e6 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java @@ -57,7 +57,7 @@ import org.opensearch.repositories.ShardGenerations; import org.opensearch.repositories.blobstore.MeteredBlobStoreRepository; import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; -import org.opensearch.repositories.s3.async.AsyncUploadUtils; +import org.opensearch.repositories.s3.async.AsyncTransferManager; import org.opensearch.snapshots.SnapshotId; import org.opensearch.snapshots.SnapshotInfo; import org.opensearch.threadpool.Scheduler; @@ -172,24 +172,6 @@ class S3Repository extends MeteredBlobStoreRepository { Setting.Property.NodeScope ); - /** - * Event loop thread count for priority uploads - */ - public static Setting PRIORITY_UPLOAD_EVENT_LOOP_THREAD_COUNT_SETTING = Setting.intSetting( - "parallel_multipart_upload.priority.event_loop_thread_count", - 4, - Setting.Property.NodeScope - ); - - /** - * Event loop thread count for normal uploads - */ - public static Setting NORMAL_UPLOAD_EVENT_LOOP_THREAD_COUNT_SETTING = Setting.intSetting( - "parallel_multipart_upload.normal.event_loop_thread_count", - 1, - Setting.Property.NodeScope - ); - /** * Big files can be broken down into chunks during snapshotting if needed. Defaults to 1g. */ @@ -237,7 +219,7 @@ class S3Repository extends MeteredBlobStoreRepository { private final RepositoryMetadata repositoryMetadata; - private final AsyncUploadUtils asyncUploadUtils; + private final AsyncTransferManager asyncUploadUtils; private final S3AsyncService s3AsyncService; private final boolean multipartUploadEnabled; private final AsyncExecutorBuilder priorityExecutorBuilder; @@ -252,7 +234,7 @@ class S3Repository extends MeteredBlobStoreRepository { final S3Service service, final ClusterService clusterService, final RecoverySettings recoverySettings, - final AsyncUploadUtils asyncUploadUtils, + final AsyncTransferManager asyncUploadUtils, final AsyncExecutorBuilder priorityExecutorBuilder, final AsyncExecutorBuilder normalExecutorBuilder, final S3AsyncService s3AsyncService, diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java index b16a1699ad8a7..82944ce543464 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java @@ -39,6 +39,7 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; @@ -49,8 +50,8 @@ import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; -import org.opensearch.repositories.s3.async.AsyncUploadUtils; -import org.opensearch.repositories.s3.async.TransferNIOGroup; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; +import org.opensearch.repositories.s3.async.AsyncTransferManager; import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -92,10 +93,14 @@ public S3RepositoryPlugin(final Settings settings, final Path configPath) { @Override public List> getExecutorBuilders(Settings settings) { List> executorBuilders = new ArrayList<>(); - executorBuilders.add(new FixedExecutorBuilder(settings, PRIORITY_FUTURE_COMPLETION, 4, 10_000, PRIORITY_FUTURE_COMPLETION)); - executorBuilders.add(new FixedExecutorBuilder(settings, PRIORITY_STREAM_READER, 4, 10_000, PRIORITY_STREAM_READER)); - executorBuilders.add(new FixedExecutorBuilder(settings, FUTURE_COMPLETION, 1, 10_000, FUTURE_COMPLETION)); - executorBuilders.add(new FixedExecutorBuilder(settings, STREAM_READER, 1, 10_000, STREAM_READER)); + executorBuilders.add( + new FixedExecutorBuilder(settings, PRIORITY_FUTURE_COMPLETION, priorityPoolCount(settings), 10_000, PRIORITY_FUTURE_COMPLETION) + ); + executorBuilders.add( + new FixedExecutorBuilder(settings, PRIORITY_STREAM_READER, priorityPoolCount(settings), 10_000, PRIORITY_STREAM_READER) + ); + executorBuilders.add(new FixedExecutorBuilder(settings, FUTURE_COMPLETION, normalPoolCount(settings), 10_000, FUTURE_COMPLETION)); + executorBuilders.add(new FixedExecutorBuilder(settings, STREAM_READER, normalPoolCount(settings), 10_000, STREAM_READER)); return executorBuilders; } @@ -109,6 +114,22 @@ public List> getExecutorBuilders(Settings settings) { this.s3AsyncService.refreshAndClearCache(clientsSettings); } + private static int boundedBy(int value, int min, int max) { + return Math.min(max, Math.max(min, value)); + } + + private static int allocatedProcessors(Settings settings) { + return OpenSearchExecutors.allocatedProcessors(settings); + } + + private static int priorityPoolCount(Settings settings) { + return boundedBy((allocatedProcessors(settings) + 1) / 2, 2, 4); + } + + private static int normalPoolCount(Settings settings) { + return boundedBy((allocatedProcessors(settings) + 7) / 8, 1, 2); + } + @Override public Collection createComponents( final Client client, @@ -123,15 +144,17 @@ public Collection createComponents( final IndexNameExpressionResolver expressionResolver, final Supplier repositoriesServiceSupplier ) { + int priorityEventLoopThreads = priorityPoolCount(clusterService.getSettings()); + int normalEventLoopThreads = normalPoolCount(clusterService.getSettings()); this.priorityExecutorBuilder = new AsyncExecutorBuilder( threadPool.executor(PRIORITY_FUTURE_COMPLETION), threadPool.executor(PRIORITY_STREAM_READER), - new TransferNIOGroup(S3Repository.PRIORITY_UPLOAD_EVENT_LOOP_THREAD_COUNT_SETTING.get(clusterService.getSettings())) + new AsyncTransferEventLoopGroup(priorityEventLoopThreads) ); this.normalExecutorBuilder = new AsyncExecutorBuilder( threadPool.executor(FUTURE_COMPLETION), threadPool.executor(STREAM_READER), - new TransferNIOGroup(S3Repository.NORMAL_UPLOAD_EVENT_LOOP_THREAD_COUNT_SETTING.get(clusterService.getSettings())) + new AsyncTransferEventLoopGroup(normalEventLoopThreads) ); return Collections.emptyList(); } @@ -143,7 +166,8 @@ protected S3Repository createRepository( final ClusterService clusterService, final RecoverySettings recoverySettings ) { - AsyncUploadUtils asyncUploadUtils = new AsyncUploadUtils( + + AsyncTransferManager asyncUploadUtils = new AsyncTransferManager( S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING.get(clusterService.getSettings()).getBytes(), normalExecutorBuilder.getStreamReader(), priorityExecutorBuilder.getStreamReader() diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java new file mode 100644 index 0000000000000..b6af91a08ac2b --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.repositories.s3.SocketAccess; +import org.opensearch.repositories.s3.io.CheckedContainer; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReferenceArray; + +/** + * Responsible for handling parts of the original multipart request + */ +public class AsyncPartsHandler { + + private static Logger log = LogManager.getLogger(AsyncPartsHandler.class); + + /** + * Uploads parts of the upload multipart request* + * @param s3AsyncClient S3 client to use for upload + * @param executorService Thread pool for regular upload + * @param priorityExecutorService Thread pool for priority uploads + * @param uploadRequest request for upload + * @param streamContext Stream context used in supplying individual file parts + * @param uploadId Upload Id against which multi-part is being performed + * @param completedParts Reference of completed parts + * @param inputStreamContainers Checksum containers + * @return list of completable futures + * @throws IOException thrown in case of an IO error + */ + public static List> uploadParts( + S3AsyncClient s3AsyncClient, + ExecutorService executorService, + ExecutorService priorityExecutorService, + UploadRequest uploadRequest, + StreamContext streamContext, + String uploadId, + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers + ) throws IOException { + List> futures = new ArrayList<>(); + for (int partIdx = 0; partIdx < streamContext.getNumberOfParts(); partIdx++) { + InputStreamContainer inputStreamContainer = streamContext.provideStream(partIdx); + inputStreamContainers.set(partIdx, new CheckedContainer(inputStreamContainer.getContentLength())); + UploadPartRequest.Builder uploadPartRequestBuilder = UploadPartRequest.builder() + .bucket(uploadRequest.getBucket()) + .partNumber(partIdx + 1) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .contentLength(inputStreamContainer.getContentLength()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + uploadPartRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + } + uploadPart( + s3AsyncClient, + executorService, + priorityExecutorService, + completedParts, + inputStreamContainers, + futures, + uploadPartRequestBuilder.build(), + inputStreamContainer, + uploadRequest + ); + } + + return futures; + } + + /** + * Cleans up parts of the original multipart request* + * @param s3AsyncClient s3 client to use + * @param uploadRequest upload request + * @param uploadId upload id against which multi-part was carried out. + */ + public static void cleanUpParts(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, String uploadId) { + + AbortMultipartUploadRequest abortMultipartUploadRequest = AbortMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .build(); + SocketAccess.doPrivileged(() -> s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest).exceptionally(throwable -> { + log.warn( + () -> new ParameterizedMessage( + "Failed to abort previous multipart upload " + + "(id: {})" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + uploadId + ), + throwable + ); + return null; + })); + } + + private static void uploadPart( + S3AsyncClient s3AsyncClient, + ExecutorService executorService, + ExecutorService priorityExecutorService, + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers, + List> futures, + UploadPartRequest uploadPartRequest, + InputStreamContainer inputStreamContainer, + UploadRequest uploadRequest + ) { + Integer partNumber = uploadPartRequest.partNumber(); + + ExecutorService streamReadExecutor = uploadRequest.getWritePriority() == WritePriority.HIGH + ? priorityExecutorService + : executorService; + CompletableFuture uploadPartResponseFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.uploadPart( + uploadPartRequest, + AsyncRequestBody.fromInputStream( + inputStreamContainer.getInputStream(), + inputStreamContainer.getContentLength(), + streamReadExecutor + ) + ) + ); + + CompletableFuture convertFuture = uploadPartResponseFuture.thenApply( + uploadPartResponse -> convertUploadPartResponse( + completedParts, + inputStreamContainers, + uploadPartResponse, + partNumber, + uploadRequest.doRemoteDataIntegrityCheck() + ) + ); + futures.add(convertFuture); + + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartResponseFuture); + } + + private static CompletedPart convertUploadPartResponse( + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers, + UploadPartResponse partResponse, + int partNumber, + boolean isRemoteDataIntegrityCheckEnabled + ) { + CompletedPart.Builder completedPartBuilder = CompletedPart.builder().eTag(partResponse.eTag()).partNumber(partNumber); + if (isRemoteDataIntegrityCheckEnabled) { + completedPartBuilder.checksumCRC32(partResponse.checksumCRC32()); + CheckedContainer inputStreamCRC32Container = inputStreamContainers.get(partNumber - 1); + inputStreamCRC32Container.setChecksum(partResponse.checksumCRC32()); + inputStreamContainers.set(partNumber - 1, inputStreamCRC32Container); + } + CompletedPart completedPart = completedPartBuilder.build(); + completedParts.set(partNumber - 1, completedPart); + return completedPart; + } +} diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java index 2f2b70b7fd778..99a71849ca38c 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java @@ -12,16 +12,16 @@ import org.junit.After; import org.junit.Before; import org.mockito.invocation.InvocationOnMock; +import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.RepositoryMetadata; -import org.opensearch.common.io.InputStreamContainer; -import org.opensearch.common.StreamContext; import org.opensearch.common.CheckedTriFunction; +import org.opensearch.common.StreamContext; import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; -import org.opensearch.common.blobstore.transfer.UploadFinalizer; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.ByteSizeValue; @@ -53,11 +53,14 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -438,39 +441,35 @@ private void testWriteBlobByStreams(boolean expectException, boolean throwExcept ExecutionException, InterruptedException { final byte[] bytes = randomByteArrayOfLength(100); List openInputStreams = new ArrayList<>(); - CompletableFuture completableFuture = blobContainer.writeBlobByStreams( - new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { - @Override - public StreamContext supplyStreamContext(long partSize) { - return new StreamContext(new CheckedTriFunction() { - @Override - public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { - InputStream inputStream = new OffsetRangeIndexInputStream( - new ByteArrayIndexInput("desc", bytes), - size, - position - ); - openInputStreams.add(inputStream); - return new InputStreamContainer(inputStream, size); - } - }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); - } - }, bytes.length, false, WritePriority.NORMAL, new UploadFinalizer() { - @Override - public void accept(boolean uploadSuccess) { - assertTrue(uploadSuccess); - if (throwExceptionOnFinalizeUpload) { - throw new RuntimeException(); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + blobContainer.asyncBlobUpload(new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ByteArrayIndexInput("desc", bytes), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); } - } - }, false, null) - ); + }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); + } + }, bytes.length, false, WritePriority.NORMAL, uploadSuccess -> { + assertTrue(uploadSuccess); + if (throwExceptionOnFinalizeUpload) { + throw new RuntimeException(); + } + }, false, null), completionListener); + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); // wait for completableFuture to finish if (expectException || throwExceptionOnFinalizeUpload) { - assertThrows(ExecutionException.class, completableFuture::get); - } else { - completableFuture.get(); + assertNotNull(exceptionRef.get()); } asyncService.verifySingleChunkUploadCallCount(throwExceptionOnFinalizeUpload); @@ -491,37 +490,35 @@ private void testWriteBlobByStreamsLargeBlob(boolean expectException, boolean th int numberOfParts = randomIntBetween(2, 5); final long lastPartSize = randomLongBetween(10, 512); final long blobSize = ((numberOfParts - 1) * partSize.getBytes()) + lastPartSize; - + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); List openInputStreams = new ArrayList<>(); - CompletableFuture completableFuture = blobContainer.writeBlobByStreams( - new WriteContext("write_large_blob", new StreamContextSupplier() { - @Override - public StreamContext supplyStreamContext(long partSize) { - return new StreamContext(new CheckedTriFunction() { - @Override - public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { - InputStream inputStream = new OffsetRangeIndexInputStream(new ZeroIndexInput("desc", blobSize), size, position); - openInputStreams.add(inputStream); - return new InputStreamContainer(inputStream, size); - } - }, partSize, calculateLastPartSize(blobSize, partSize), calculateNumberOfParts(blobSize, partSize)); - } - }, blobSize, false, WritePriority.HIGH, new UploadFinalizer() { - @Override - public void accept(boolean uploadSuccess) { - assertTrue(uploadSuccess); - if (throwExceptionOnFinalizeUpload) { - throw new RuntimeException(); + blobContainer.asyncBlobUpload(new WriteContext("write_large_blob", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ZeroIndexInput("desc", blobSize), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); } - } - }, false, null) - ); + }, partSize, calculateLastPartSize(blobSize, partSize), calculateNumberOfParts(blobSize, partSize)); + } + }, blobSize, false, WritePriority.HIGH, uploadSuccess -> { + assertTrue(uploadSuccess); + if (throwExceptionOnFinalizeUpload) { + throw new RuntimeException(); + } + }, false, null), completionListener); - // wait for completableFuture to finish + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); if (expectException || throwExceptionOnFinalizeUpload) { - assertThrows(ExecutionException.class, completableFuture::get); - } else { - completableFuture.get(); + assertNotNull(exceptionRef.get()); } asyncService.verifyMultipartUploadCallCount(numberOfParts, throwExceptionOnFinalizeUpload); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java index 721d3818f883d..ae611249a0add 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java @@ -33,22 +33,24 @@ import org.apache.http.HttpStatus; import org.junit.After; +import org.junit.Assert; import org.junit.Before; +import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.CheckedTriFunction; import org.opensearch.common.Nullable; -import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.StreamContext; import org.opensearch.common.SuppressForbidden; -import org.opensearch.common.CheckedTriFunction; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; -import org.opensearch.common.blobstore.transfer.UploadFinalizer; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.hash.MessageDigests; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.io.Streams; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.lucene.store.InputStreamIndexInput; @@ -61,13 +63,13 @@ import org.opensearch.common.util.concurrent.CountDown; import org.opensearch.common.util.io.IOUtils; import org.opensearch.repositories.blobstore.AbstractBlobContainerRetriesTestCase; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.io.SdkDigestInputStream; -import software.amazon.awssdk.utils.internal.Base16; +import org.opensearch.repositories.blobstore.ZeroInputStream; import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; import org.opensearch.repositories.s3.async.AsyncUploadUtils; import org.opensearch.repositories.s3.async.TransferNIOGroup; -import org.opensearch.repositories.blobstore.ZeroInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.io.SdkDigestInputStream; +import software.amazon.awssdk.utils.internal.Base16; import java.io.ByteArrayInputStream; import java.io.FilterInputStream; @@ -80,10 +82,12 @@ import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; @@ -156,7 +160,7 @@ protected Class unresponsiveExceptionType() { } @Override - protected BlobContainer createBlobContainer( + protected VerifyingMultiStreamBlobContainer createBlobContainer( final @Nullable Integer maxRetries, final @Nullable TimeValue readTimeout, final @Nullable Boolean disableChunkedEncoding, @@ -315,35 +319,29 @@ public void testWriteBlobByStreamsWithRetries() throws Exception { } }); - final BlobContainer blobContainer = createBlobContainer(maxRetries, null, true, null); + final VerifyingMultiStreamBlobContainer blobContainer = createBlobContainer(maxRetries, null, true, null); List openInputStreams = new ArrayList<>(); - CompletableFuture completableFuture = blobContainer.writeBlobByStreams( - new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { - @Override - public StreamContext supplyStreamContext(long partSize) { - return new StreamContext(new CheckedTriFunction() { - @Override - public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { - InputStream inputStream = new OffsetRangeIndexInputStream( - new ByteArrayIndexInput("desc", bytes), - size, - position - ); - openInputStreams.add(inputStream); - return new InputStreamContainer(inputStream, size); - } - }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); - } - }, bytes.length, false, WritePriority.NORMAL, new UploadFinalizer() { - @Override - public void accept(boolean uploadSuccess) { - assertTrue(uploadSuccess); - } - }, false, null) - ); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + blobContainer.asyncBlobUpload(new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ByteArrayIndexInput("desc", bytes), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); + } + }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); + } + }, bytes.length, false, WritePriority.NORMAL, Assert::assertTrue, false, null), completionListener); - // wait for completableFuture to finish - completableFuture.get(); + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); assertThat(countDown.isCountedDown(), is(true));