diff --git a/src/main/java/com/github/luben/zstd/ZstdInputStream.java b/src/main/java/com/github/luben/zstd/ZstdInputStream.java index b1d1e354..e75588ba 100644 --- a/src/main/java/com/github/luben/zstd/ZstdInputStream.java +++ b/src/main/java/com/github/luben/zstd/ZstdInputStream.java @@ -149,7 +149,11 @@ int readInternal(byte[] dst, int offset, int len) throws IOException { if (frameFinished) { return -1; } else if (isContinuous) { - return (int)(dstPos - offset); + srcSize = (int)(dstPos - offset); + if (srcSize > 0) { + return (int) srcSize; + } + return -1; } else { throw new IOException("Read error or truncated source"); } diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index 34d10849..e741c6ad 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -375,6 +375,52 @@ class ZstdSpec extends FlatSpec with Checkers { input.toSeq == output.toSeq } } + + "ZstdInputStream in continuous mode" should s"not block when the stream ends unexpectedly at level $level" in { + check { input: Array[Byte] => + val size = input.length + val os = new ByteArrayOutputStream(Zstd.compressBound(size.toLong).toInt) + val zos = new ZstdOutputStream(os, level) + zos.write(input) + zos.close + val compressed = os.toByteArray + // Cut the stream arbitrarily short by returning only part of the available data at first. + var releaseRemainingData = false + class IncrementalInputStream(bytes: Array[Byte], truncationAmount: Int) extends ByteArrayInputStream(bytes) { + var firstRead = true + override def read(b: Array[Byte], off: Int, len: Int): Int = { + if (firstRead) { + firstRead = false + super.read(b, off, Math.max(available() - truncationAmount, 0)) + } else if (releaseRemainingData) { + super.read(b, off, truncationAmount) + } else { + -1 + } + } + + override def read(): Int = { + throw new IllegalStateException() + } + } + val arbitraryTruncationAmount = 7 + val is = new IncrementalInputStream(compressed, arbitraryTruncationAmount) + val zis = new ZstdInputStream(is).setContinuous(true); + val output = Array.fill[Byte](size)(0) + // Read the incomplete data. + val amountRead = Math.max(0, zis.read(output)) + // Read the rest of the data and assert that the entire input was decompressed. + releaseRemainingData = true + zis.read(output, amountRead, size - amountRead) + zis.close + if (input.toSeq != output.toSeq) { + println(s"AT SIZE $size") + println(input.toSeq + "!=" + output.toSeq) + println("COMPRESSED: " + compressed.toSeq) + } + input.toSeq == output.toSeq + } + } } for (level <- levels)