From 6b2a0825b6a5a235ebfd97a210d06706f966e0ed Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 18 Oct 2024 10:13:49 -0700 Subject: [PATCH] Fix issue with partial UTF-8 string (#6317) Summary: It will cause JNI exception if we don't pass in UTF-8 string. Alternative 1 (this): wait until we have complete UTF-8 tokens. Alternative 2 (?): Fix this from runner layer Alternative 3 (no): Change the API to use uint8_t array, but if we want to display on app in real time, this is still an issue. Pull Request resolved: https://github.com/pytorch/executorch/pull/6317 Reviewed By: Riandy Differential Revision: D64580932 Pulled By: kirklandsign fbshipit-source-id: 341ea906097707fae0f97d32dea974ad44425083 --- extension/android/jni/jni_layer_llama.cpp | 50 +++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 6ffc88d810..1049b9da30 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -6,11 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include #include -#include +#include #include -#include #include #include #include @@ -33,6 +31,43 @@ namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; +namespace { +bool utf8_check_validity(const char* str, size_t length) { + for (size_t i = 0; i < length; ++i) { + uint8_t byte = static_cast(str[i]); + if (byte >= 0x80) { // Non-ASCII byte + if (i + 1 >= length) { // Incomplete sequence + return false; + } + uint8_t next_byte = static_cast(str[i + 1]); + if ((byte & 0xE0) == 0xC0 && + (next_byte & 0xC0) == 0x80) { // 2-byte sequence + i += 2; + } else if ( + (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == + 0x80) { // 3-byte sequence + i += 3; + } else if ( + (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == 0x80 && + (i + 3 < length) && + (static_cast(str[i + 3]) & 0xC0) == + 0x80) { // 4-byte sequence + i += 4; + } else { + return false; // Invalid sequence + } + } + } + return true; // All bytes were valid +} + +std::string token_buffer; +} // namespace + namespace executorch_jni { class ExecuTorchLlamaCallbackJni @@ -45,6 +80,15 @@ class ExecuTorchLlamaCallbackJni static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); + + token_buffer += result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; + } + result = token_buffer; + token_buffer = ""; facebook::jni::local_ref s = facebook::jni::make_jstring(result); method(self(), s); }