Skip to content

Commit

Permalink
Fix infinite loop in continuous mode when reading an incomplete frame.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgregg committed Jun 5, 2020
1 parent 3d16e51 commit 3d51bdc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/main/java/com/github/luben/zstd/ZstdInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ int readInternal(byte[] dst, int offset, int len) throws IOException {
if (frameFinished) {
return -1;
} else if (isContinuous) {
return (int)(dstPos - offset);
srcSize = (int)(dstPos - offset);
if (srcSize > 0) {
return (int) srcSize;
}
return -1;
} else {
throw new IOException("Read error or truncated source");
}
Expand Down
46 changes: 46 additions & 0 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,52 @@ class ZstdSpec extends FlatSpec with Checkers {
input.toSeq == output.toSeq
}
}

"ZstdInputStream in continuous mode" should s"not block when the stream ends unexpectedly at level $level" in {
check { input: Array[Byte] =>
val size = input.length
val os = new ByteArrayOutputStream(Zstd.compressBound(size.toLong).toInt)
val zos = new ZstdOutputStream(os, level)
zos.write(input)
zos.close
val compressed = os.toByteArray
// Cut the stream arbitrarily short by returning only part of the available data at first.
var releaseRemainingData = false
class IncrementalInputStream(bytes: Array[Byte], truncationAmount: Int) extends ByteArrayInputStream(bytes) {
var firstRead = true
override def read(b: Array[Byte], off: Int, len: Int): Int = {
if (firstRead) {
firstRead = false
super.read(b, off, Math.max(available() - truncationAmount, 0))
} else if (releaseRemainingData) {
super.read(b, off, truncationAmount)
} else {
-1
}
}

override def read(): Int = {
throw new IllegalStateException()
}
}
val arbitraryTruncationAmount = 7
val is = new IncrementalInputStream(compressed, arbitraryTruncationAmount)
val zis = new ZstdInputStream(is).setContinuous(true);
val output = Array.fill[Byte](size)(0)
// Read the incomplete data.
val amountRead = Math.max(0, zis.read(output))
// Read the rest of the data and assert that the entire input was decompressed.
releaseRemainingData = true
zis.read(output, amountRead, size - amountRead)
zis.close
if (input.toSeq != output.toSeq) {
println(s"AT SIZE $size")
println(input.toSeq + "!=" + output.toSeq)
println("COMPRESSED: " + compressed.toSeq)
}
input.toSeq == output.toSeq
}
}
}

for (level <- levels)
Expand Down

0 comments on commit 3d51bdc

Please sign in to comment.