diff --git a/plugins/repository-s3/src/internalClusterTest/java/org/opensearch/repositories/s3/S3BlobStoreRepositoryTests.java b/plugins/repository-s3/src/internalClusterTest/java/org/opensearch/repositories/s3/S3BlobStoreRepositoryTests.java index 61268cf00a77a..3070c654a96ee 100644 --- a/plugins/repository-s3/src/internalClusterTest/java/org/opensearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/plugins/repository-s3/src/internalClusterTest/java/org/opensearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -172,7 +172,7 @@ protected S3Repository createRepository( ClusterService clusterService, RecoverySettings recoverySettings ) { - return new S3Repository(metadata, registry, service, clusterService, recoverySettings) { + return new S3Repository(metadata, registry, service, clusterService, recoverySettings, null, null, null, null, false) { @Override public BlobStore blobStore() { diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 49ebce77a59ad..81a902a6992d8 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -39,11 +39,15 @@ import org.opensearch.action.ActionListener; import org.opensearch.common.Nullable; import org.opensearch.common.SetOnce; +import org.opensearch.common.StreamContext; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.BlobStoreException; import org.opensearch.common.blobstore.DeleteResult; +import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.support.AbstractBlobContainer; import org.opensearch.common.blobstore.support.PlainBlobMetadata; import org.opensearch.common.collect.Tuple; @@ -72,6 +76,8 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; import org.opensearch.core.common.Strings; +import org.opensearch.repositories.s3.async.UploadRequest; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -82,6 +88,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.stream.Collectors; @@ -90,12 +97,13 @@ import static org.opensearch.repositories.s3.S3Repository.MAX_FILE_SIZE_USING_MULTIPART; import static org.opensearch.repositories.s3.S3Repository.MIN_PART_SIZE_USING_MULTIPART; -class S3BlobContainer extends AbstractBlobContainer { +class S3BlobContainer extends AbstractBlobContainer implements VerifyingMultiStreamBlobContainer { private static final Logger logger = LogManager.getLogger(S3BlobContainer.class); /** * Maximum number of deletes in a {@link DeleteObjectsRequest}. + * * @see S3 Documentation. */ private static final int MAX_BULK_DELETES = 1000; @@ -166,6 +174,42 @@ public void writeBlob(String blobName, InputStream inputStream, long blobSize, b }); } + @Override + public void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException { + UploadRequest uploadRequest = new UploadRequest( + blobStore.bucket(), + buildKey(writeContext.getFileName()), + writeContext.getFileSize(), + writeContext.getWritePriority(), + writeContext.getUploadFinalizer(), + writeContext.doRemoteDataIntegrityCheck(), + writeContext.getExpectedChecksum() + ); + try { + long partSize = blobStore.getAsyncTransferManager().calculateOptimalPartSize(writeContext.getFileSize()); + StreamContext streamContext = SocketAccess.doPrivileged(() -> writeContext.getStreamProvider(partSize)); + try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) { + + S3AsyncClient s3AsyncClient = writeContext.getWritePriority() == WritePriority.HIGH + ? amazonS3Reference.get().priorityClient() + : amazonS3Reference.get().client(); + CompletableFuture completableFuture = blobStore.getAsyncTransferManager() + .uploadObject(s3AsyncClient, uploadRequest, streamContext); + completableFuture.whenComplete((response, throwable) -> { + if (throwable == null) { + completionListener.onResponse(response); + } else { + Exception ex = throwable instanceof Error ? new Exception(throwable) : (Exception) throwable; + completionListener.onFailure(ex); + } + }); + } + } catch (Exception e) { + logger.info("exception error from blob container for file {}", writeContext.getFileName()); + throw new IOException(e); + } + } + // package private for testing long getLargeBlobThresholdInBytes() { return blobStore.bufferSizeInBytes(); diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java index 6a9be2df2bf72..30040e182cbc9 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobStore.java @@ -42,6 +42,8 @@ import org.opensearch.common.unit.ByteSizeValue; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.StorageClass; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferManager; import java.io.IOException; import java.util.Locale; @@ -53,6 +55,8 @@ class S3BlobStore implements BlobStore { private final S3Service service; + private final S3AsyncService s3AsyncService; + private final String bucket; private final ByteSizeValue bufferSize; @@ -67,22 +71,41 @@ class S3BlobStore implements BlobStore { private final StatsMetricPublisher statsMetricPublisher = new StatsMetricPublisher(); + private final AsyncTransferManager asyncTransferManager; + private final AsyncExecutorContainer priorityExecutorBuilder; + private final AsyncExecutorContainer normalExecutorBuilder; + private final boolean multipartUploadEnabled; + S3BlobStore( S3Service service, + S3AsyncService s3AsyncService, + boolean multipartUploadEnabled, String bucket, boolean serverSideEncryption, ByteSizeValue bufferSize, String cannedACL, String storageClass, - RepositoryMetadata repositoryMetadata + RepositoryMetadata repositoryMetadata, + AsyncTransferManager asyncTransferManager, + AsyncExecutorContainer priorityExecutorBuilder, + AsyncExecutorContainer normalExecutorBuilder ) { this.service = service; + this.s3AsyncService = s3AsyncService; + this.multipartUploadEnabled = multipartUploadEnabled; this.bucket = bucket; this.serverSideEncryption = serverSideEncryption; this.bufferSize = bufferSize; this.cannedACL = initCannedACL(cannedACL); this.storageClass = initStorageClass(storageClass); this.repositoryMetadata = repositoryMetadata; + this.asyncTransferManager = asyncTransferManager; + this.normalExecutorBuilder = normalExecutorBuilder; + this.priorityExecutorBuilder = priorityExecutorBuilder; + } + + public boolean isMultipartUploadEnabled() { + return multipartUploadEnabled; } @Override @@ -94,6 +117,10 @@ public AmazonS3Reference clientReference() { return service.client(repositoryMetadata); } + public AmazonAsyncS3Reference asyncClientReference() { + return s3AsyncService.client(repositoryMetadata, priorityExecutorBuilder, normalExecutorBuilder); + } + int getMaxRetries() { return service.settings(repositoryMetadata).maxRetries; } @@ -117,7 +144,12 @@ public BlobContainer blobContainer(BlobPath path) { @Override public void close() throws IOException { - this.service.close(); + if (service != null) { + this.service.close(); + } + if (s3AsyncService != null) { + this.s3AsyncService.close(); + } } @Override @@ -170,4 +202,8 @@ public static ObjectCannedACL initCannedACL(String cannedACL) { throw new BlobStoreException("cannedACL is not valid: [" + cannedACL + "]"); } + + public AsyncTransferManager getAsyncTransferManager() { + return asyncTransferManager; + } } diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java index 07abb69c11bdd..d42bfc0be7e4f 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3Repository.java @@ -34,7 +34,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.cluster.ClusterState; @@ -57,6 +56,8 @@ import org.opensearch.repositories.RepositoryException; import org.opensearch.repositories.ShardGenerations; import org.opensearch.repositories.blobstore.MeteredBlobStoreRepository; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferManager; import org.opensearch.snapshots.SnapshotId; import org.opensearch.snapshots.SnapshotInfo; import org.opensearch.threadpool.Scheduler; @@ -103,6 +104,11 @@ class S3Repository extends MeteredBlobStoreRepository { ByteSizeUnit.BYTES ); + private static final ByteSizeValue DEFAULT_MULTIPART_UPLOAD_MINIMUM_PART_SIZE = new ByteSizeValue( + ByteSizeUnit.MB.toBytes(16), + ByteSizeUnit.BYTES + ); + static final Setting BUCKET_SETTING = Setting.simpleString("bucket"); /** @@ -146,6 +152,26 @@ class S3Repository extends MeteredBlobStoreRepository { MAX_PART_SIZE_USING_MULTIPART ); + /** + * Minimum part size for parallel multipart uploads + */ + static final Setting PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING = Setting.byteSizeSetting( + "parallel_multipart_upload.minimum_part_size", + DEFAULT_MULTIPART_UPLOAD_MINIMUM_PART_SIZE, + MIN_PART_SIZE_USING_MULTIPART, + MAX_PART_SIZE_USING_MULTIPART, + Setting.Property.NodeScope + ); + + /** + * This setting controls whether parallel multipart uploads will be used when calling S3 or not + */ + public static Setting PARALLEL_MULTIPART_UPLOAD_ENABLED_SETTING = Setting.boolSetting( + "parallel_multipart_upload.enabled", + true, + Setting.Property.NodeScope + ); + /** * Big files can be broken down into chunks during snapshotting if needed. Defaults to 1g. */ @@ -193,6 +219,12 @@ class S3Repository extends MeteredBlobStoreRepository { private final RepositoryMetadata repositoryMetadata; + private final AsyncTransferManager asyncUploadUtils; + private final S3AsyncService s3AsyncService; + private final boolean multipartUploadEnabled; + private final AsyncExecutorContainer priorityExecutorBuilder; + private final AsyncExecutorContainer normalExecutorBuilder; + /** * Constructs an s3 backed repository */ @@ -201,7 +233,12 @@ class S3Repository extends MeteredBlobStoreRepository { final NamedXContentRegistry namedXContentRegistry, final S3Service service, final ClusterService clusterService, - final RecoverySettings recoverySettings + final RecoverySettings recoverySettings, + final AsyncTransferManager asyncUploadUtils, + final AsyncExecutorContainer priorityExecutorBuilder, + final AsyncExecutorContainer normalExecutorBuilder, + final S3AsyncService s3AsyncService, + final boolean multipartUploadEnabled ) { super( metadata, @@ -212,8 +249,13 @@ class S3Repository extends MeteredBlobStoreRepository { buildLocation(metadata) ); this.service = service; + this.s3AsyncService = s3AsyncService; + this.multipartUploadEnabled = multipartUploadEnabled; this.repositoryMetadata = metadata; + this.asyncUploadUtils = asyncUploadUtils; + this.priorityExecutorBuilder = priorityExecutorBuilder; + this.normalExecutorBuilder = normalExecutorBuilder; // Parse and validate the user's S3 Storage Class setting this.bucket = BUCKET_SETTING.get(metadata.settings()); @@ -314,7 +356,20 @@ public void deleteSnapshots( @Override protected S3BlobStore createBlobStore() { - return new S3BlobStore(service, bucket, serverSideEncryption, bufferSize, cannedACL, storageClass, repositoryMetadata); + return new S3BlobStore( + service, + s3AsyncService, + multipartUploadEnabled, + bucket, + serverSideEncryption, + bufferSize, + cannedACL, + storageClass, + repositoryMetadata, + asyncUploadUtils, + priorityExecutorBuilder, + normalExecutorBuilder + ); } // only use for testing diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java index 828bf85fd7889..30f792346f9be 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RepositoryPlugin.java @@ -32,44 +32,131 @@ package org.opensearch.repositories.s3; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.metadata.RepositoryMetadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ReloadablePlugin; import org.opensearch.plugins.RepositoryPlugin; +import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; +import org.opensearch.repositories.s3.async.AsyncTransferManager; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; /** * A plugin to add a repository type that writes to and from the AWS S3. */ public class S3RepositoryPlugin extends Plugin implements RepositoryPlugin, ReloadablePlugin { + private static final String PRIORITY_FUTURE_COMPLETION = "priority_future_completion"; + private static final String PRIORITY_STREAM_READER = "priority_stream_reader"; + private static final String FUTURE_COMPLETION = "future_completion"; + private static final String STREAM_READER = "stream_reader"; protected final S3Service service; + private final S3AsyncService s3AsyncService; + private final Path configPath; + private AsyncExecutorContainer priorityExecutorBuilder; + private AsyncExecutorContainer normalExecutorBuilder; + public S3RepositoryPlugin(final Settings settings, final Path configPath) { - this(settings, configPath, new S3Service(configPath)); + this(settings, configPath, new S3Service(configPath), new S3AsyncService(configPath)); + } + + @Override + public List> getExecutorBuilders(Settings settings) { + List> executorBuilders = new ArrayList<>(); + executorBuilders.add( + new FixedExecutorBuilder(settings, PRIORITY_FUTURE_COMPLETION, priorityPoolCount(settings), 10_000, PRIORITY_FUTURE_COMPLETION) + ); + executorBuilders.add( + new FixedExecutorBuilder(settings, PRIORITY_STREAM_READER, priorityPoolCount(settings), 10_000, PRIORITY_STREAM_READER) + ); + executorBuilders.add(new FixedExecutorBuilder(settings, FUTURE_COMPLETION, normalPoolCount(settings), 10_000, FUTURE_COMPLETION)); + executorBuilders.add(new FixedExecutorBuilder(settings, STREAM_READER, normalPoolCount(settings), 10_000, STREAM_READER)); + return executorBuilders; } - S3RepositoryPlugin(final Settings settings, final Path configPath, final S3Service service) { + S3RepositoryPlugin(final Settings settings, final Path configPath, final S3Service service, final S3AsyncService s3AsyncService) { this.service = Objects.requireNonNull(service, "S3 service must not be null"); this.configPath = configPath; // eagerly load client settings so that secure settings are read - final Map clientsSettings = S3ClientSettings.load(settings, configPath); + Map clientsSettings = S3ClientSettings.load(settings, configPath); + this.s3AsyncService = Objects.requireNonNull(s3AsyncService, "S3AsyncService must not be null"); this.service.refreshAndClearCache(clientsSettings); + this.s3AsyncService.refreshAndClearCache(clientsSettings); + } + + private static int boundedBy(int value, int min, int max) { + return Math.min(max, Math.max(min, value)); + } + + private static int allocatedProcessors(Settings settings) { + return OpenSearchExecutors.allocatedProcessors(settings); + } + + private static int priorityPoolCount(Settings settings) { + return boundedBy((allocatedProcessors(settings) + 1) / 2, 2, 4); + } + + private static int normalPoolCount(Settings settings) { + return boundedBy((allocatedProcessors(settings) + 7) / 8, 1, 2); + } + + @Override + public Collection createComponents( + final Client client, + final ClusterService clusterService, + final ThreadPool threadPool, + final ResourceWatcherService resourceWatcherService, + final ScriptService scriptService, + final NamedXContentRegistry xContentRegistry, + final Environment environment, + final NodeEnvironment nodeEnvironment, + final NamedWriteableRegistry namedWriteableRegistry, + final IndexNameExpressionResolver expressionResolver, + final Supplier repositoriesServiceSupplier + ) { + int priorityEventLoopThreads = priorityPoolCount(clusterService.getSettings()); + int normalEventLoopThreads = normalPoolCount(clusterService.getSettings()); + this.priorityExecutorBuilder = new AsyncExecutorContainer( + threadPool.executor(PRIORITY_FUTURE_COMPLETION), + threadPool.executor(PRIORITY_STREAM_READER), + new AsyncTransferEventLoopGroup(priorityEventLoopThreads) + ); + this.normalExecutorBuilder = new AsyncExecutorContainer( + threadPool.executor(FUTURE_COMPLETION), + threadPool.executor(STREAM_READER), + new AsyncTransferEventLoopGroup(normalEventLoopThreads) + ); + return Collections.emptyList(); } // proxy method for testing @@ -79,7 +166,24 @@ protected S3Repository createRepository( final ClusterService clusterService, final RecoverySettings recoverySettings ) { - return new S3Repository(metadata, registry, service, clusterService, recoverySettings); + + AsyncTransferManager asyncUploadUtils = new AsyncTransferManager( + S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING.get(clusterService.getSettings()).getBytes(), + normalExecutorBuilder.getStreamReader(), + priorityExecutorBuilder.getStreamReader() + ); + return new S3Repository( + metadata, + registry, + service, + clusterService, + recoverySettings, + asyncUploadUtils, + priorityExecutorBuilder, + normalExecutorBuilder, + s3AsyncService, + S3Repository.PARALLEL_MULTIPART_UPLOAD_ENABLED_SETTING.get(clusterService.getSettings()) + ); } @Override @@ -119,7 +223,9 @@ public List> getSettings() { S3ClientSettings.REGION, S3ClientSettings.ROLE_ARN_SETTING, S3ClientSettings.IDENTITY_TOKEN_FILE_SETTING, - S3ClientSettings.ROLE_SESSION_NAME_SETTING + S3ClientSettings.ROLE_SESSION_NAME_SETTING, + S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING, + S3Repository.PARALLEL_MULTIPART_UPLOAD_ENABLED_SETTING ); } @@ -128,10 +234,12 @@ public void reload(Settings settings) { // secure settings should be readable final Map clientsSettings = S3ClientSettings.load(settings, configPath); service.refreshAndClearCache(clientsSettings); + s3AsyncService.refreshAndClearCache(clientsSettings); } @Override public void close() throws IOException { service.close(); + s3AsyncService.close(); } } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java index 3ccf6553c479d..46e589f7fa41f 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java @@ -291,7 +291,7 @@ private void createRepository(final String name, final Settings repositorySettin public static final class ProxyS3RepositoryPlugin extends S3RepositoryPlugin { public ProxyS3RepositoryPlugin(Settings settings, Path configPath) { - super(settings, configPath, new ProxyS3Service(configPath)); + super(settings, configPath, new ProxyS3Service(configPath), new S3AsyncService(configPath)); } @Override @@ -301,7 +301,7 @@ protected S3Repository createRepository( ClusterService clusterService, RecoverySettings recoverySettings ) { - return new S3Repository(metadata, registry, service, clusterService, recoverySettings) { + return new S3Repository(metadata, registry, service, clusterService, recoverySettings, null, null, null, null, false) { @Override protected void assertSnapshotOrGenericThread() { // eliminate thread name check as we create repo manually on test/main threads diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java new file mode 100644 index 0000000000000..10137f0475177 --- /dev/null +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java @@ -0,0 +1,542 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.apache.lucene.store.IndexInput; +import org.junit.After; +import org.junit.Before; +import org.mockito.invocation.InvocationOnMock; +import org.opensearch.action.ActionListener; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.CheckedTriFunction; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; +import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.common.lucene.store.ByteArrayIndexInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferManager; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; +import org.opensearch.test.OpenSearchTestCase; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class S3BlobContainerMockClientTests extends OpenSearchTestCase implements ConfigPathSupport { + + private MockS3AsyncService asyncService; + private ExecutorService futureCompletionService; + private ExecutorService streamReaderService; + private AsyncTransferEventLoopGroup transferNIOGroup; + private S3BlobContainer blobContainer; + + static class MockS3AsyncService extends S3AsyncService { + + private final S3AsyncClient asyncClient = mock(S3AsyncClient.class); + private final int maxDelayInFutureCompletionMillis; + + private boolean failPutObjectRequest; + private boolean failCreateMultipartUploadRequest; + private boolean failUploadPartRequest; + private boolean failCompleteMultipartUploadRequest; + + private String multipartUploadId; + + public MockS3AsyncService(Path configPath, int maxDelayInFutureCompletionMillis) { + super(configPath); + this.maxDelayInFutureCompletionMillis = maxDelayInFutureCompletionMillis; + } + + public void initializeMocks( + boolean failPutObjectRequest, + boolean failCreateMultipartUploadRequest, + boolean failUploadPartRequest, + boolean failCompleteMultipartUploadRequest + ) { + setupFailureBooleans( + failPutObjectRequest, + failCreateMultipartUploadRequest, + failUploadPartRequest, + failCompleteMultipartUploadRequest + ); + doAnswer(this::doOnPutObject).when(asyncClient).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + doAnswer(this::doOnDeleteObject).when(asyncClient).deleteObject(any(DeleteObjectRequest.class)); + doAnswer(this::doOnCreateMultipartUpload).when(asyncClient).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + doAnswer(this::doOnPartUpload).when(asyncClient).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + doAnswer(this::doOnCompleteMultipartUpload).when(asyncClient) + .completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + doAnswer(this::doOnAbortMultipartUpload).when(asyncClient).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + private void setupFailureBooleans( + boolean failPutObjectRequest, + boolean failCreateMultipartUploadRequest, + boolean failUploadPartRequest, + boolean failCompleteMultipartUploadRequest + ) { + this.failPutObjectRequest = failPutObjectRequest; + this.failCreateMultipartUploadRequest = failCreateMultipartUploadRequest; + this.failUploadPartRequest = failUploadPartRequest; + this.failCompleteMultipartUploadRequest = failCompleteMultipartUploadRequest; + } + + private CompletableFuture doOnPutObject(InvocationOnMock invocationOnMock) { + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (failPutObjectRequest) { + completableFuture.completeExceptionally(new IOException()); + } else { + completableFuture.complete(PutObjectResponse.builder().build()); + } + }).start(); + + return completableFuture; + } + + private CompletableFuture doOnDeleteObject(InvocationOnMock invocationOnMock) { + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (failPutObjectRequest) { + completableFuture.completeExceptionally(new IOException()); + } else { + completableFuture.complete(DeleteObjectResponse.builder().build()); + } + }).start(); + + return completableFuture; + } + + private CompletableFuture doOnCreateMultipartUpload(InvocationOnMock invocationOnMock) { + multipartUploadId = randomAlphaOfLength(5); + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (failCreateMultipartUploadRequest) { + completableFuture.completeExceptionally(new IOException()); + } else { + completableFuture.complete(CreateMultipartUploadResponse.builder().uploadId(multipartUploadId).build()); + } + }).start(); + + return completableFuture; + } + + private CompletableFuture doOnPartUpload(InvocationOnMock invocationOnMock) { + UploadPartRequest uploadPartRequest = invocationOnMock.getArgument(0); + assertEquals(multipartUploadId, uploadPartRequest.uploadId()); + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (failUploadPartRequest) { + completableFuture.completeExceptionally(new IOException()); + } else { + completableFuture.complete(UploadPartResponse.builder().eTag("eTag").build()); + } + }).start(); + + return completableFuture; + } + + private CompletableFuture doOnCompleteMultipartUpload(InvocationOnMock invocationOnMock) { + CompleteMultipartUploadRequest completeMultipartUploadRequest = invocationOnMock.getArgument(0); + assertEquals(multipartUploadId, completeMultipartUploadRequest.uploadId()); + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (failCompleteMultipartUploadRequest) { + completableFuture.completeExceptionally(new IOException()); + } else { + completableFuture.complete(CompleteMultipartUploadResponse.builder().build()); + } + }).start(); + + return completableFuture; + } + + private CompletableFuture doOnAbortMultipartUpload(InvocationOnMock invocationOnMock) { + AbortMultipartUploadRequest abortMultipartUploadRequest = invocationOnMock.getArgument(0); + assertEquals(multipartUploadId, abortMultipartUploadRequest.uploadId()); + CompletableFuture completableFuture = new CompletableFuture<>(); + new Thread(() -> { + try { + Thread.sleep(randomInt(maxDelayInFutureCompletionMillis)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + completableFuture.complete(AbortMultipartUploadResponse.builder().build()); + + }).start(); + + return completableFuture; + } + + public void verifyMultipartUploadCallCount(int numberOfParts, boolean finalizeUploadFailure) { + verify(asyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(asyncClient, times(!failCreateMultipartUploadRequest ? numberOfParts : 0)).uploadPart( + any(UploadPartRequest.class), + any(AsyncRequestBody.class) + ); + verify(asyncClient, times(!failCreateMultipartUploadRequest && !failUploadPartRequest && !finalizeUploadFailure ? 1 : 0)) + .completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify( + asyncClient, + times( + (!failCreateMultipartUploadRequest && (failUploadPartRequest || failCompleteMultipartUploadRequest)) + || finalizeUploadFailure ? 1 : 0 + ) + ).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + public void verifySingleChunkUploadCallCount(boolean finalizeUploadFailure) { + verify(asyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + verify(asyncClient, times(finalizeUploadFailure ? 1 : 0)).deleteObject(any(DeleteObjectRequest.class)); + } + + @Override + public AmazonAsyncS3Reference client( + RepositoryMetadata repositoryMetadata, + AsyncExecutorContainer priorityExecutorBuilder, + AsyncExecutorContainer normalExecutorBuilder + ) { + return new AmazonAsyncS3Reference(AmazonAsyncS3WithCredentials.create(asyncClient, asyncClient, null)); + } + } + + /** + * An IndexInput implementation that serves only zeroes + */ + static class ZeroIndexInput extends IndexInput { + + private final AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicLong reads = new AtomicLong(0); + private final long length; + + /** + * @param resourceDescription resourceDescription should be a non-null, opaque string describing this resource; it's returned + * from {@link #toString}. + */ + public ZeroIndexInput(String resourceDescription, final long length) { + super(resourceDescription); + this.length = length; + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + @Override + public long getFilePointer() { + return reads.get(); + } + + @Override + public void seek(long pos) throws IOException { + reads.set(pos); + } + + @Override + public long length() { + return length; + } + + @Override + public IndexInput slice(String sliceDescription, long offset, long length) throws IOException { + return new ZeroIndexInput(sliceDescription, length); + } + + @Override + public byte readByte() throws IOException { + ensureOpen(); + return (byte) ((reads.incrementAndGet() <= length) ? 0 : -1); + } + + @Override + public void readBytes(byte[] b, int offset, int len) throws IOException { + ensureOpen(); + final long available = available(); + final int toCopy = Math.min(len, (int) available); + Arrays.fill(b, offset, offset + toCopy, (byte) 0); + reads.addAndGet(toCopy); + } + + private long available() { + return Math.max(length - reads.get(), 0); + } + + private void ensureOpen() throws IOException { + if (closed.get()) { + throw new IOException("Stream closed"); + } + } + } + + @Override + @Before + public void setUp() throws Exception { + asyncService = new MockS3AsyncService(configPath(), 1000); + futureCompletionService = Executors.newSingleThreadExecutor(); + streamReaderService = Executors.newSingleThreadExecutor(); + transferNIOGroup = new AsyncTransferEventLoopGroup(1); + blobContainer = createBlobContainer(); + super.setUp(); + } + + @Override + @After + public void tearDown() throws Exception { + IOUtils.close(asyncService); + super.tearDown(); + } + + private S3BlobContainer createBlobContainer() { + return new S3BlobContainer(BlobPath.cleanPath(), createBlobStore()); + } + + private S3BlobStore createBlobStore() { + final String clientName = randomAlphaOfLength(5).toLowerCase(Locale.ROOT); + + final RepositoryMetadata repositoryMetadata = new RepositoryMetadata( + "repository", + S3Repository.TYPE, + Settings.builder().put(S3Repository.CLIENT_NAME.getKey(), clientName).build() + ); + + AsyncExecutorContainer asyncExecutorContainer = new AsyncExecutorContainer( + futureCompletionService, + streamReaderService, + transferNIOGroup + ); + + return new S3BlobStore( + null, + asyncService, + true, + "bucket", + S3Repository.SERVER_SIDE_ENCRYPTION_SETTING.getDefault(Settings.EMPTY), + S3Repository.BUFFER_SIZE_SETTING.getDefault(Settings.EMPTY), + S3Repository.CANNED_ACL_SETTING.getDefault(Settings.EMPTY), + S3Repository.STORAGE_CLASS_SETTING.getDefault(Settings.EMPTY), + repositoryMetadata, + new AsyncTransferManager( + S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING.getDefault(Settings.EMPTY).getBytes(), + asyncExecutorContainer.getStreamReader(), + asyncExecutorContainer.getStreamReader() + ), + asyncExecutorContainer, + asyncExecutorContainer + ); + } + + public void testWriteBlobByStreamsNoFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, false, false, false); + testWriteBlobByStreamsLargeBlob(false, false); + } + + public void testWriteBlobByStreamsFinalizeUploadFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, false, false, false); + testWriteBlobByStreamsLargeBlob(false, true); + } + + public void testWriteBlobByStreamsCreateMultipartRequestFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, true, false, false); + testWriteBlobByStreamsLargeBlob(true, false); + } + + public void testWriteBlobByStreamsUploadPartRequestFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, false, true, false); + testWriteBlobByStreamsLargeBlob(true, false); + } + + public void testWriteBlobByStreamsCompleteMultipartRequestFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, false, false, true); + testWriteBlobByStreamsLargeBlob(true, false); + } + + public void testWriteBlobByStreamsSingleChunkUploadNoFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(false, false, false, false); + testWriteBlobByStreams(false, false); + } + + public void testWriteBlobByStreamsSingleChunkUploadPutObjectFailure() throws IOException, ExecutionException, InterruptedException { + asyncService.initializeMocks(true, false, false, false); + testWriteBlobByStreams(true, false); + } + + public void testWriteBlobByStreamsSingleChunkUploadFinalizeUploadFailure() throws IOException, ExecutionException, + InterruptedException { + asyncService.initializeMocks(false, false, false, false); + testWriteBlobByStreams(false, true); + } + + private void testWriteBlobByStreams(boolean expectException, boolean throwExceptionOnFinalizeUpload) throws IOException, + ExecutionException, InterruptedException { + final byte[] bytes = randomByteArrayOfLength(100); + List openInputStreams = new ArrayList<>(); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + blobContainer.asyncBlobUpload(new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ByteArrayIndexInput("desc", bytes), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); + } + }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); + } + }, bytes.length, false, WritePriority.NORMAL, uploadSuccess -> { + assertTrue(uploadSuccess); + if (throwExceptionOnFinalizeUpload) { + throw new RuntimeException(); + } + }, false, null), completionListener); + + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); + // wait for completableFuture to finish + if (expectException || throwExceptionOnFinalizeUpload) { + assertNotNull(exceptionRef.get()); + } + + asyncService.verifySingleChunkUploadCallCount(throwExceptionOnFinalizeUpload); + + openInputStreams.forEach(inputStream -> { + try { + inputStream.close(); + } catch (IOException e) { + fail("Failure while closing open input streams"); + } + }); + } + + private void testWriteBlobByStreamsLargeBlob(boolean expectException, boolean throwExceptionOnFinalizeUpload) throws IOException, + ExecutionException, InterruptedException { + final ByteSizeValue partSize = S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING.getDefault(Settings.EMPTY); + + int numberOfParts = randomIntBetween(2, 5); + final long lastPartSize = randomLongBetween(10, 512); + final long blobSize = ((numberOfParts - 1) * partSize.getBytes()) + lastPartSize; + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + List openInputStreams = new ArrayList<>(); + blobContainer.asyncBlobUpload(new WriteContext("write_large_blob", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ZeroIndexInput("desc", blobSize), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); + } + }, partSize, calculateLastPartSize(blobSize, partSize), calculateNumberOfParts(blobSize, partSize)); + } + }, blobSize, false, WritePriority.HIGH, uploadSuccess -> { + assertTrue(uploadSuccess); + if (throwExceptionOnFinalizeUpload) { + throw new RuntimeException(); + } + }, false, null), completionListener); + + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); + if (expectException || throwExceptionOnFinalizeUpload) { + assertNotNull(exceptionRef.get()); + } + + asyncService.verifyMultipartUploadCallCount(numberOfParts, throwExceptionOnFinalizeUpload); + + openInputStreams.forEach(inputStream -> { + try { + inputStream.close(); + } catch (IOException ex) { + logger.error("Error closing input stream"); + } + }); + } + + private long calculateLastPartSize(long totalSize, long partSize) { + return totalSize % partSize == 0 ? partSize : totalSize % partSize; + } + + private int calculateNumberOfParts(long contentLength, long partSize) { + return (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); + } +} diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java index 045ce73daf5a3..1a1fb123aa5ea 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java @@ -33,14 +33,24 @@ import org.apache.http.HttpStatus; import org.junit.After; +import org.junit.Assert; import org.junit.Before; +import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.CheckedTriFunction; import org.opensearch.common.Nullable; +import org.opensearch.common.StreamContext; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; +import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.hash.MessageDigests; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.io.Streams; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.lucene.store.InputStreamIndexInput; @@ -53,10 +63,13 @@ import org.opensearch.common.util.concurrent.CountDown; import org.opensearch.common.util.io.IOUtils; import org.opensearch.repositories.blobstore.AbstractBlobContainerRetriesTestCase; +import org.opensearch.repositories.blobstore.ZeroInputStream; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferManager; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.io.SdkDigestInputStream; import software.amazon.awssdk.utils.internal.Base16; -import org.opensearch.repositories.blobstore.ZeroInputStream; import java.io.ByteArrayInputStream; import java.io.FilterInputStream; @@ -65,9 +78,16 @@ import java.net.InetSocketAddress; import java.net.SocketTimeoutException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; @@ -88,17 +108,34 @@ public class S3BlobContainerRetriesTests extends AbstractBlobContainerRetriesTes private S3Service service; private String previousOpenSearchPathConf; + private S3AsyncService asyncService; + private ExecutorService futureCompletionService; + private ExecutorService streamReaderService; + private AsyncTransferEventLoopGroup transferNIOGroup; @Before public void setUp() throws Exception { previousOpenSearchPathConf = SocketAccess.doPrivileged(() -> System.setProperty("opensearch.path.conf", configPath().toString())); service = new S3Service(configPath()); + asyncService = new S3AsyncService(configPath()); + + futureCompletionService = Executors.newSingleThreadExecutor(); + streamReaderService = Executors.newSingleThreadExecutor(); + transferNIOGroup = new AsyncTransferEventLoopGroup(1); + + // needed by S3AsyncService + SocketAccess.doPrivileged(() -> System.setProperty("opensearch.path.conf", configPath().toString())); super.setUp(); } @After public void tearDown() throws Exception { - IOUtils.close(service); + IOUtils.close(service, asyncService); + + streamReaderService.shutdown(); + futureCompletionService.shutdown(); + IOUtils.close(transferNIOGroup); + if (previousOpenSearchPathConf != null) { SocketAccess.doPrivileged(() -> System.setProperty("opensearch.path.conf", previousOpenSearchPathConf)); } else { @@ -123,7 +160,7 @@ protected Class unresponsiveExceptionType() { } @Override - protected BlobContainer createBlobContainer( + protected VerifyingMultiStreamBlobContainer createBlobContainer( final @Nullable Integer maxRetries, final @Nullable TimeValue readTimeout, final @Nullable Boolean disableChunkedEncoding, @@ -152,6 +189,7 @@ protected BlobContainer createBlobContainer( secureSettings.setString(S3ClientSettings.SECRET_KEY_SETTING.getConcreteSettingForNamespace(clientName).getKey(), "secret"); clientSettings.setSecureSettings(secureSettings); service.refreshAndClearCache(S3ClientSettings.load(clientSettings.build(), configPath())); + asyncService.refreshAndClearCache(S3ClientSettings.load(clientSettings.build(), configPath())); final RepositoryMetadata repositoryMetadata = new RepositoryMetadata( "repository", @@ -159,16 +197,31 @@ protected BlobContainer createBlobContainer( Settings.builder().put(S3Repository.CLIENT_NAME.getKey(), clientName).build() ); + AsyncExecutorContainer asyncExecutorContainer = new AsyncExecutorContainer( + futureCompletionService, + streamReaderService, + transferNIOGroup + ); + return new S3BlobContainer( BlobPath.cleanPath(), new S3BlobStore( service, + asyncService, + true, "bucket", S3Repository.SERVER_SIDE_ENCRYPTION_SETTING.getDefault(Settings.EMPTY), bufferSize == null ? S3Repository.BUFFER_SIZE_SETTING.getDefault(Settings.EMPTY) : bufferSize, S3Repository.CANNED_ACL_SETTING.getDefault(Settings.EMPTY), S3Repository.STORAGE_CLASS_SETTING.getDefault(Settings.EMPTY), - repositoryMetadata + repositoryMetadata, + new AsyncTransferManager( + S3Repository.PARALLEL_MULTIPART_UPLOAD_MINIMUM_PART_SIZE_SETTING.getDefault(Settings.EMPTY).getBytes(), + asyncExecutorContainer.getStreamReader(), + asyncExecutorContainer.getStreamReader() + ), + asyncExecutorContainer, + asyncExecutorContainer ) ) { @Override @@ -228,6 +281,87 @@ public void testWriteBlobWithRetries() throws Exception { assertThat(countDown.isCountedDown(), is(true)); } + public void testWriteBlobByStreamsWithRetries() throws Exception { + final int maxRetries = randomInt(5); + final CountDown countDown = new CountDown(maxRetries + 1); + + final byte[] bytes = randomBlobContent(); + httpServer.createContext("/bucket/write_blob_by_streams_max_retries", exchange -> { + if ("PUT".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery() == null) { + if (countDown.countDown()) { + final BytesReference body = Streams.readFully(exchange.getRequestBody()); + if (Objects.deepEquals(bytes, BytesReference.toBytes(body))) { + exchange.sendResponseHeaders(HttpStatus.SC_OK, -1); + } else { + exchange.sendResponseHeaders(HttpStatus.SC_BAD_REQUEST, -1); + } + exchange.close(); + return; + } + + if (randomBoolean()) { + if (randomBoolean()) { + Streams.readFully(exchange.getRequestBody(), new byte[randomIntBetween(1, Math.max(1, bytes.length - 1))]); + } else { + Streams.readFully(exchange.getRequestBody()); + exchange.sendResponseHeaders( + randomFrom( + HttpStatus.SC_INTERNAL_SERVER_ERROR, + HttpStatus.SC_BAD_GATEWAY, + HttpStatus.SC_SERVICE_UNAVAILABLE, + HttpStatus.SC_GATEWAY_TIMEOUT + ), + -1 + ); + } + } + exchange.close(); + } + }); + + final VerifyingMultiStreamBlobContainer blobContainer = createBlobContainer(maxRetries, null, true, null); + List openInputStreams = new ArrayList<>(); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + ActionListener completionListener = ActionListener.wrap(resp -> { countDownLatch.countDown(); }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + blobContainer.asyncBlobUpload(new WriteContext("write_blob_by_streams_max_retries", new StreamContextSupplier() { + @Override + public StreamContext supplyStreamContext(long partSize) { + return new StreamContext(new CheckedTriFunction() { + @Override + public InputStreamContainer apply(Integer partNo, Long size, Long position) throws IOException { + InputStream inputStream = new OffsetRangeIndexInputStream(new ByteArrayIndexInput("desc", bytes), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); + } + }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); + } + }, bytes.length, false, WritePriority.NORMAL, Assert::assertTrue, false, null), completionListener); + + assertTrue(countDownLatch.await(5000, TimeUnit.SECONDS)); + + assertThat(countDown.isCountedDown(), is(true)); + + openInputStreams.forEach(inputStream -> { + try { + inputStream.close(); + } catch (IOException e) { + fail("Failure while closing open input streams"); + } + }); + } + + private long calculateLastPartSize(long totalSize, long partSize) { + return totalSize % partSize == 0 ? partSize : totalSize % partSize; + } + + private int calculateNumberOfParts(long contentLength, long partSize) { + return (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); + } + public void testWriteBlobWithReadTimeouts() { final byte[] bytes = randomByteArrayOfLength(randomIntBetween(10, 128)); final TimeValue readTimeout = TimeValue.timeValueMillis(randomIntBetween(100, 500)); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3ClientSettingsTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3ClientSettingsTests.java index 130b8efca0512..1edf8d53c1e73 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3ClientSettingsTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3ClientSettingsTests.java @@ -68,7 +68,11 @@ public void testThereIsADefaultClientByDefault() { assertThat(defaultSettings.endpoint, is(emptyString())); assertThat(defaultSettings.protocol, is(Protocol.HTTPS)); assertThat(defaultSettings.proxySettings, is(ProxySettings.NO_PROXY_SETTINGS)); - assertThat(defaultSettings.readTimeoutMillis, is(50_000)); + assertThat(defaultSettings.readTimeoutMillis, is(50 * 1000)); + assertThat(defaultSettings.requestTimeoutMillis, is(120 * 1000)); + assertThat(defaultSettings.connectionTimeoutMillis, is(10 * 1000)); + assertThat(defaultSettings.connectionTTLMillis, is(5 * 1000)); + assertThat(defaultSettings.maxConnections, is(100)); assertThat(defaultSettings.maxRetries, is(3)); assertThat(defaultSettings.throttleRetries, is(true)); } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RepositoryTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RepositoryTests.java index dc63ed50d5f3d..84d56c7ae2854 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RepositoryTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RepositoryTests.java @@ -138,7 +138,12 @@ private S3Repository createS3Repo(RepositoryMetadata metadata) { NamedXContentRegistry.EMPTY, new DummyS3Service(configPath()), BlobStoreTestUtil.mockClusterService(), - new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)) + new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + null, + null, + null, + null, + false ) { @Override protected void assertSnapshotOrGenericThread() {