Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the ability to support third party LLM through customizable API URL #143

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases),
- Support proxy connection
- Dark/light theme
- Save chat messages
- Customizable API URL

## Config

Expand All @@ -58,6 +59,8 @@ api_key: "<YOUR SECRET API KEY>" # Request via https://platform.openai.com/accou
organization_id: "org-xxx" # optional, set organization id
model: "gpt-3.5-turbo" # optional, choose a model
temperature: 1.0 # optional, see https://platform.openai.com/docs/api-reference/chat/create#chat/create-temperature
api_url: "https://api.openai.com/v1/chat/completions" #optional, set API url to request
max_tokens: 1024 # optional, set max tokens for return from API.
save: true # optional, If set true, aichat will save chat messages to message.md
highlight: true # optional, Set false to turn highlight
proxy: "socks5://127.0.0.1:1080" # optional, set proxy server. e.g. http://127.0.0.1:8080 or socks5://127.0.0.1:1080
Expand Down Expand Up @@ -158,6 +161,8 @@ roles_file /home/alice/.config/aichat/roles.yaml
messages_file /home/alice/.config/aichat/messages.md
api_key sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
organization_id -
api_url https://api.openai.com/v1/chat/completions
max_tokens 1024
model gpt-3.5-turbo
temperature -
save true
Expand Down
5 changes: 5 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ pub struct Cli {
/// No stream output
#[clap(short = 'S', long)]
pub no_stream: bool,
/// Define the API URL for requesting
#[clap(short = 'I', long)]
pub api_url: Option<String>,
#[clap(long)]
pub max_tokens: Option<usize>,
/// List all roles
#[clap(long)]
pub list_roles: bool,
Expand Down
11 changes: 7 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::sleep;

const API_URL: &str = "https://api.openai.com/v1/chat/completions";

#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct ChatGptClient {
Expand Down Expand Up @@ -138,11 +136,13 @@ impl ChatGptClient {
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let (model, _) = self.config.read().get_model();
let messages = self.config.read().build_messages(content)?;
let max_tokens = self.config.read().get_max_tokens();
let mut body = json!({
"model": model,
"messages": messages,
"max_tokens" : max_tokens,
});

if let Some(v) = self.config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
Expand All @@ -153,11 +153,14 @@ impl ChatGptClient {
.and_then(|m| m.insert("stream".into(), json!(true)));
}

let api_url = self.config.read().get_api_url();


let (api_key, organization_id) = self.config.read().get_api_key();

let mut builder = self
.build_client()?
.post(API_URL)
.post(api_url)
.bearer_auth(api_key)
.json(&body);

Expand Down
34 changes: 31 additions & 3 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ use std::{

pub const MODELS: [(&str, usize); 4] = [
("gpt-4", 8192),
("gpt-4-0613", 8192), // fewer limits when calling API compared to gpt-4
("gpt-4-32k", 32768),
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
];

const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
const MESSAGE_FILE_NAME: &str = "messages.md";
const SET_COMPLETIONS: [&str; 8] = [
const SET_COMPLETIONS: [&str; 10] = [
".set temperature",
".set save true",
".set save false",
Expand All @@ -43,6 +43,8 @@ const SET_COMPLETIONS: [&str; 8] = [
".set proxy",
".set dry_run true",
".set dry_run false",
".set api_url",
".set max_tokens",
];

#[allow(clippy::struct_excessive_bools)]
Expand All @@ -53,6 +55,9 @@ pub struct Config {
pub api_key: Option<String>,
/// OpenAI organization id
pub organization_id: Option<String>,
/// Configable requesting api url
pub api_url: String,
pub max_tokens: usize,
/// OpenAI model
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_name: Option<String>,
Expand Down Expand Up @@ -105,6 +110,8 @@ impl Default for Config {
roles: vec![],
role: None,
conversation: None,
max_tokens: 1024,
api_url: "https://api.openai.com/v1/chat/completions".into(),
model: ("gpt-3.5-turbo".into(), 4096),
}
}
Expand Down Expand Up @@ -134,6 +141,7 @@ impl Config {
if let Some(name) = config.model_name.clone() {
config.set_model(&name)?;
}

config.merge_env_vars();
config.maybe_proxy();
config.load_roles()?;
Expand Down Expand Up @@ -287,7 +295,12 @@ impl Config {
pub fn get_model(&self) -> (String, usize) {
self.model.clone()
}

pub fn get_api_url(&self) -> String {
self.api_url.clone()
}
pub fn get_max_tokens(&self) -> usize {
self.max_tokens.clone()
}
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
#[allow(clippy::option_if_let_else)]
let messages = if let Some(conversation) = self.conversation.as_ref() {
Expand Down Expand Up @@ -336,6 +349,8 @@ impl Config {
.temperature
.map_or_else(|| String::from("-"), |v| v.to_string());
let (api_key, organization_id) = self.get_api_key();
let api_url = self.get_api_url();
let max_tokens = self.get_max_tokens();
let api_key = mask_text(&api_key, 3, 4);
let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4));
let items = vec![
Expand All @@ -344,6 +359,8 @@ impl Config {
("messages_file", file_info(&Self::messages_file()?)),
("api_key", api_key),
("organization_id", organization_id),
("api_url", api_url),
("max_tokens", max_tokens.to_string()),
("model", self.model.0.to_string()),
("temperature", temperature),
("save", self.save.to_string()),
Expand Down Expand Up @@ -390,6 +407,17 @@ impl Config {
self.temperature = Some(value);
}
}
"api_url" => {
self.api_url = value.to_string();
}
"max_tokens" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.max_tokens = value;
}
"light_theme" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.light_theme = value;
}
"save" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.save = value;
Expand Down
6 changes: 6 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ fn main() -> Result<()> {
if cli.dry_run {
config.write().dry_run = true;
}
if cli.max_tokens.is_some() {
config.write().max_tokens = cli.max_tokens.unwrap();
}
if cli.api_url.is_some() {
config.write().api_url = cli.api_url.expect("Error setting API_URL").to_string();
}
let role = match &cli.role {
Some(name) => Some(
config
Expand Down