Skip to content

Commit

Permalink
Fixes #999, hanlde UTF16 surrogate charactors properly.
Browse files Browse the repository at this point in the history
Change-Id: I19e77cf5a8282bea901434041806eb102549ec0f
  • Loading branch information
frankfliu committed Jun 9, 2021
1 parent 8286930 commit 4721cae
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
34 changes: 30 additions & 4 deletions api/src/main/native/djl/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@ inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) {
if (jstr == nullptr) {
return std::string();
}
const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE);
std::string str = std::string(c_str);
env->ReleaseStringUTFChars(jstr, c_str);

// TODO: cache reflection to improve performance
const jclass string_class = env->GetObjectClass(jstr);
const jmethodID getbytes_method = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");

const jstring charset = env->NewStringUTF("UTF-8");
const jbyteArray jbytes = (jbyteArray) env->CallObjectMethod(jstr, getbytes_method, charset);
env->DeleteLocalRef(charset);

const jsize length = env->GetArrayLength(jbytes);
jbyte* c_str = env->GetByteArrayElements(jbytes, NULL);
std::string str = std::string(reinterpret_cast<const char *>(c_str), length);

env->ReleaseByteArrayElements(jbytes, c_str, RELEASE_MODE);
env->DeleteLocalRef(jbytes);
return str;
}

Expand Down Expand Up @@ -100,9 +112,23 @@ inline std::vector<std::string> GetVecFromJStringArray(JNIEnv* env, jobjectArray
// String[]
inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector <std::string> &vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr);

// TODO: cache reflection to improve performance
const jclass string_class = env->FindClass("java/lang/String");
const jmethodID ctor = env->GetMethodID(string_class, "<init>", "([BLjava/lang/String;)V");
const jstring charset = env->NewStringUTF("UTF-8");

for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str()));
const char* c_str = vec[i].c_str();
int len = vec[i].length();
auto jbytes = env->NewByteArray(len);
env->SetByteArrayRegion(jbytes, 0, len, reinterpret_cast<const jbyte*>(c_str));
jobject jstr = env->NewObject(string_class, ctor, jbytes, charset);
env->DeleteLocalRef(jbytes);
env->SetObjectArrayElement(array, i, jstr);
}

env->DeleteLocalRef(charset);
return array;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ public void testTokenize() throws IOException {
}
}

@Test
public void testUtf16Tokenize() throws IOException {
if (System.getProperty("os.name").startsWith("Win")) {
throw new SkipException("Skip windows test.");
}
Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model");
try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) {
String original = "\uD83D\uDC4B\uD83D\uDC4B";
List<String> tokens = tokenizer.tokenize(original);
List<String> expected = Arrays.asList("▁", "\uD83D\uDC4B\uD83D\uDC4B");
Assert.assertEquals(tokens, expected);
}
}

@Test
public void testEncodeDecode() throws IOException {
if (System.getProperty("os.name").startsWith("Win")) {
Expand Down

0 comments on commit 4721cae

Please sign in to comment.