Skip to content

Commit

Permalink
refactor: improve environment variables (#163)
Browse files Browse the repository at this point in the history
Rename `AICHAT_API_KEY` to `OPENAI_API_KEY`
Add `LOCALAI_API_KEY`
  • Loading branch information
sigoden authored Oct 28, 2023
1 parent bc44026 commit 2ab2e23
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
14 changes: 7 additions & 7 deletions src/client/localai.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{Client, ModelInfo};
use super::{set_proxy, Client, ModelInfo};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;

use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::json;
use std::env;
use std::time::Duration;

#[allow(clippy::module_name_repetitions)]
Expand Down Expand Up @@ -151,10 +152,7 @@ impl LocalAIClient {

let client = {
let mut builder = ReqwestClient::builder();
if let Some(proxy) = &self.local_config.proxy {
builder = builder
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
}
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
Expand All @@ -165,7 +163,9 @@ impl LocalAIClient {
let mut builder = client.post(&self.local_config.url);
if let Some(api_key) = &self.local_config.api_key {
builder = builder.bearer_auth(api_key);
};
} else if let Ok(api_key) = env::var("LOCALAI_API_KEY") {
builder = builder.bearer_auth(api_key);
}
builder = builder.json(&body);

Ok(builder)
Expand Down
19 changes: 18 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use self::{

use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use reqwest::{ClientBuilder, Proxy};
use serde::Deserialize;
use std::time::Duration;
use std::{env, time::Duration};
use tokio::runtime::Runtime;
use tokio::time::sleep;

Expand Down Expand Up @@ -204,3 +205,19 @@ pub fn init_runtime() -> Result<Runtime> {
.build()
.with_context(|| "Failed to init tokio")
}

pub(crate) fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
let proxy = if let Some(proxy) = proxy {
if proxy.is_empty() || proxy == "false" || proxy == "-" {
return Ok(builder);
}
proxy.clone()
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
proxy
} else {
return Ok(builder);
};
let builder =
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
Ok(builder)
}
13 changes: 5 additions & 8 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::{Client, ModelInfo};
use super::{set_proxy, Client, ModelInfo};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use crate::{config::SharedConfig, utils::get_env_name};

use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::env;
Expand Down Expand Up @@ -120,7 +120,7 @@ impl OpenAIClient {
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let api_key = if let Some(api_key) = &self.local_config.api_key {
api_key.to_string()
} else if let Ok(api_key) = env::var(get_env_name("api_key")) {
} else if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key.to_string()
} else {
bail!("Miss api_key")
Expand All @@ -145,10 +145,7 @@ impl OpenAIClient {

let client = {
let mut builder = ReqwestClient::builder();
if let Some(proxy) = &self.local_config.proxy {
builder = builder
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
}
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
Expand Down
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl Config {
pub fn init(is_interactive: bool) -> Result<Self> {
let config_path = Self::config_file()?;

let api_key = env::var(get_env_name("api_key")).ok();
let api_key = env::var("OPENAI_API_KEY").ok();

let exist_config_path = config_path.exists();
if is_interactive && api_key.is_none() && !exist_config_path {
Expand Down

0 comments on commit 2ab2e23

Please sign in to comment.