Skip to content

Commit

Permalink
Guard all native code execution behind the shared lock ...
Browse files Browse the repository at this point in the history
... in the ZstdCompressionCtx and ZstdDecompressionCtx.
  • Loading branch information
luben committed Nov 21, 2024
1 parent 9b08f1d commit cec9653
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 25 deletions.
50 changes: 35 additions & 15 deletions src/main/java/com/github/luben/zstd/ZstdCompressCtx.java
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,12 @@ public ZstdCompressCtx loadDict(byte[] dict) {
*/
public ZstdFrameProgression getFrameProgression() {
ensureOpen();
return getFrameProgression0(nativePtr);
acquireSharedLock();
try {
return getFrameProgression0(nativePtr);
} finally {
releaseSharedLock();
}
}
private static native ZstdFrameProgression getFrameProgression0(long ptr);

Expand All @@ -447,10 +452,16 @@ public ZstdFrameProgression getFrameProgression() {
*/
public void reset() {
ensureOpen();
long result = reset0(nativePtr);
if (Zstd.isError(result)) {
throw new ZstdException(result);
acquireSharedLock();
try {
long result = reset0(nativePtr);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
} finally {
releaseSharedLock();
}

}
private static native long reset0(long ptr);

Expand All @@ -464,9 +475,14 @@ public void reset() {
*/
public void setPledgedSrcSize(long srcSize) {
ensureOpen();
long result = setPledgedSrcSize0(nativePtr, srcSize);
if (Zstd.isError(result)) {
throw new ZstdException(result);
acquireSharedLock();
try {
long result = setPledgedSrcSize0(nativePtr, srcSize);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
} finally {
releaseSharedLock();
}
}
private static native long setPledgedSrcSize0(long ptr, long srcSize);
Expand All @@ -482,14 +498,19 @@ public void setPledgedSrcSize(long srcSize) {
*/
public boolean compressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src, EndDirective endOp) {
ensureOpen();
long result = compressDirectByteBufferStream0(nativePtr, dst, dst.position(), dst.limit(), src, src.position(), src.limit(), endOp.value());
if ((result & 0x80000000L) != 0) {
long code = result & 0xFF;
throw new ZstdException(code, Zstd.getErrorName(code));
acquireSharedLock();
try {
long result = compressDirectByteBufferStream0(nativePtr, dst, dst.position(), dst.limit(), src, src.position(), src.limit(), endOp.value());
if ((result & 0x80000000L) != 0) {
long code = result & 0xFF;
throw new ZstdException(code, Zstd.getErrorName(code));
}
src.position((int)(result & 0x7FFFFFFF));
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
return (result >>> 63) == 1;
} finally {
releaseSharedLock();
}
src.position((int)(result & 0x7FFFFFFF));
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
return (result >>> 63) == 1;
}

/**
Expand Down Expand Up @@ -604,7 +625,6 @@ public int compressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[]
* @return the size of the compressed data
*/
public int compress(ByteBuffer dstBuf, ByteBuffer srcBuf) {

int size = compressDirectByteBuffer(dstBuf, // compress into dstBuf
dstBuf.position(), // write compressed data starting at offset position()
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position() bytes
Expand Down
33 changes: 23 additions & 10 deletions src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,18 @@ public ZstdDecompressCtx loadDict(byte[] dict) {
*/
public void reset() {
ensureOpen();
reset0(nativePtr);
acquireSharedLock();
try {
long result = reset0(nativePtr);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
} finally {
releaseSharedLock();
}

}
private static native void reset0(long nativePtr);
private static native long reset0(long nativePtr);

private void ensureOpen() {
if (nativePtr == 0) {
Expand All @@ -121,14 +130,19 @@ private void ensureOpen() {
*/
public boolean decompressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src) {
ensureOpen();
long result = decompressDirectByteBufferStream0(nativePtr, dst, dst.position(), dst.limit(), src, src.position(), src.limit());
if ((result & 0x80000000L) != 0) {
long code = result & 0xFF;
throw new ZstdException(code, Zstd.getErrorName(code));
acquireSharedLock();
try {
long result = decompressDirectByteBufferStream0(nativePtr, dst, dst.position(), dst.limit(), src, src.position(), src.limit());
if ((result & 0x80000000L) != 0) {
long code = result & 0xFF;
throw new ZstdException(code, Zstd.getErrorName(code));
}
src.position((int)(result & 0x7FFFFFFF));
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
return (result >>> 63) == 1;
} finally {
releaseSharedLock();
}
src.position((int)(result & 0x7FFFFFFF));
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
return (result >>> 63) == 1;
}

/**
Expand Down Expand Up @@ -242,7 +256,6 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[
* @return the size of the decompressed data.
*/
public int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) throws ZstdException {

int size = decompressDirectByteBuffer(dstBuf, // decompress into dstBuf
dstBuf.position(), // write decompressed data at offset position()
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position()
Expand Down

0 comments on commit cec9653

Please sign in to comment.