Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #200 from danforbes/dfo/chore/refactors
Browse files Browse the repository at this point in the history
Refactors from llm-chain integration
  • Loading branch information
philpax authored May 9, 2023
2 parents 6a9d9cb + 9f2f02a commit 67ee753
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{fmt::Debug, path::PathBuf};
use clap::{Parser, Subcommand, ValueEnum};
use color_eyre::eyre::{Result, WrapErr};
use llm::{
ElementType, InferenceParameters, InferenceSessionConfig, LoadProgress, Model,
ModelKVMemoryType, ModelParameters, TokenBias,
ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress,
Model, ModelKVMemoryType, ModelParameters, TokenBias,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -276,7 +276,7 @@ impl Generate {
}
}
}
fn parse_bias(s: &str) -> Result<TokenBias, String> {
fn parse_bias(s: &str) -> Result<TokenBias, InvalidTokenBias> {
s.parse()
}

Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use util::TokenUtf8Buffer;
pub use vocabulary::{TokenBias, TokenId, Vocabulary};
pub use vocabulary::{InvalidTokenBias, TokenBias, TokenId, Vocabulary};

#[derive(Clone, Debug, PartialEq)]
/// The parameters for text generation.
Expand Down
25 changes: 22 additions & 3 deletions crates/llm-base/src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, str::FromStr};
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr};

use crate::InferenceError;

Expand Down Expand Up @@ -139,7 +139,7 @@ impl TokenBias {
}

impl FromStr for TokenBias {
type Err = String;
type Err = InvalidTokenBias;

/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
Expand All @@ -165,11 +165,30 @@ impl FromStr for TokenBias {
.map_err(|e: std::num::ParseFloatError| e.to_string())?;
Result::<_, String>::Ok((tid, bias))
})
.collect::<Result<_, _>>()?;
.collect::<Result<_, _>>()
.map_err(InvalidTokenBias)?;
Ok(TokenBias::new(x))
}
}

/// An error was encountered when parsing a token bias string, which should be
/// in the format "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS
/// is a floating point number.
#[derive(Debug)]
pub struct InvalidTokenBias(String);

impl Display for InvalidTokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"should be in the format <int>=<float>,<int>=<float>: {:?}",
self.0
)
}
}

impl Error for InvalidTokenBias {}

impl std::fmt::Display for TokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
Expand Down
2 changes: 2 additions & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ llm-gptj = { path = "../models/gptj", optional = true, version = "0.1.1" }
llm-bloom = { path = "../models/bloom", optional = true, version = "0.1.1" }
llm-neox = { path = "../models/neox", optional = true, version = "0.1.1" }

serde = { workspace = true }

[dev-dependencies]
rand = { workspace = true }

Expand Down
9 changes: 5 additions & 4 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ use std::{
pub use llm_base::{
ggml::format as ggml_format, load, load_progress_callback_stdout, quantize, ElementType,
FileType, InferenceError, InferenceParameters, InferenceRequest, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, QuantizeError, QuantizeProgress,
SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, Vocabulary,
InferenceSessionConfig, InferenceSnapshot, InvalidTokenBias, KnownModel, LoadError,
LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, QuantizeError,
QuantizeProgress, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, Vocabulary,
};
use serde::Serialize;

/// All available models.
pub mod models {
Expand All @@ -89,7 +90,7 @@ pub mod models {
pub use llm_neox::{self as neox, NeoX};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
/// All available model architectures.
pub enum ModelArchitecture {
#[cfg(feature = "bloom")]
Expand Down

0 comments on commit 67ee753

Please sign in to comment.