Skip to content

Commit

Permalink
Refactor async blob read to avoid blocking calls, support non multipa…
Browse files Browse the repository at this point in the history
…rt calls (opensearch-project#10192)

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
  • Loading branch information
kotwanikunal authored and rayshrey committed Oct 3, 2023
1 parent f8c4846 commit bd17d8d
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,35 +228,50 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) {
final S3AsyncClient s3AsyncClient = amazonS3Reference.get().client();
final String bucketName = blobStore.bucket();
final String blobKey = buildKey(blobName);

final GetObjectAttributesResponse blobMetadata = getBlobMetadata(s3AsyncClient, bucketName, blobName).get();
final CompletableFuture<GetObjectAttributesResponse> blobMetadataFuture = getBlobMetadata(s3AsyncClient, bucketName, blobKey);

final long blobSize = blobMetadata.objectSize();
final int numberOfParts = blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

final List<InputStreamContainer> blobPartStreams = new ArrayList<>();
final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, partNumber));
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> {
if (throwable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
blobMetadataFuture.whenComplete((blobMetadata, throwable) -> {
if (throwable != null) {
Exception ex = throwable.getCause() instanceof Exception
? (Exception) throwable.getCause()
: new Exception(throwable.getCause());
listener.onFailure(ex);
return;
}

final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
}
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down Expand Up @@ -685,41 +700,47 @@ static Tuple<Long, Long> numberOfMultiparts(final long totalSize, final long par
* the stream and its related metadata.
* @param s3AsyncClient Async client to be utilized to fetch the object part
* @param bucketName Name of the S3 bucket
* @param blobName Identifier of the blob for which the parts will be fetched
* @param partNumber Part number for the blob to be retrieved
* @param blobKey Identifier of the blob for which the parts will be fetched
* @param partNumber Optional part number for the blob to be retrieved
* @return A future of {@link InputStreamContainer} containing the stream and stream metadata.
*/
CompletableFuture<InputStreamContainer> getBlobPartInputStreamContainer(
S3AsyncClient s3AsyncClient,
String bucketName,
String blobName,
int partNumber
String blobKey,
@Nullable Integer partNumber
) {
final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder()
.bucket(bucketName)
.key(blobName)
.partNumber(partNumber);
final boolean isMultipartObject = partNumber != null;
final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder().bucket(bucketName).key(blobKey);

if (isMultipartObject) {
getObjectRequestBuilder.partNumber(partNumber);
}

return SocketAccess.doPrivileged(
() -> s3AsyncClient.getObject(getObjectRequestBuilder.build(), AsyncResponseTransformer.toBlockingInputStream())
.thenApply(S3BlobContainer::transformResponseToInputStreamContainer)
.thenApply(response -> transformResponseToInputStreamContainer(response, isMultipartObject))
);
}

/**
* Transforms the stream response object from S3 into an {@link InputStreamContainer}
* @param streamResponse Response stream object from S3
* @param isMultipartObject Flag to denote a multipart object response
* @return {@link InputStreamContainer} containing the stream and stream metadata
*/
// Package-Private for testing.
static InputStreamContainer transformResponseToInputStreamContainer(ResponseInputStream<GetObjectResponse> streamResponse) {
static InputStreamContainer transformResponseToInputStreamContainer(
ResponseInputStream<GetObjectResponse> streamResponse,
boolean isMultipartObject
) {
final GetObjectResponse getObjectResponse = streamResponse.response();
final String contentRange = getObjectResponse.contentRange();
final Long contentLength = getObjectResponse.contentLength();
if (contentRange == null || contentLength == null) {
if ((isMultipartObject && contentRange == null) || contentLength == null) {
throw SdkException.builder().message("Failed to fetch required metadata for blob part").build();
}
final Long offset = HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange());
final long offset = isMultipartObject ? HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange()) : 0L;
return new InputStreamContainer(streamResponse, getObjectResponse.contentLength(), offset);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.repositories.s3.async.AsyncTransferManager;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
Expand All @@ -100,7 +99,6 @@
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -919,7 +917,7 @@ public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanNumberO
testListBlobsByPrefixInLexicographicOrder(12, 2, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC);
}

public void testReadBlobAsync() throws Exception {
public void testReadBlobAsyncMultiPart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);
Expand All @@ -932,19 +930,14 @@ public void testReadBlobAsync() throws Exception {
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final AsyncTransferManager asyncTransferManager = new AsyncTransferManager(
10000L,
mock(ExecutorService.class),
mock(ExecutorService.class)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);
when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
Expand Down Expand Up @@ -984,6 +977,60 @@ public void testReadBlobAsync() throws Exception {
}
}

public void testReadBlobAsyncSinglePart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final int objectSize = 100;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize((long) objectSize)
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

mockObjectResponse(s3AsyncClient, bucketName, blobName, objectSize);

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(1, readContextActionListener.getResponseCount());
assertEquals(0, readContextActionListener.getFailureCount());
ReadContext readContext = readContextActionListener.getResponse();
assertEquals(1, readContext.getNumberOfParts());
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);

}

public void testReadBlobAsyncFailure() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
Expand All @@ -996,19 +1043,14 @@ public void testReadBlobAsyncFailure() throws Exception {
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final AsyncTransferManager asyncTransferManager = new AsyncTransferManager(
10000L,
mock(ExecutorService.class),
mock(ExecutorService.class)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);
when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
Expand Down Expand Up @@ -1071,7 +1113,7 @@ public void testGetBlobPartInputStream() throws Exception {
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final long contentLength = 10L;
final String contentRange = "bytes 0-10/100";
final String contentRange = "bytes 10-20/100";
final InputStream inputStream = ResponseInputStream.nullInputStream();

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
Expand All @@ -1095,9 +1137,17 @@ public void testGetBlobPartInputStream() throws Exception {
)
).thenReturn(getObjectPartResponse);

// Header based offset in case of a multi part object request
InputStreamContainer inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, 0)
.get();

assertEquals(10, inputStreamContainer.getOffset());
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());

// 0 offset in case of a single part object request
inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, null).get();

assertEquals(0, inputStreamContainer.getOffset());
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());
Expand All @@ -1108,28 +1158,65 @@ public void testTransformResponseToInputStreamContainer() throws Exception {
final long contentLength = 10L;
final InputStream inputStream = ResponseInputStream.nullInputStream();

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);

GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength(contentLength).build();

// Exception when content range absent for multipart object
ResponseInputStream<GetObjectResponse> responseInputStreamNoRange = new ResponseInputStream<>(getObjectResponse, inputStream);
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange));
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange, true));

// No exception when content range absent for single part object
ResponseInputStream<GetObjectResponse> responseInputStreamNoRangeSinglePart = new ResponseInputStream<>(
getObjectResponse,
inputStream
);
InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(
responseInputStreamNoRangeSinglePart,
false
);
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());

// Exception when length is absent
getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).build();
ResponseInputStream<GetObjectResponse> responseInputStreamNoContentLength = new ResponseInputStream<>(
getObjectResponse,
inputStream
);
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength));
assertThrows(
SdkException.class,
() -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength, true)
);

// No exception when range and length both are present
getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).contentLength(contentLength).build();
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream);
InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream);
inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream, true);
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());
}

private void mockObjectResponse(S3AsyncClient s3AsyncClient, String bucketName, String blobName, int objectSize) {

final InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(objectSize));

GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength((long) objectSize).build();

CompletableFuture<ResponseInputStream<GetObjectResponse>> getObjectPartResponse = new CompletableFuture<>();
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream);
getObjectPartResponse.complete(responseInputStream);

GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(blobName).build();

when(
s3AsyncClient.getObject(
eq(getObjectRequest),
ArgumentMatchers.<AsyncResponseTransformer<GetObjectResponse, ResponseInputStream<GetObjectResponse>>>any()
)
).thenReturn(getObjectPartResponse);

}

private void mockObjectPartResponse(
S3AsyncClient s3AsyncClient,
String bucketName,
Expand Down

0 comments on commit bd17d8d

Please sign in to comment.