Skip to content

Commit

Permalink
Close the input stream while decompressing data
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 4, 2024
1 parent a00d05f commit 559bbed
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import io.trino.client.QueryDataDecoder;
import io.trino.client.spooling.DataAttribute;
import io.trino.client.spooling.DataAttributes;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public abstract class CompressedQueryDataDecoder
Expand All @@ -35,7 +38,7 @@ public CompressedQueryDataDecoder(QueryDataDecoder delegate)
this.delegate = requireNonNull(delegate, "delegate is null");
}

abstract InputStream decompress(byte[] bytes, int expectedDecompressedSize)
abstract void decompress(byte[] input, byte[] output)
throws IOException;

@Override
Expand All @@ -44,8 +47,11 @@ public QueryDataAccess decode(InputStream stream, DataAttributes metadata)
{
Optional<Integer> expectedDecompressedSize = metadata.getOptional(DataAttribute.UNCOMPRESSED_SIZE, Integer.class);
if (expectedDecompressedSize.isPresent()) {
try (ReadOnceByteSource source = new ReadOnceByteSource(stream, expectedDecompressedSize.get())) {
return delegate.decode(decompress(source.read(), expectedDecompressedSize.get()), metadata);
long uncompressedSize = expectedDecompressedSize.get();
try (ReadOnceByteSource source = new ReadOnceByteSource(stream, uncompressedSize)) {
byte[] output = new byte[toIntExact(uncompressedSize)];
decompress(source.read(), output);
return delegate.decode(new ByteArrayInputStream(output), metadata);
}
}
// Data not compressed - below threshold
Expand Down Expand Up @@ -73,6 +79,7 @@ public InputStream openStream()
return inputStream;
}

@SuppressModernizer // This is Guava's API
@Override
public com.google.common.base.Optional<Long> sizeIfKnown()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
*/
package io.trino.client.spooling.encoding;

import com.google.common.io.ByteStreams;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.trino.client.QueryDataDecoder;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import static java.lang.String.format;

Expand All @@ -32,17 +29,14 @@ public Lz4QueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(byte[] bytes, int expectedDecompressedSize)
void decompress(byte[] input, byte[] output)
throws IOException
{
Lz4Decompressor decompressor = new Lz4Decompressor();
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
int decompressedSize = decompressor.decompress(input, 0, input.length, output, 0, output.length);
if (decompressedSize != output.length) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, output.length));
}
return new ByteArrayInputStream(output);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import io.airlift.compress.zstd.ZstdDecompressor;
import io.trino.client.QueryDataDecoder;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import static java.lang.String.format;

Expand All @@ -31,17 +29,14 @@ public ZstdQueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(byte[] bytes, int expectedDecompressedSize)
void decompress(byte[] bytes, byte[] output)
throws IOException
{
ZstdDecompressor decompressor = new ZstdDecompressor();
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
if (decompressedSize != output.length) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, output.length));
}
return new ByteArrayInputStream(output);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.client.spooling.DataAttribute.UNCOMPRESSED_SIZE;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;

class TestCompressedQueryDataDecoder
Expand All @@ -39,7 +39,7 @@ public void testClosesUnderlyingInputStreamIfCompressed()
throws IOException
{
AtomicBoolean closed = new AtomicBoolean();
InputStream stream = new FilterInputStream(new ByteArrayInputStream("compressed".getBytes())) {
InputStream stream = new FilterInputStream(new ByteArrayInputStream("compressed".getBytes(UTF_8))) {
@Override
public void close()
throws IOException
Expand All @@ -54,7 +54,7 @@ public void close()
public QueryDataAccess decode(InputStream input, DataAttributes segmentAttributes)
throws IOException
{
assertThat(new String(ByteStreams.toByteArray(input), StandardCharsets.UTF_8))
assertThat(new String(ByteStreams.toByteArray(input), UTF_8))
.isEqualTo("decompressed");
return () -> SAMPLE_VALUES;
}
Expand All @@ -67,15 +67,17 @@ public String encoding()
});

assertThat(closed.get()).isFalse();
assertThat(decoder.decode(stream, DataAttributes.builder().set(UNCOMPRESSED_SIZE, 5).build()).toIterable())
assertThat(decoder.decode(stream, DataAttributes.builder().set(UNCOMPRESSED_SIZE, 12).build()).toIterable())
.isEqualTo(SAMPLE_VALUES);
assertThat(closed.get()).isTrue();
}

@Test
public void testDelegatesClosingIfUncompressed() throws IOException {
public void testDelegatesClosingIfUncompressed()
throws IOException
{
AtomicBoolean closed = new AtomicBoolean();
InputStream stream = new FilterInputStream(new ByteArrayInputStream("not compressed".getBytes())) {
InputStream stream = new FilterInputStream(new ByteArrayInputStream("not compressed".getBytes(UTF_8))) {
@Override
public void close()
throws IOException
Expand All @@ -90,7 +92,7 @@ public void close()
public QueryDataAccess decode(InputStream input, DataAttributes segmentAttributes)
throws IOException
{
assertThat(new String(ByteStreams.toByteArray(input), StandardCharsets.UTF_8))
assertThat(new String(ByteStreams.toByteArray(input), UTF_8))
.isEqualTo("not compressed");
input.close(); // Closes input stream according to the contract
return () -> SAMPLE_VALUES;
Expand Down Expand Up @@ -118,11 +120,13 @@ public TestQueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(byte[] bytes, int expectedDecompressedSize)
void decompress(byte[] bytes, byte[] output)
{
assertThat(new String(bytes, StandardCharsets.UTF_8))
assertThat(new String(bytes, UTF_8))
.isEqualTo("compressed");
return new ByteArrayInputStream("decompressed".getBytes(StandardCharsets.UTF_8));

byte[] uncompressed = "decompressed".getBytes(UTF_8);
System.arraycopy(uncompressed, 0, output, 0, uncompressed.length);
}

@Override
Expand Down

0 comments on commit 559bbed

Please sign in to comment.