Skip to content

Commit

Permalink
Fix issue with partial UTF-8 string (#6317)
Browse files Browse the repository at this point in the history
Summary:
It will cause JNI exception if we don't pass in UTF-8 string.

Alternative 1 (this): wait until we have complete UTF-8 tokens.
Alternative 2 (?): Fix this from runner layer
Alternative 3 (no): Change the API to use uint8_t array, but if we want to display on app in real time, this is still an issue.

Pull Request resolved: #6317

Reviewed By: Riandy

Differential Revision: D64580932

Pulled By: kirklandsign

fbshipit-source-id: 341ea906097707fae0f97d32dea974ad44425083
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Oct 18, 2024
1 parent 0eeea82 commit 6b2a082
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cassert>
#include <chrono>
#include <iostream>
#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -33,6 +31,43 @@
namespace llm = ::executorch::extension::llm;
using ::executorch::runtime::Error;

namespace {
bool utf8_check_validity(const char* str, size_t length) {
for (size_t i = 0; i < length; ++i) {
uint8_t byte = static_cast<uint8_t>(str[i]);
if (byte >= 0x80) { // Non-ASCII byte
if (i + 1 >= length) { // Incomplete sequence
return false;
}
uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
if ((byte & 0xE0) == 0xC0 &&
(next_byte & 0xC0) == 0x80) { // 2-byte sequence
i += 2;
} else if (
(byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 &&
(i + 2 < length) &&
(static_cast<uint8_t>(str[i + 2]) & 0xC0) ==
0x80) { // 3-byte sequence
i += 3;
} else if (
(byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 &&
(i + 2 < length) &&
(static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 &&
(i + 3 < length) &&
(static_cast<uint8_t>(str[i + 3]) & 0xC0) ==
0x80) { // 4-byte sequence
i += 4;
} else {
return false; // Invalid sequence
}
}
}
return true; // All bytes were valid
}

std::string token_buffer;
} // namespace

namespace executorch_jni {

class ExecuTorchLlamaCallbackJni
Expand All @@ -45,6 +80,15 @@ class ExecuTorchLlamaCallbackJni
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");

token_buffer += result;
if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
ET_LOG(
Info, "Current token buffer is not valid UTF-8. Waiting for more.");
return;
}
result = token_buffer;
token_buffer = "";
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
}
Expand Down

0 comments on commit 6b2a082

Please sign in to comment.