Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Add HuggingFace's Tokenizer #271

Merged
merged 22 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,461 changes: 1,435 additions & 26 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ pub struct ModelLoad {
#[arg(long, short = 'm')]
pub model_path: PathBuf,

/// Where to save the model from
#[arg(long, short = 'v')]
pub vocab_path: Option<PathBuf>,
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
///
Expand Down Expand Up @@ -376,6 +380,7 @@ impl ModelLoad {

let model = llm::load::<M>(
&self.model_path,
self.vocab_path.as_deref(),
params,
overrides,
|progress| match progress {
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ fn perplexity<M: llm::KnownModel + 'static>(
fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {
let file = File::open(&args.model_path)?;
let mut reader = BufReader::new(&file);
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(|_| {
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(None, |_| {
// We purposely do not print progress here, as we are only interested in the metadata
});

Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ partial_sort = "0.2.0"
serde_bytes = "0.11"
memmap2 = "0.5.10"
half = "2.2.1"
tokenizers = "0.13.3"
8 changes: 4 additions & 4 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl InferenceSession {
if should_call_callback {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
match callback(vocab.token(tk as usize)) {
match callback(&vocab.token(tk as usize)) {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
InferenceFeedback::Continue => (),
Expand All @@ -118,7 +118,7 @@ impl InferenceSession {
params: &InferenceParameters,
output_request: &mut OutputRequest,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
) -> Result<Vec<u8>, InferenceError> {
if self.n_past + 1 >= model.context_size() {
return Err(InferenceError::ContextFull);
}
Expand Down Expand Up @@ -163,7 +163,7 @@ impl InferenceSession {
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
token_utf8_buf.push(&model.vocabulary().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Some(Box::new(e))));
Expand Down Expand Up @@ -204,7 +204,7 @@ impl InferenceSession {
};

// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(token) {
if let Some(tokens) = token_utf8_buf.push(&token) {
match callback(InferenceResponse::InferredToken(tokens)) {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
Expand Down
42 changes: 38 additions & 4 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use ggml::{
use memmap2::Mmap;
use thiserror::Error;

use tokenizers::Tokenizer;

#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
/// Information about the file.
pub struct FileType {
Expand Down Expand Up @@ -280,6 +282,15 @@ pub enum LoadError {
/// The paths that were found.
paths: Vec<PathBuf>,
},

/// The vocab file for the tokenizer could not be loaded.
///
///
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
#[error("could not load vocab file {path:?}")]
VocabLoadError {
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
/// The path that failed.
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
path: PathBuf,
},
}
impl From<util::FindAllModelFilesError> for LoadError {
fn from(value: util::FindAllModelFilesError) -> Self {
Expand Down Expand Up @@ -343,6 +354,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// store any information about the architecture.
pub fn load<M: KnownModel>(
path: &Path,
vocab_path: Option<&Path>,
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
params: ModelParameters,
overrides: Option<M::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
Expand All @@ -364,7 +376,29 @@ pub fn load<M: KnownModel>(
})?;
let mut reader = BufReader::new(&file);

let mut loader = Loader::new(load_progress_callback);
let tokenizer = if let Some(path) = vocab_path {
let tok = if !path.exists() && path.to_str().unwrap().matches("/").count() == 1 {
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
Tokenizer::from_pretrained(path.to_str().unwrap(), None)
} else if path.exists() && path.is_file() {
Tokenizer::from_file(path)
} else {
return Err(LoadError::VocabLoadError {
path: path.to_owned(),
});
};

if tok.is_err() {
return Err(LoadError::VocabLoadError {
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
path: path.to_owned(),
});
}

Some(tok.unwrap())
} else {
None
};

let mut loader = Loader::new(tokenizer, load_progress_callback);

ggml::format::load(&mut reader, &mut loader)
.map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
Expand Down Expand Up @@ -422,7 +456,7 @@ pub fn load<M: KnownModel>(
let mut lora_reader = BufReader::new(&lora_file);
// TODO: Consider updating the progress callback to report the progress of the LoRA file.
// Most LoRAs are small enough that this is not necessary, but it would be nice to have.
let mut lora_loader: Loader<LoraParameters, _> = Loader::new(|_| {});
let mut lora_loader: Loader<LoraParameters, _> = Loader::new(None, |_| {});
ggml::format::load(&mut lora_reader, &mut lora_loader)
.map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?;

Expand Down Expand Up @@ -498,13 +532,13 @@ pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
/// Creates a new loader.
pub fn new(load_progress_callback: F) -> Self {
pub fn new(tokenizer: Option<Tokenizer>, load_progress_callback: F) -> Self {
Self {
load_progress_callback,

container_type: ContainerType::Ggml,
hyperparameters: Hp::default(),
vocabulary: Vocabulary::default(),
vocabulary: Vocabulary::new(tokenizer),
tensors: HashMap::default(),
}
}
Expand Down
7 changes: 4 additions & 3 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ pub trait KnownModel: Send + Sync {
/// is a helper function on top of [llm_base::load](crate::load).
fn load(
path: &Path,
vocab_path: Option<&Path>,
params: ModelParameters,
overrides: Option<Self::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<Self, LoadError>
where
Self: Sized,
{
crate::load(path, params, overrides, load_progress_callback)
crate::load(path, vocab_path, params, overrides, load_progress_callback)
}

/// Creates a new model from the provided [ModelParameters] hyperparameters.
Expand Down Expand Up @@ -151,7 +152,7 @@ pub trait KnownModel: Send + Sync {
output_request: &mut OutputRequest,
);

/// Get the vocabulary (loaded from the GGML file) for this model.
/// Get the vocabulary for this model.
fn vocabulary(&self) -> &Vocabulary;

/// Get the context size (configured with [ModelParameters::context_size]) used by
Expand Down Expand Up @@ -188,7 +189,7 @@ pub trait Model: Send + Sync {
output_request: &mut OutputRequest,
);

/// Get the vocabulary (loaded from the GGML file) for this model.
/// Get the vocabulary for this model.
fn vocabulary(&self) -> &Vocabulary;

/// Get the context size (configured with [ModelParameters::context_size]) used by
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub fn quantize<M: KnownModel, R: BufRead + Seek, W: Write + Seek>(
// Load the model
let progress_callback = Arc::new(progress_callback);

let mut loader = Loader::<M::Hyperparameters, _>::new({
let mut loader = Loader::<M::Hyperparameters, _>::new(None, {
let progress_callback = progress_callback.clone();
move |p| {
if let LoadProgress::HyperparametersLoaded = p {
Expand Down
130 changes: 88 additions & 42 deletions crates/llm-base/src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr};

use thiserror::Error;
use tokenizers::Tokenizer;

/// The identifier of a token in a vocabulary.
pub type TokenId = i32;
pub type TokenId = u32;
pub(crate) type Token = Vec<u8>;
pub(crate) type TokenScore = f32;

Expand Down Expand Up @@ -34,9 +35,19 @@ pub struct Vocabulary {

/// The longest token in this vocabulary.
pub max_token_length: usize,

/// The tokenizer
pub tokenizer: Option<Tokenizer>,
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
}

impl Vocabulary {
/// Intialize a new vocabulary.
pub fn new(tokenizer: Option<Tokenizer>) -> Vocabulary {
let mut vocab = Vocabulary::default();
vocab.tokenizer = tokenizer;

vocab
}
/// Add a token to the vocabulary.
///
/// The token added must have `id` directly after the last token in the vocabulary.
Expand All @@ -45,6 +56,10 @@ impl Vocabulary {
/// - This function can panic if `id` does not correspond to the next token in the vocabulary.
/// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`.
pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
if self.tokenizer.is_some() {
return;
}

// These are loader invariants. If this is broken, then the loader is broken and this is a bug,
// not an issue with the model itself.
assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
Expand All @@ -60,17 +75,33 @@ impl Vocabulary {
}

/// Converts a token index to the token it represents in this vocabulary.
pub fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
pub fn token(&self, idx: usize) -> Vec<u8> {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer
.decode(vec![idx as u32], true)
.unwrap()
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
.as_bytes()
.to_vec();
}

(&self.id_to_token[idx]).clone()
}

/// Returns the number of tokens in the vocabulary.
pub fn len(&self) -> usize {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer.get_vocab_size(false) as usize;
}

self.id_to_token.len()
}

/// Returns whether the vocabulary is empty.
pub fn is_empty(&self) -> bool {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer.get_vocab_size(false) == 0;
}

self.id_to_token.is_empty()
}

Expand All @@ -82,53 +113,68 @@ impl Vocabulary {
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a [u8], TokenId)>, TokenizationError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
) -> Result<Vec<(Vec<u8>, TokenId)>, TokenizationError> {
if let Some(tokenizer) = &self.tokenizer {
let res = tokenizer.encode(text, bos);
if res.is_err() {
return Err(TokenizationError::TokenizationFailed);
} else {
Ok(tokenizer
.encode(text, bos)
.unwrap()
RedBoxing marked this conversation as resolved.
Show resolved Hide resolved
.get_ids()
.iter()
.map(|id| (self.token(*id as usize), *id))
.collect::<Vec<(Vec<u8>, TokenId)>>())
}
} else {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}
}

// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(TokenizationError::TokenizationFailed);
// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(TokenizationError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token.to_vec(), token_id));
i -= token.len();
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push((&[], 1));
}
if bos {
// TODO: replace with vocab.bos
res.push((vec![], 1));
}

// Pieces are in reverse order so correct that
res.reverse();
// Pieces are in reverse order so correct that
res.reverse();

Ok(res)
Ok(res)
}
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fn main() {
let model = llm::load_dynamic(
model_architecture,
model_path,
None,
Default::default(),
overrides,
load_callback,
Expand Down
Loading