Skip to content

Commit

Permalink
Close the input stream while decompressing data
Browse files Browse the repository at this point in the history
Add a test to check whether compressed streams are closed properly
  • Loading branch information
wendigo committed Oct 4, 2024
1 parent a524522 commit 3843fb8
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
*/
package io.trino.client.spooling.encoding;

import com.google.common.io.ByteStreams;
import io.trino.client.QueryDataDecoder;
import io.trino.client.spooling.DataAttribute;
import io.trino.client.spooling.DataAttributes;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static java.util.Objects.requireNonNull;

public abstract class CompressedQueryDataDecoder
Expand All @@ -33,16 +36,26 @@ public CompressedQueryDataDecoder(QueryDataDecoder delegate)
this.delegate = requireNonNull(delegate, "delegate is null");
}

abstract InputStream decompress(InputStream inputStream, int expectedDecompressedSize)
abstract void decompress(byte[] input, byte[] output)
throws IOException;

@Override
public QueryDataAccess decode(InputStream stream, DataAttributes metadata)
throws IOException
{
Optional<Integer> expectedDecompressedSize = metadata.getOptional(DataAttribute.UNCOMPRESSED_SIZE, Integer.class);
int segmentSize = metadata.get(DataAttribute.SEGMENT_SIZE, Integer.class);

if (expectedDecompressedSize.isPresent()) {
return delegate.decode(decompress(stream, expectedDecompressedSize.get()), metadata);
int uncompressedSize = expectedDecompressedSize.get();
try (InputStream inputStream = stream) {
byte[] input = new byte[segmentSize];
byte[] output = new byte[uncompressedSize];
int readBytes = ByteStreams.read(inputStream, input, 0, segmentSize);
verify(readBytes == segmentSize, "Expected to read %s bytes but got %s", segmentSize, readBytes);
decompress(input, output);
return delegate.decode(new ByteArrayInputStream(output), metadata);
}
}
// Data not compressed - below threshold
return delegate.decode(stream, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
*/
package io.trino.client.spooling.encoding;

import com.google.common.io.ByteStreams;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.trino.client.QueryDataDecoder;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import static java.lang.String.format;

Expand All @@ -32,18 +29,14 @@ public Lz4QueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(InputStream stream, int expectedDecompressedSize)
void decompress(byte[] input, byte[] output)
throws IOException
{
Lz4Decompressor decompressor = new Lz4Decompressor();
byte[] bytes = ByteStreams.toByteArray(stream);
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
int decompressedSize = decompressor.decompress(input, 0, input.length, output, 0, output.length);
if (decompressedSize != output.length) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, output.length));
}
return new ByteArrayInputStream(output);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
*/
package io.trino.client.spooling.encoding;

import com.google.common.io.ByteStreams;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.trino.client.QueryDataDecoder;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import static java.lang.String.format;

Expand All @@ -32,18 +29,14 @@ public ZstdQueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(InputStream stream, int expectedDecompressedSize)
void decompress(byte[] bytes, byte[] output)
throws IOException
{
ZstdDecompressor decompressor = new ZstdDecompressor();
byte[] bytes = ByteStreams.toByteArray(stream);
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
if (decompressedSize != output.length) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, output.length));
}
return new ByteArrayInputStream(output);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.client.spooling.encoding;

import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import io.trino.client.QueryDataDecoder;
import io.trino.client.spooling.DataAttributes;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.client.spooling.DataAttribute.SEGMENT_SIZE;
import static io.trino.client.spooling.DataAttribute.UNCOMPRESSED_SIZE;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;

class TestCompressedQueryDataDecoder
{
private static final List<List<Object>> SAMPLE_VALUES = ImmutableList.of(ImmutableList.of("hello", "world"));

@Test
public void testClosesUnderlyingInputStreamIfCompressed()
throws IOException
{
AtomicBoolean closed = new AtomicBoolean();
InputStream stream = new FilterInputStream(new ByteArrayInputStream("compressed".getBytes(UTF_8))) {
@Override
public void close()
throws IOException
{
super.close();
closed.set(true);
}
};

QueryDataDecoder decoder = new TestQueryDataDecoder(new QueryDataDecoder() {
@Override
public QueryDataAccess decode(InputStream input, DataAttributes segmentAttributes)
throws IOException
{
assertThat(new String(ByteStreams.toByteArray(input), UTF_8))
.isEqualTo("decompressed");
return () -> SAMPLE_VALUES;
}

@Override
public String encoding()
{
return "test";
}
});

assertThat(closed.get()).isFalse();
assertThat(decoder.decode(stream, DataAttributes
.builder()
.set(UNCOMPRESSED_SIZE, "decompressed".length())
.set(SEGMENT_SIZE, "compressed".length())
.build()).toIterable())
.isEqualTo(SAMPLE_VALUES);
assertThat(closed.get()).isTrue();
}

@Test
public void testDelegatesClosingIfUncompressed()
throws IOException
{
AtomicBoolean closed = new AtomicBoolean();
InputStream stream = new FilterInputStream(new ByteArrayInputStream("not compressed".getBytes(UTF_8))) {
@Override
public void close()
throws IOException
{
super.close();
closed.set(true);
}
};

QueryDataDecoder decoder = new TestQueryDataDecoder(new QueryDataDecoder() {
@Override
public QueryDataAccess decode(InputStream input, DataAttributes segmentAttributes)
throws IOException
{
assertThat(new String(ByteStreams.toByteArray(input), UTF_8))
.isEqualTo("not compressed");
input.close(); // Closes input stream according to the contract
return () -> SAMPLE_VALUES;
}

@Override
public String encoding()
{
return "test";
}
});

assertThat(closed.get()).isFalse();
assertThat(decoder.decode(stream, DataAttributes.builder()
.set(SEGMENT_SIZE, "not compressed".length())
.build()).toIterable())
.isEqualTo(SAMPLE_VALUES);
assertThat(closed.get()).isTrue();
}

private static class TestQueryDataDecoder
extends CompressedQueryDataDecoder
{
public TestQueryDataDecoder(QueryDataDecoder delegate)
{
super(delegate);
}

@Override
void decompress(byte[] bytes, byte[] output)
{
assertThat(new String(bytes, UTF_8))
.isEqualTo("compressed");

byte[] uncompressed = "decompressed".getBytes(UTF_8);
System.arraycopy(uncompressed, 0, output, 0, uncompressed.length);
}

@Override
public String encoding()
{
return "test";
}
}
}

0 comments on commit 3843fb8

Please sign in to comment.