Skip to content

Commit

Permalink
Fixes to prevent segment fault errors arising due to unexpected SDK b… (
Browse files Browse the repository at this point in the history
#11334) (#11509)

(cherry picked from commit 2a4aafd)

Signed-off-by: vikasvb90 <vikasvb@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 5a820ca commit d46cc07
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 57 deletions.
2 changes: 2 additions & 0 deletions server/src/main/java/org/opensearch/common/StreamContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ protected StreamContext(StreamContext streamContext) {
/**
* Vendor plugins can use this method to create new streams only when they are required for processing
* New streams won't be created till this method is called with the specific <code>partNumber</code>
* It is the responsibility of caller to ensure that stream is properly closed after consumption
* otherwise it can leak resources.
*
* @param partNumber The index of the part
* @return A stream reference to the part requested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.blobstore.stream.write.WritePriority;
import org.opensearch.common.blobstore.transfer.stream.OffsetRangeInputStream;
import org.opensearch.common.blobstore.transfer.stream.RateLimitingOffsetRangeInputStream;
import org.opensearch.common.blobstore.transfer.stream.ResettableCheckedInputStream;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.common.util.ByteUtils;
Expand All @@ -27,6 +28,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.zip.CRC32;

import com.jcraft.jzlib.JZlib;
Expand All @@ -43,14 +46,15 @@ public class RemoteTransferContainer implements Closeable {
private long lastPartSize;

private final long contentLength;
private final SetOnce<InputStream[]> inputStreams = new SetOnce<>();
private final SetOnce<Supplier<Long>[]> checksumSuppliers = new SetOnce<>();
private final String fileName;
private final String remoteFileName;
private final boolean failTransferIfFileExists;
private final WritePriority writePriority;
private final long expectedChecksum;
private final OffsetRangeInputStreamSupplier offsetRangeInputStreamSupplier;
private final boolean isRemoteDataIntegritySupported;
private final AtomicBoolean readBlock = new AtomicBoolean();

private static final Logger log = LogManager.getLogger(RemoteTransferContainer.class);

Expand Down Expand Up @@ -120,23 +124,24 @@ StreamContext supplyStreamContext(long partSize) {
}
}

@SuppressWarnings({ "unchecked" })
private StreamContext openMultipartStreams(long partSize) throws IOException {
if (inputStreams.get() != null) {
if (checksumSuppliers.get() != null) {
throw new IOException("Multi-part streams are already created.");
}

this.partSize = partSize;
this.lastPartSize = (contentLength % partSize) != 0 ? contentLength % partSize : partSize;
this.numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
InputStream[] streams = new InputStream[numberOfParts];
inputStreams.set(streams);
Supplier<Long>[] suppliers = new Supplier[numberOfParts];
checksumSuppliers.set(suppliers);

return new StreamContext(getTransferPartStreamSupplier(), partSize, lastPartSize, numberOfParts);
}

private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOException> getTransferPartStreamSupplier() {
return ((partNo, size, position) -> {
assert inputStreams.get() != null : "expected inputStreams to be initialised";
assert checksumSuppliers.get() != null : "expected container to be initialised";
return getMultipartStreamSupplier(partNo, size, position).get();
});
}
Expand All @@ -160,10 +165,21 @@ private LocalStreamSupplier<InputStreamContainer> getMultipartStreamSupplier(
return () -> {
try {
OffsetRangeInputStream offsetRangeInputStream = offsetRangeInputStreamSupplier.get(size, position);
InputStream inputStream = !isRemoteDataIntegrityCheckPossible()
? new ResettableCheckedInputStream(offsetRangeInputStream, fileName)
: offsetRangeInputStream;
Objects.requireNonNull(inputStreams.get())[streamIdx] = inputStream;
if (offsetRangeInputStream instanceof RateLimitingOffsetRangeInputStream) {
RateLimitingOffsetRangeInputStream rangeIndexInputStream = (RateLimitingOffsetRangeInputStream) offsetRangeInputStream;
rangeIndexInputStream.setReadBlock(readBlock);
}
InputStream inputStream;
if (isRemoteDataIntegrityCheckPossible() == false) {
ResettableCheckedInputStream resettableCheckedInputStream = new ResettableCheckedInputStream(
offsetRangeInputStream,
fileName
);
Objects.requireNonNull(checksumSuppliers.get())[streamIdx] = resettableCheckedInputStream::getChecksum;
inputStream = resettableCheckedInputStream;
} else {
inputStream = offsetRangeInputStream;
}

return new InputStreamContainer(inputStream, size, position);
} catch (IOException e) {
Expand Down Expand Up @@ -205,48 +221,23 @@ public long getContentLength() {
return contentLength;
}

private long getInputStreamChecksum(InputStream inputStream) {
assert inputStream instanceof ResettableCheckedInputStream
: "expected passed inputStream to be instance of ResettableCheckedInputStream";
return ((ResettableCheckedInputStream) inputStream).getChecksum();
}

private long getActualChecksum() {
InputStream[] currentInputStreams = Objects.requireNonNull(inputStreams.get());
long checksum = getInputStreamChecksum(currentInputStreams[0]);
for (int checkSumIdx = 1; checkSumIdx < Objects.requireNonNull(inputStreams.get()).length - 1; checkSumIdx++) {
checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[checkSumIdx]), partSize);
Supplier<Long>[] ckSumSuppliers = Objects.requireNonNull(checksumSuppliers.get());
long checksum = ckSumSuppliers[0].get();
for (int checkSumIdx = 1; checkSumIdx < ckSumSuppliers.length - 1; checkSumIdx++) {
checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[checkSumIdx].get(), partSize);
}
if (numberOfParts > 1) {
checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[numberOfParts - 1]), lastPartSize);
checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[numberOfParts - 1].get(), lastPartSize);
}

return checksum;
}

@Override
public void close() throws IOException {
if (inputStreams.get() == null) {
log.warn("Input streams cannot be closed since they are not yet set for multi stream upload");
return;
}

boolean closeStreamException = false;
for (InputStream is : Objects.requireNonNull(inputStreams.get())) {
try {
if (is != null) {
is.close();
}
} catch (IOException ex) {
closeStreamException = true;
// Attempting to close all streams first before throwing exception.
log.error("Multipart stream failed to close ", ex);
}
}

if (closeStreamException) {
throw new IOException("Closure of some of the multi-part streams failed.");
}
// Setting a read block on all streams ever created by the container.
readBlock.set(true);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,29 @@

package org.opensearch.common.blobstore.transfer.stream;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.store.AlreadyClosedException;
import org.apache.lucene.store.IndexInput;
import org.opensearch.common.concurrent.RefCountedReleasable;
import org.opensearch.common.lucene.store.InputStreamIndexInput;
import org.opensearch.common.util.concurrent.RunOnce;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* OffsetRangeIndexInputStream extends InputStream to read from a specified offset using IndexInput
*
* @opensearch.internal
*/
public class OffsetRangeIndexInputStream extends OffsetRangeInputStream {

private static final Logger logger = LogManager.getLogger(OffsetRangeIndexInputStream.class);
private final InputStreamIndexInput inputStreamIndexInput;
private final IndexInput indexInput;
private AtomicBoolean readBlock;
private final OffsetRangeRefCount offsetRangeRefCount;
private final RunOnce closeOnce;

/**
* Construct a new OffsetRangeIndexInputStream object
Expand All @@ -35,16 +44,68 @@ public OffsetRangeIndexInputStream(IndexInput indexInput, long size, long positi
indexInput.seek(position);
this.indexInput = indexInput;
this.inputStreamIndexInput = new InputStreamIndexInput(indexInput, size);
ClosingStreams closingStreams = new ClosingStreams(inputStreamIndexInput, indexInput);
offsetRangeRefCount = new OffsetRangeRefCount(closingStreams);
closeOnce = new RunOnce(offsetRangeRefCount::decRef);
}

@Override
public void setReadBlock(AtomicBoolean readBlock) {
this.readBlock = readBlock;
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
return inputStreamIndexInput.read(b, off, len);
// There are two levels of check to ensure that we don't read an already closed stream and
// to not close the stream if it is already being read.
// 1. First check is a coarse-grained check outside reference check which allows us to fail fast if read
// was invoked after the stream was closed. We need a separate atomic boolean closed because we don't want a
// future read to succeed when #close has been invoked even if there are on-going reads. On-going reads would
// hold reference and since ref count will not be 0 even after close was invoked, future reads will go through
// without a check on closed. Also, we do need to set closed externally. It is shared across all streams of the
// file. Check on closed in this class makes sure that no other stream allows subsequent reads. closed is
// being set to true in RemoteTransferContainer#close which is invoked when we are done processing all
// parts/file. Processing completes when either all parts are completed successfully or if either of the parts
// failed. In successful case, subsequent read will anyway not go through since all streams would have been
// consumed fully but in case of failure, SDK can continue to invoke read and this would be a wasted compute
// and IO.
// 2. In second check, a tryIncRef is invoked which tries to increment reference under lock and fails if ref
// is already closed. If reference is successfully obtained by the stream then stream will not be closed.
// Ref counting ensures that stream isn't closed in between reads.
//
// All these protection mechanisms are required in order to prevent invalid access to streams happening
// from the new S3 async SDK.
ensureReadable();
try (OffsetRangeRefCount ignored = getStreamReference()) {
return inputStreamIndexInput.read(b, off, len);
}
}

private OffsetRangeRefCount getStreamReference() {
boolean successIncrement = offsetRangeRefCount.tryIncRef();
if (successIncrement == false) {
throw alreadyClosed("OffsetRangeIndexInputStream is already unreferenced.");
}
return offsetRangeRefCount;
}

private void ensureReadable() {
if (readBlock != null && readBlock.get() == true) {
logger.debug("Read attempted on a stream which was read blocked!");
throw alreadyClosed("Read blocked stream.");
}
}

AlreadyClosedException alreadyClosed(String msg) {
return new AlreadyClosedException(msg + this);
}

@Override
public int read() throws IOException {
return inputStreamIndexInput.read();
ensureReadable();
try (OffsetRangeRefCount ignored = getStreamReference()) {
return inputStreamIndexInput.read();
}
}

@Override
Expand All @@ -67,9 +128,42 @@ public long getFilePointer() throws IOException {
return indexInput.getFilePointer();
}

@Override
public String toString() {
return "OffsetRangeIndexInputStream{" + "indexInput=" + indexInput + ", readBlock=" + readBlock + '}';
}

private static class ClosingStreams {
private final InputStreamIndexInput inputStreamIndexInput;
private final IndexInput indexInput;

public ClosingStreams(InputStreamIndexInput inputStreamIndexInput, IndexInput indexInput) {
this.inputStreamIndexInput = inputStreamIndexInput;
this.indexInput = indexInput;
}
}

private static class OffsetRangeRefCount extends RefCountedReleasable<ClosingStreams> {
private static final Logger logger = LogManager.getLogger(OffsetRangeRefCount.class);

public OffsetRangeRefCount(ClosingStreams ref) {
super("OffsetRangeRefCount", ref, () -> {
try {
ref.inputStreamIndexInput.close();
} catch (IOException ex) {
logger.error("Failed to close indexStreamIndexInput", ex);
}
try {
ref.indexInput.close();
} catch (IOException ex) {
logger.error("Failed to close indexInput", ex);
}
});
}
}

@Override
public void close() throws IOException {
inputStreamIndexInput.close();
indexInput.close();
closeOnce.run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* OffsetRangeInputStream is an abstract class that extends from {@link InputStream}
Expand All @@ -19,4 +20,8 @@
*/
public abstract class OffsetRangeInputStream extends InputStream {
public abstract long getFilePointer() throws IOException;

public void setReadBlock(AtomicBoolean readBlock) {
// Nothing
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.StreamLimiter;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

/**
Expand Down Expand Up @@ -40,6 +41,10 @@ public RateLimitingOffsetRangeInputStream(
this.delegate = delegate;
}

public void setReadBlock(AtomicBoolean readBlock) {
delegate.setReadBlock(readBlock);
}

@Override
public int read() throws IOException {
int b = delegate.read();
Expand Down
Loading

0 comments on commit d46cc07

Please sign in to comment.