Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Refactors from llm-chain integration #200

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{fmt::Debug, path::PathBuf};
use clap::{Parser, Subcommand, ValueEnum};
use color_eyre::eyre::{Result, WrapErr};
use llm::{
ElementType, InferenceParameters, InferenceSessionConfig, LoadProgress, Model,
ModelKVMemoryType, ModelParameters, TokenBias,
ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress,
Model, ModelKVMemoryType, ModelParameters, TokenBias,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -276,7 +276,7 @@ impl Generate {
}
}
}
fn parse_bias(s: &str) -> Result<TokenBias, String> {
fn parse_bias(s: &str) -> Result<TokenBias, InvalidTokenBias> {
s.parse()
}

Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use util::TokenUtf8Buffer;
pub use vocabulary::{TokenBias, TokenId, Vocabulary};
pub use vocabulary::{InvalidTokenBias, TokenBias, TokenId, Vocabulary};

#[derive(Clone, Debug, PartialEq)]
/// The parameters for text generation.
Expand Down
25 changes: 22 additions & 3 deletions crates/llm-base/src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, str::FromStr};
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr};

use crate::InferenceError;

Expand Down Expand Up @@ -139,7 +139,7 @@ impl TokenBias {
}

impl FromStr for TokenBias {
type Err = String;
type Err = InvalidTokenBias;

/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
Expand All @@ -165,11 +165,30 @@ impl FromStr for TokenBias {
.map_err(|e: std::num::ParseFloatError| e.to_string())?;
Result::<_, String>::Ok((tid, bias))
})
.collect::<Result<_, _>>()?;
.collect::<Result<_, _>>()
.map_err(InvalidTokenBias)?;
Ok(TokenBias::new(x))
}
}

/// An error was encountered when parsing a token bias string, which should be
/// in the format "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS
/// is a floating point number.
#[derive(Debug)]
pub struct InvalidTokenBias(String);

impl Display for InvalidTokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"should be in the format <int>=<float>,<int>=<float>: {:?}",
self.0
)
}
}

impl Error for InvalidTokenBias {}

impl std::fmt::Display for TokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
Expand Down
2 changes: 2 additions & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ llm-gptj = { path = "../models/gptj", optional = true, version = "0.1.1" }
llm-bloom = { path = "../models/bloom", optional = true, version = "0.1.1" }
llm-neox = { path = "../models/neox", optional = true, version = "0.1.1" }

serde = { workspace = true }

[dev-dependencies]
rand = { workspace = true }

Expand Down
9 changes: 5 additions & 4 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ use std::{
pub use llm_base::{
ggml::format as ggml_format, load, load_progress_callback_stdout, quantize, ElementType,
FileType, InferenceError, InferenceParameters, InferenceRequest, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, QuantizeError, QuantizeProgress,
SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, Vocabulary,
InferenceSessionConfig, InferenceSnapshot, InvalidTokenBias, KnownModel, LoadError,
LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, QuantizeError,
QuantizeProgress, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, Vocabulary,
};
use serde::Serialize;

/// All available models.
pub mod models {
Expand All @@ -89,7 +90,7 @@ pub mod models {
pub use llm_neox::{self as neox, NeoX};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
/// All available model architectures.
pub enum ModelArchitecture {
#[cfg(feature = "bloom")]
Expand Down