diff --git a/server/src/main/java/org/opensearch/common/StreamContext.java b/server/src/main/java/org/opensearch/common/StreamContext.java
index b163ba65dc7db..47a3d2b8571ea 100644
--- a/server/src/main/java/org/opensearch/common/StreamContext.java
+++ b/server/src/main/java/org/opensearch/common/StreamContext.java
@@ -57,6 +57,8 @@ protected StreamContext(StreamContext streamContext) {
/**
* Vendor plugins can use this method to create new streams only when they are required for processing
* New streams won't be created till this method is called with the specific partNumber
+ * It is the responsibility of caller to ensure that stream is properly closed after consumption
+ * otherwise it can leak resources.
*
* @param partNumber The index of the part
* @return A stream reference to the part requested
diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java
index 077e02988d8b3..e14f512917685 100644
--- a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java
+++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java
@@ -29,6 +29,7 @@
import java.io.InputStream;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
import java.util.zip.CRC32;
import com.jcraft.jzlib.JZlib;
@@ -45,7 +46,7 @@ public class RemoteTransferContainer implements Closeable {
private long lastPartSize;
private final long contentLength;
- private final SetOnce inputStreams = new SetOnce<>();
+ private final SetOnce[]> checksumSuppliers = new SetOnce<>();
private final String fileName;
private final String remoteFileName;
private final boolean failTransferIfFileExists;
@@ -123,23 +124,24 @@ StreamContext supplyStreamContext(long partSize) {
}
}
+ @SuppressWarnings({ "unchecked" })
private StreamContext openMultipartStreams(long partSize) throws IOException {
- if (inputStreams.get() != null) {
+ if (checksumSuppliers.get() != null) {
throw new IOException("Multi-part streams are already created.");
}
this.partSize = partSize;
this.lastPartSize = (contentLength % partSize) != 0 ? contentLength % partSize : partSize;
this.numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
- InputStream[] streams = new InputStream[numberOfParts];
- inputStreams.set(streams);
+ Supplier[] suppliers = new Supplier[numberOfParts];
+ checksumSuppliers.set(suppliers);
return new StreamContext(getTransferPartStreamSupplier(), partSize, lastPartSize, numberOfParts);
}
private CheckedTriFunction getTransferPartStreamSupplier() {
return ((partNo, size, position) -> {
- assert inputStreams.get() != null : "expected inputStreams to be initialised";
+ assert checksumSuppliers.get() != null : "expected container to be initialised";
return getMultipartStreamSupplier(partNo, size, position).get();
});
}
@@ -167,10 +169,17 @@ private LocalStreamSupplier getMultipartStreamSupplier(
RateLimitingOffsetRangeInputStream rangeIndexInputStream = (RateLimitingOffsetRangeInputStream) offsetRangeInputStream;
rangeIndexInputStream.setClose(closed);
}
- InputStream inputStream = !isRemoteDataIntegrityCheckPossible()
- ? new ResettableCheckedInputStream(offsetRangeInputStream, fileName)
- : offsetRangeInputStream;
- Objects.requireNonNull(inputStreams.get())[streamIdx] = inputStream;
+ InputStream inputStream;
+ if (isRemoteDataIntegrityCheckPossible() == false) {
+ ResettableCheckedInputStream resettableCheckedInputStream = new ResettableCheckedInputStream(
+ offsetRangeInputStream,
+ fileName
+ );
+ Objects.requireNonNull(checksumSuppliers.get())[streamIdx] = resettableCheckedInputStream::getChecksum;
+ inputStream = resettableCheckedInputStream;
+ } else {
+ inputStream = offsetRangeInputStream;
+ }
return new InputStreamContainer(inputStream, size, position);
} catch (IOException e) {
@@ -212,20 +221,14 @@ public long getContentLength() {
return contentLength;
}
- private long getInputStreamChecksum(InputStream inputStream) {
- assert inputStream instanceof ResettableCheckedInputStream
- : "expected passed inputStream to be instance of ResettableCheckedInputStream";
- return ((ResettableCheckedInputStream) inputStream).getChecksum();
- }
-
private long getActualChecksum() {
- InputStream[] currentInputStreams = Objects.requireNonNull(inputStreams.get());
- long checksum = getInputStreamChecksum(currentInputStreams[0]);
- for (int checkSumIdx = 1; checkSumIdx < Objects.requireNonNull(inputStreams.get()).length - 1; checkSumIdx++) {
- checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[checkSumIdx]), partSize);
+ Supplier[] ckSumSuppliers = Objects.requireNonNull(checksumSuppliers.get());
+ long checksum = ckSumSuppliers[0].get();
+ for (int checkSumIdx = 1; checkSumIdx < ckSumSuppliers.length - 1; checkSumIdx++) {
+ checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[checkSumIdx].get(), partSize);
}
if (numberOfParts > 1) {
- checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[numberOfParts - 1]), lastPartSize);
+ checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[numberOfParts - 1].get(), lastPartSize);
}
return checksum;