Skip to content

Commit

Permalink
Add new constructor to ZstdDictCompress and ZstdDictDecompress that
Browse files Browse the repository at this point in the history
allows the byReference semantics for the provided byte buffer: If you
set this to true, you avoid the copying of the dict data into a
natively malloc'ed buffer, but then also have to promise that the byte
buffer will not be modified before the CTX has been closed.
  • Loading branch information
Morten Grouleff authored and luben committed Apr 3, 2024
1 parent a516a43 commit 2a262bf
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 15 deletions.
28 changes: 25 additions & 3 deletions src/main/java/com/github/luben/zstd/ZstdDictCompress.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ public class ZstdDictCompress extends SharedDictBase {
}

private long nativePtr = 0;

private ByteBuffer sharedDict = null;

private int level = Zstd.defaultCompressionLevel();

private native void init(byte[] dict, int dict_offset, int dict_size, int level);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level);
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level, int byReference);

private native void free();

Expand Down Expand Up @@ -59,6 +62,18 @@ public ZstdDictCompress(byte[] dict, int offset, int length, int level) {
* @param level compression level
*/
public ZstdDictCompress(ByteBuffer dict, int level) {
this(dict, level, false);
}

/**
* Create a new dictionary for use with fast compress.
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
* @param level compression level
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
*/
public ZstdDictCompress(ByteBuffer dict, int level, boolean byReference) {
this.level = level;
int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
Expand All @@ -67,11 +82,14 @@ public ZstdDictCompress(ByteBuffer dict, int level) {
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length, level);
initDirect(dict, dict.position(), length, level, byReference ? 1 : 0);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createCDict failed");
}
if (byReference) {
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
}
// Ensures that even if ZstdDictCompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
Expand All @@ -85,7 +103,11 @@ int level() {
@Override
void doClose() {
if (nativePtr != 0) {
free();
if (sharedDict == null) {
free();
} else {
sharedDict = null;
}
nativePtr = 0;
}
}
Expand Down
26 changes: 23 additions & 3 deletions src/main/java/com/github/luben/zstd/ZstdDictDecompress.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ public class ZstdDictDecompress extends SharedDictBase {

private long nativePtr = 0L;

private ByteBuffer sharedDict = null;

private native void init(byte[] dict, int dict_offset, int dict_size);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size);
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int byReference);

private native void free();

Expand Down Expand Up @@ -52,6 +54,17 @@ public ZstdDictDecompress(byte[] dict, int offset, int length) {
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
*/
public ZstdDictDecompress(ByteBuffer dict) {
this(dict, false);
}

/**
* Create a new dictionary for use with fast decompress.
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
*/
public ZstdDictDecompress(ByteBuffer dict, boolean byReference) {

int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
Expand All @@ -60,11 +73,14 @@ public ZstdDictDecompress(ByteBuffer dict) {
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length);
initDirect(dict, dict.position(), length, byReference ? 1 : 0);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createDDict failed");
}
if (byReference) {
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
}
// Ensures that even if ZstdDictDecompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
Expand All @@ -74,7 +90,11 @@ public ZstdDictDecompress(ByteBuffer dict) {
@Override
void doClose() {
if (nativePtr != 0) {
free();
if (sharedDict == null) {
free();
} else {
sharedDict = null;
}
nativePtr = 0;
}
}
Expand Down
22 changes: 16 additions & 6 deletions src/main/native/jni_fast_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init
/*
* Class: com_github_luben_zstd_ZstdDictCompress
* Method: init
* Signature: (Ljava/nio/ByteBuffer;III)V
* Signature: (Ljava/nio/ByteBuffer;IIII)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level)
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level, jint byReference)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
compress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);
if (NULL == dict_buff) return;
ZSTD_CDict* cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
ZSTD_CDict* cdict = NULL;
if (byReference == 0) {
cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
} else {
cdict = ZSTD_createCDict_byReference(((char *)dict_buff) + dict_offset, dict_size, level);
}
if (NULL == cdict) return;
(*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict);
}
Expand Down Expand Up @@ -85,17 +90,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init
/*
* Class: com_github_luben_zstd_ZstdDictDecompress
* Method: initDirect
* Signature: (Ljava/nio/ByteBuffer;II)V
* Signature: (Ljava/nio/ByteBuffer;III)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size)
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint byReference)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
decompress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);

ZSTD_DDict* ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
ZSTD_DDict* ddict = NULL;
if (byReference == 0) {
ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
} else {
ddict = ZSTD_createDDict_byReference(((char *)dict_buff) + dict_offset, dict_size);
}

if (NULL == ddict) return;
(*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict);
Expand Down
7 changes: 4 additions & 3 deletions src/test/scala/ZstdDict.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,18 @@ class ZstdDictSpec extends AnyFlatSpec {
assert(input.toSeq == decompressed.toSeq)
}

it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with legacy $legacy" in {
it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with byReference $legacy" in {
val byReference = legacy // Reuse the variance flag here.
val size = input.length
val inBuf = ByteBuffer.allocateDirect(size)
inBuf.put(input)
inBuf.flip()
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level)
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level, byReference)
val compressed = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt);
Zstd.compress(compressed, inBuf, cdict)
compressed.flip()
cdict.close
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer)
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer, byReference)
val decompressed = ByteBuffer.allocateDirect(size)
Zstd.decompress(decompressed, compressed, ddict)
decompressed.flip()
Expand Down

0 comments on commit 2a262bf

Please sign in to comment.