Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tiktoken integration #60

Merged
merged 38 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26847a9
cpp-tiktoken integration init
sangjanai Jul 16, 2024
96dd32f
remove cpp-tiktoken folder
sangjanai Jul 16, 2024
8e9983c
tiktoken integration init
sangjanai Jul 16, 2024
4f0fc5d
move fmt to third-party dependency
sangjanai Jul 16, 2024
06effaf
remove unnecessary option in cortex CMakeLists
sangjanai Jul 16, 2024
f4a6703
remove unnecessary submodule
sangjanai Jul 16, 2024
c7c6030
make pcre2 as third-party lib
sangjanai Jul 16, 2024
06f323d
fix CI bug build in window
sangjanai Jul 16, 2024
3077181
fix CI bug in linux
sangjanai Jul 16, 2024
5a4a79c
link directory pcre2
sangjanai Jul 17, 2024
2e059b0
Fix CI bug in linux 'pcre2.h' not found
sangjanai Jul 17, 2024
9c106fa
Fix bug build static third party lib
nguyenhoangthuan99 Jul 19, 2024
1044ea2
Integate Llama3 successfully
nguyenhoangthuan99 Jul 19, 2024
7c4eb94
refactor with model type
nguyenhoangthuan99 Jul 19, 2024
9bb8fca
remove unnecessary comment
nguyenhoangthuan99 Jul 19, 2024
406fb57
format and code convention
nguyenhoangthuan99 Jul 19, 2024
4985a45
fix logic of mistral when preparing input
nguyenhoangthuan99 Jul 19, 2024
3ab27be
Add remove end_of_text and end_of_turn token feature for llama3
nguyenhoangthuan99 Jul 19, 2024
26f0557
fix: patch pcre2 CMakeLists.txt file
nguyenhoangthuan99 Jul 19, 2024
d7a6dea
fix: do not add fPIC to cpp-tiktoken cmake file
nguyenhoangthuan99 Jul 19, 2024
6cec922
fix: suppress warnings
nguyenhoangthuan99 Jul 19, 2024
58afe8f
fix: add CMAKE_C_FLAGS
nguyenhoangthuan99 Jul 22, 2024
63b7484
fix: CI bug build for window in cpp-tiktoken
nguyenhoangthuan99 Jul 22, 2024
7d807d2
fix: CI bug for window fmt build fail without utf-8 flag
nguyenhoangthuan99 Jul 22, 2024
2f48d96
fix: CI bug for window stop using fmt
nguyenhoangthuan99 Jul 22, 2024
6772650
fix: build bugs for window - using latest pcre2 third party
nguyenhoangthuan99 Jul 22, 2024
654c9e9
fix: build bugs for window - pcre2 static runtime on
nguyenhoangthuan99 Jul 22, 2024
95679f7
fix CI bug for window build pcre2
nguyenhoangthuan99 Jul 22, 2024
4521928
rename link lib to pcre2-8-static in windows
nguyenhoangthuan99 Jul 22, 2024
3bb3aa0
fix: use pcre2 dynamic lib for windows
nguyenhoangthuan99 Jul 22, 2024
a3f66a7
fix: pack tensorrt_llm_nvrtc_wrapper.dll
nguyenhoangthuan99 Jul 23, 2024
27ab0be
test: only windows
nguyenhoangthuan99 Jul 24, 2024
d1b8392
fix: disable setLoggerFinder for msvc
nguyenhoangthuan99 Jul 24, 2024
11e2925
fix: cleanup cmake files
nguyenhoangthuan99 Jul 24, 2024
d634357
fix: remove fmt
nguyenhoangthuan99 Jul 24, 2024
f82af99
fix: pack libtensorrt_llm_nvrtc_wrapper
nguyenhoangthuan99 Jul 24, 2024
358955c
fix: patches windows
nguyenhoangthuan99 Jul 24, 2024
b4ddb21
fix: rm fmt
nguyenhoangthuan99 Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/patches/windows/msvc-19.40/msvcp140.dll
Git LFS file not shown
3 changes: 3 additions & 0 deletions .github/patches/windows/msvc-19.40/vcruntime140.dll
Git LFS file not shown
3 changes: 3 additions & 0 deletions .github/patches/windows/msvc-19.40/vcruntime140_1.dll
Git LFS file not shown
7 changes: 6 additions & 1 deletion cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ ifeq ($(OS),Windows_NT)
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force C:\workspace\TensorRT-10.0.1.6\lib\nvinfer_10.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\build\tensorrt_llm\plugins\nvinfer_plugin_tensorrt_llm.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\build\tensorrt_llm\tensorrt_llm.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force .\build_deps\_install\bin\zlib.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\build\tensorrt_llm\kernels\decoderMaskedMultiheadAttention\decoderXQAImplJIT\nvrtcWrapper\tensorrt_llm_nvrtc_wrapper.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force .\build_deps\_install\bin\pcre2-8.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\..\.github\patches\windows\msvc-19.40\msvcp140.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\..\.github\patches\windows\msvc-19.40\vcruntime140_1.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\..\.github\patches\windows\msvc-19.40\vcruntime140.dll cortex.tensorrt-llm\;"
@powershell -Command "cd tensorrt_llm\cortex.tensorrt-llm\; cp -Force ..\..\..\.github\patches\windows\msmpi.dll cortex.tensorrt-llm\;"
else
cd ./tensorrt_llm/cortex.tensorrt-llm && \
Expand All @@ -47,6 +51,7 @@ else
cp /usr/local/tensorrt/targets/x86_64-linux-gnu/lib/libnvinfer.so cortex.tensorrt-llm && \
cp /home/runner/actions-runner/_work/cortex.tensorrt-llm/cortex.tensorrt-llm/cpp/build/tensorrt_llm/plugins/libnvinfer_plugin_tensorrt_llm.so cortex.tensorrt-llm && \
cp /home/runner/actions-runner/_work/cortex.tensorrt-llm/cortex.tensorrt-llm/cpp/build/tensorrt_llm/libtensorrt_llm.so cortex.tensorrt-llm && \
cp /home/runner/actions-runner/_work/cortex.tensorrt-llm/cortex.tensorrt-llm/cpp/build/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.so cortex.tensorrt-llm && \
cp /opt/hpcx/ompi/lib/libmpi.so cortex.tensorrt-llm && \
cp /usr/lib/x86_64-linux-gnu/libnccl.so cortex.tensorrt-llm
endif
Expand Down
12 changes: 10 additions & 2 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
# the License.
# C++17
# engine init

if(UNIX AND NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
add_compile_options(-fPIC)
endif()

include(CheckIncludeFileCXX)
check_include_file_cxx(any HAS_ANY)
check_include_file_cxx(string_view HAS_STRING_VIEW)
Expand Down Expand Up @@ -62,9 +68,11 @@ endif()
message(STATUS "SentencePiece library dirs: ${SENTENCEPIECE_LIBRARY_DIRS}")
message(STATUS "SentencePiece header dirs: ${SENTENCEPIECE_INCLUDE_DIRS}")

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/cpp-tiktoken)

include_directories(${PROJECT_SOURCE_DIR}/include ${SENTENCEPIECE_INCLUDE_DIRS})

link_directories(${SENTENCEPIECE_LIBRARY_DIRS})
link_directories(${SENTENCEPIECE_LIBRARY_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/src/cpp-tiktoken)

set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")

Expand All @@ -75,7 +83,7 @@ add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)

add_library(engine SHARED src/tensorrt-llm_engine.cc)
target_link_libraries(
engine PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm cxxopts::cxxopts sentencepiece PRIVATE ${JSONCPP} ${TRANTOR} ${CMAKE_THREAD_LIBS_INIT} )
engine PUBLIC ${SHARED_TARGET} tiktoken nvinfer_plugin_tensorrt_llm cxxopts::cxxopts sentencepiece PRIVATE ${JSONCPP} ${TRANTOR} ${CMAKE_THREAD_LIBS_INIT} )

target_compile_features(engine PRIVATE cxx_std_17)
target_compile_definitions(engine PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
BasedOnStyle: WebKit
BreakConstructorInitializers: AfterColon
FixNamespaceComments: true
IndentCaseLabels: true
NamespaceIndentation: None
PointerAlignment: Right
SpaceAfterTemplateKeyword: false
SpaceBeforeInheritanceColon: false
SpaceBeforeRangeBasedForLoopColon: false
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
cmake_minimum_required(VERSION 3.1.2)

project(tiktoken LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../build_deps/_install)

find_library(PCRE2
NAMES pcre2-8
HINTS "${CMAKE_PREFIX_PATH}/lib"
)

option(CPP_TIKTOKEN_TESTING "Enable testing" OFF)

set(OPENAPI_SOURCES byte_pair_encoding.cc emdedded_resource_reader.cc modelparams.cc encoding.cc encoding_utils.cc pcre2_regex.cc)

add_library(tiktoken ${OPENAPI_SOURCES})
include_directories(${CMAKE_PREFIX_PATH}/include)
link_directories(${CMAKE_PREFIX_PATH}/lib)

target_link_libraries(tiktoken pcre2-8)

target_include_directories(tiktoken PUBLIC ${CMAKE_CURRENT_LIST_DIR} ${CMAKE_PREFIX_PATH}/include)

if (NOT CPP_TIKTOKEN_TESTING)
message(STATUS "Tests off")
else()
add_subdirectory(ut)
endif()

MESSAGE(STATUS "Copying tokenizers to '${CMAKE_BINARY_DIR}/tokenizers'.")
FILE(COPY o200k_base.tiktoken cl100k_base.tiktoken p50k_base.tiktoken r50k_base.tiktoken tokenizer.model DESTINATION "${CMAKE_BINARY_DIR}/tokenizers")
MESSAGE(STATUS "Tokenizers copied.")
17 changes: 17 additions & 0 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/cpp-tiktoken/License.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Copyright (c) 2023 by Mark Tarrabain. All rights reserved. Redistribution and use in source and binary forms,
with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following
disclaimer.

Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name of the nor the names of its contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 changes: 33 additions & 0 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/cpp-tiktoken/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Cpp-Tiktoken

This is a C++ implementation of a tiktoken tokenizer library for C++. It was heavily inspired
by https://github.com/dmitry-brazhenko/SharpToken

To use, first somewhere have a lines in your project that reads something like:

#include "tiktoken/enconding.h"

....

auto encoder = GptEncoding::get_encoding(<model name>);
The value returned from this function is an `std::shared_ptr` and you will not have to manage its memory.

Supported language models that you can pass as a parameter to this function are:

LanguageModel::O200K_BASE
LanguageModel::CL100K_BASE
LanguageModel::R50K_BASE
LanguageModel::P50K_BASE
LanguageModel::P50K_EDIT
After obtaining an encoder, you can then call

auto tokens = encoder->encode(string_to_encode);
This returns a vector of the tokens for that language model.

You can decode a vector of tokens back into its original string with

auto string_value = encoder->decode(tokens)
If you like this project, and find it useful, you are invited to make a donation of whatever amount you believe
is appropriate via paypal to markt AT nerdflat.com. There is absolutely no obligation to donate.

Install pcre2-devel
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) 2023 by Mark Tarrabain All rights reserved. Redistribution and use in source and binary forms,
* with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following
* disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the nor the names of its contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "byte_pair_encoding.h"
#include "pcre2_regex.h"
#include <limits>
#include <optional>
#include <sstream>
#include <string>
#include <algorithm>

BytePairEncodingCore::BytePairEncodingCore(const std::unordered_map<std::vector<uint8_t>, int, VectorHash> &byte_pair_ranks,
const std::unordered_map<std::string, int> &special_token_mappings,
const std::shared_ptr<PCRERegex> &pattern_string) :
byte_pair_ranks_(byte_pair_ranks),
special_token_mappings_(special_token_mappings),
pattern_string_(pattern_string) { }

std::vector<int> BytePairEncodingCore::byte_pair_merge(const std::vector<uint8_t> &piece,
const std::unordered_map<std::vector<uint8_t>, int, VectorHash> &ranks,
const std::function<int(int, int)> &f)
{
std::vector<std::pair<int, int>> partitions(piece.size() + 1);
for (size_t i = 0; i <= piece.size(); ++i) {
partitions[i] = { static_cast<int>(i), std::numeric_limits<int>::max() };
}
auto get_rank = [&piece, &partitions, &ranks](size_t idx, int skip) -> std::optional<int> {
if (idx + skip + 2 >= partitions.size()) {
return std::nullopt;
}
std::vector<uint8_t> key(piece.begin() + partitions[idx].first, piece.begin() + partitions[idx + skip + 2].first);
auto rank_iter = ranks.find(key);
return (rank_iter != ranks.end()) ? std::optional<int>(rank_iter->second) : std::nullopt;
};
for (size_t i = 0; i < partitions.size() - 2; ++i) {
auto rank = get_rank(i, 0);
if (rank.has_value()) {
partitions[i].second = rank.value();
}
}
while (partitions.size() > 1) {
int min_rank = std::numeric_limits<int>::max();
size_t min_rank_idx = 0;
for (size_t i = 0; i < partitions.size() - 1; ++i) {
if (partitions[i].second < min_rank) {
min_rank = partitions[i].second;
min_rank_idx = i;
}
}
if (min_rank != std::numeric_limits<int>::max()) {
partitions[min_rank_idx].second = get_rank(min_rank_idx, 1).value_or(std::numeric_limits<int>::max());

if (min_rank_idx > 0) {
partitions[min_rank_idx - 1].second = get_rank(min_rank_idx - 1, 1).value_or(std::numeric_limits<int>::max());
}
partitions.erase(partitions.begin() + static_cast<long long>(min_rank_idx) + 1);
} else {
break;
}
}
std::vector<int> output;
output.reserve(partitions.size() - 1);
for (size_t i = 0; i < partitions.size() - 1; ++i) {
output.push_back(f(partitions[i].first, partitions[i + 1].first));
}
return output;
}


std::vector<std::string> BytePairEncodingCore::break_into_specials(std::string const& line_to_encode, const std::unordered_set<std::string> &allowed_special) {
std::vector<std::pair<size_t, size_t>> separator_offsets;
std::string::size_type pos = 0;
for (auto& sep: special_token_mappings_) {
if (!sep.first.empty()) {
while ((pos = line_to_encode.find(sep.first, pos)) != std::string::npos) {
separator_offsets.push_back({ pos, pos + sep.first.size() });
pos += sep.first.size();
}
pos = 0;
} else if (allowed_special.count("")) {
separator_offsets.push_back({ 0, 0 });
}
}
std::sort(separator_offsets.begin(), separator_offsets.end());
std::vector<std::string> lines;
for (auto [begin, end]: separator_offsets) {
lines.push_back(line_to_encode.substr(pos, begin - pos));
lines.push_back(line_to_encode.substr(begin, end - begin));
pos = end;
}
lines.push_back(line_to_encode.substr(pos, line_to_encode.size() - pos));
return lines;
}

std::pair<std::vector<int>, std::vector<int>> BytePairEncodingCore::encode_native(const std::string &line_to_encode,
const std::unordered_set<std::string> &allowed_special)
{
std::vector<int> tokens;
std::vector<int> segment_ids;
auto lines = break_into_specials(line_to_encode, allowed_special);
for(auto line:lines) {
auto special_mapping = special_token_mappings_.find(line);
if (special_mapping != special_token_mappings_.end() && allowed_special.count(line) > 0) {
tokens.push_back(special_mapping->second);
segment_ids.push_back(0);
} else {
auto matches = pattern_string_->get_all_matches(line);
for (auto token: matches) {
auto special_mapping = special_token_mappings_.find(token);
if (special_mapping != special_token_mappings_.end() && allowed_special.count(token) > 0) {
if (!token.empty()) {
tokens.push_back(special_mapping->second);
segment_ids.push_back(0);
}
} else {
std::vector<uint8_t> utf8_encoded(token.begin(), token.end());
if (utf8_encoded.size() == 1) {
auto rank_iter = byte_pair_ranks_.find(utf8_encoded);
if (rank_iter != byte_pair_ranks_.end()) {
tokens.push_back(rank_iter->second);
segment_ids.push_back(0);
}
} else {
auto byte_pairs = byte_pair_merge(utf8_encoded, byte_pair_ranks_, [&](int start, int end) {
std::vector<uint8_t> key(utf8_encoded.begin() + start, utf8_encoded.begin() + end);
return byte_pair_ranks_[key];
});
tokens.insert(tokens.end(), byte_pairs.begin(), byte_pairs.end());
segment_ids.insert(segment_ids.end(), byte_pairs.size(), 0);
}
}
}
}
}
return std::make_pair(tokens, segment_ids);
}

std::string BytePairEncodingCore::decode_native(const std::vector<int> &input_tokens_to_decode)
{
std::stringstream decoded_string;
for (const int token_id: input_tokens_to_decode) {
auto special_token = std::find_if(special_token_mappings_.begin(), special_token_mappings_.end(),
[token_id](const auto &pair) { return pair.second == token_id; });
if (special_token != special_token_mappings_.end()) {
decoded_string << special_token->first;
} else {
for (const auto &byte_pair: byte_pair_ranks_) {
if (byte_pair.second == token_id) {
decoded_string << std::string(byte_pair.first.begin(), byte_pair.first.end());
break;
}
}
}
}
return decoded_string.str();
}

const std::unordered_map<std::vector<uint8_t>, int, VectorHash>& BytePairEncodingCore::getBytePairRanks() const {
return byte_pair_ranks_;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2023 by Mark Tarrabain All rights reserved. Redistribution and use in source and binary forms,
* with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following
* disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the nor the names of its contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once

#include "encoding_utils.h"
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

class PCRERegex;

class BytePairEncodingCore {
std::unordered_map<std::vector<uint8_t>, int, VectorHash> byte_pair_ranks_;
std::unordered_map<std::string, int> special_token_mappings_;
std::shared_ptr<PCRERegex> pattern_string_;

static std::vector<int> byte_pair_merge(const std::vector<uint8_t> &piece,
const std::unordered_map<std::vector<uint8_t>, int, VectorHash> &ranks,
const std::function<int(int, int)> &f);

public:
BytePairEncodingCore(const std::unordered_map<std::vector<uint8_t>, int, VectorHash> &byte_pair_ranks,
const std::unordered_map<std::string, int> &special_token_mappings,
const std::shared_ptr<PCRERegex> &pattern_string);

std::pair<std::vector<int>, std::vector<int>> encode_native(const std::string &line_to_encode,
const std::unordered_set<std::string> &allowed_special);
std::string decode_native(const std::vector<int> &input_tokens_to_decode);
std::vector<std::string> break_into_specials(std::string const& line_to_encode, const std::unordered_set<std::string> &allowed_special);

[[nodiscard]] const std::unordered_map<std::vector<uint8_t>, int, VectorHash>& getBytePairRanks() const;
};
Loading
Loading