Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Azure OpenAI #106

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::SharedConfig;
use crate::config::{Config, SharedConfig};
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};

use anyhow::{anyhow, bail, Context, Result};
Expand Down Expand Up @@ -92,6 +92,7 @@ impl ChatGptClient {
handler.text(&self.config.read().echo_messages(content))?;
return Ok(());
}
let chat_api = self.config.read().use_chat_api();
let builder = self.request_builder(content, true)?;
let res = builder.send().await?;
if !res.status().is_success() {
Expand All @@ -108,10 +109,13 @@ impl ChatGptClient {
break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
let text = if chat_api {
&data["choices"][0]["delta"]["content"]
} else {
&data["choices"][0]["text"]
};
let text = text.as_str().unwrap_or_default();
if text.is_empty() || text == "<|im_end|>" {
continue;
}
handler.text(text)?;
Expand All @@ -136,36 +140,69 @@ impl ChatGptClient {
}

fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let (model, _) = self.config.read().get_model();
let (api_key, organization_id) = self.config.read().get_api_key();
let messages = self.config.read().build_messages(content)?;
let mut body = json!({
"model": model,
"messages": messages,
});

let (builder, mut body) = if let Some((endpoint, deployment)) =
self.config.read().get_aoai_endpoint()
{
// Azure OpenAI: https://learn.microsoft.com/en-gb/azure/cognitive-services/openai/reference

let (url, body) = if self.config.read().use_chat_api() {
let url = format!(
"{endpoint}/openai/deployments/{deployment}/chat/completions?api-version=2023-03-15-preview"
);
let body = json!({
"messages": &messages,
});

(url, body)
} else {
let url = format!(
"{endpoint}/openai/deployments/{deployment}/completions?api-version=2022-12-01"
);
let body = json!({
"prompt": Config::render_messages(&messages),
});
(url, body)
};

let builder = self.build_client()?.post(url).header("api-key", api_key);

(builder, body)
} else {
// OpenAI: https://platform.openai.com/docs/api-reference/chat
let (model, _) = self.config.read().get_model();
let body = json!({
"model": model,
"messages": messages,
});

let mut builder = self.build_client()?.post(API_URL).bearer_auth(api_key);

if let Some(organization_id) = organization_id {
builder = builder.header("OpenAI-Organization", organization_id);
}

(builder, body)
};

if let Some(v) = self.config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}

if stream {
if let Some(v) = self.config.read().get_max_tokens() {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
.and_then(|m| m.insert("max_tokens".into(), json!(v)));
}

let (api_key, organization_id) = self.config.read().get_api_key();

let mut builder = self
.build_client()?
.post(API_URL)
.bearer_auth(api_key)
.json(&body);

if let Some(organization_id) = organization_id {
builder = builder.header("OpenAI-Organization", organization_id);
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}

Ok(builder)
Ok(builder.json(&body))
}
}

Expand Down
25 changes: 25 additions & 0 deletions src/config/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ impl Message {
content: content.to_string(),
}
}

pub fn render(&self) -> String {
let role = self.role.render();
let content = &self.content;
format!("<|im_start|>{role}\n{content}<|im_end|>\n")
}

pub fn render_all(messages: &[Message]) -> String {
let mut result = String::new();
for message in messages {
result.push_str(&message.render())
}
result.push_str("<|im_start|>assistant\n");
result
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand All @@ -25,6 +40,16 @@ pub enum MessageRole {
User,
}

impl MessageRole {
pub fn render(&self) -> &'static str {
match self {
MessageRole::System => "system",
MessageRole::Assistant => "assistant",
MessageRole::User => "user",
}
}
}

pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
let mut num_tokens = 0;
for message in messages.iter() {
Expand Down
96 changes: 93 additions & 3 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
const MESSAGE_FILE_NAME: &str = "messages.md";
const SET_COMPLETIONS: [&str; 8] = [
const SET_COMPLETIONS: [&str; 9] = [
".set temperature",
".set max_tokens",
".set save true",
".set save false",
".set highlight true",
Expand All @@ -47,15 +48,23 @@ const SET_COMPLETIONS: [&str; 8] = [
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
/// Openai api key
/// OpenAI or Azure OpenAI API key
pub api_key: Option<String>,
/// Azure OpenAI endpoint (set this to access via Azure OpenAI)
pub aoai_endpoint: Option<String>,
/// Azure OpenAI model deployment name
pub aoai_deployment: Option<String>,
/// Azure OpenAI - is this a chat-only model?
pub aoai_use_chat: Option<bool>,
/// Openai organization id
pub organization_id: Option<String>,
/// Openai model
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_name: Option<String>,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// The maximum size of the response in tokens
pub max_tokens: Option<u32>,
/// Whether to persistently save chat messages
pub save: bool,
/// Whether to disable highlight
Expand Down Expand Up @@ -87,9 +96,13 @@ impl Default for Config {
fn default() -> Self {
Self {
api_key: None,
aoai_endpoint: None,
aoai_deployment: None,
aoai_use_chat: None,
organization_id: None,
model_name: None,
temperature: None,
max_tokens: None,
save: false,
highlight: true,
proxy: None,
Expand Down Expand Up @@ -210,6 +223,30 @@ impl Config {
(api_key.into(), organization_id.cloned())
}

/// If using Azure OpenAI, returns `Some(endpoint, deployment)`, else `None`.
pub fn get_aoai_endpoint(&self) -> Option<(String, String)> {
if let Some(endpoint) = self.aoai_endpoint.as_ref() {
let deployment = self
.aoai_deployment
.as_ref()
.expect("aoai_deployment not set");
Some((endpoint.to_string(), deployment.to_string()))
} else {
None
}
}

pub fn use_chat_api(&self) -> bool {
if self.get_aoai_endpoint().is_none() {
// If we're using OpenAI's API, always use the chat API
true
} else {
// If we're using Azure OpenAI, then use chat API unless we're on an
// older model that doesn't support it.
self.aoai_use_chat.unwrap_or(true)
}
}

pub fn roles_file() -> Result<PathBuf> {
let env_name = get_env_name("roles_file");
if let Ok(value) = env::var(env_name) {
Expand Down Expand Up @@ -251,7 +288,7 @@ impl Config {
}

pub fn add_prompt(&mut self, prompt: &str) -> Result<()> {
let role = Role::new(prompt, self.temperature);
let role = Role::new(prompt, self.temperature, self.max_tokens);
if let Some(conversation) = self.conversation.as_mut() {
conversation.update_role(&role)?;
}
Expand All @@ -266,6 +303,13 @@ impl Config {
.or(self.temperature)
}

pub fn get_max_tokens(&self) -> Option<u32> {
self.role
.as_ref()
.and_then(|v| v.max_tokens)
.or(self.max_tokens)
}

pub fn echo_messages(&self, content: &str) -> String {
if let Some(conversation) = self.conversation.as_ref() {
conversation.echo_messages(content)
Expand Down Expand Up @@ -301,6 +345,10 @@ impl Config {
Ok(messages)
}

pub fn render_messages(messages: &[Message]) -> String {
Message::render_all(messages)
}

pub fn set_model(&mut self, name: &str) -> Result<()> {
if let Some(token) = MODELS.iter().find(|(v, _)| *v == name).map(|(_, v)| *v) {
self.model = (name.to_string(), token);
Expand Down Expand Up @@ -332,8 +380,16 @@ impl Config {
.temperature
.map(|v| v.to_string())
.unwrap_or("-".into());
let max_tokens = self.max_tokens.map(|v| v.to_string()).unwrap_or("-".into());
let (api_key, organization_id) = self.get_api_key();
let api_key = mask_text(&api_key, 3, 4);
let aoai_endpoint = self.aoai_endpoint.clone().unwrap_or("-".into());
let aoai_deployment = self.aoai_deployment.clone().unwrap_or("-".into());
let aoai_use_chat = self
.aoai_use_chat
.as_ref()
.map(|v| v.to_string())
.unwrap_or("-".into());
let organization_id = organization_id
.map(|v| mask_text(&v, 3, 4))
.unwrap_or("-".into());
Expand All @@ -342,9 +398,13 @@ impl Config {
("roles_file", file_info(&Config::roles_file()?)),
("messages_file", file_info(&Config::messages_file()?)),
("api_key", api_key),
("aoai_endpoint", aoai_endpoint.to_string()),
("aoai_deployment", aoai_deployment.to_string()),
("aoai_use_chat", aoai_use_chat),
("organization_id", organization_id),
("model", self.model.0.to_string()),
("temperature", temperature),
("max_tokens", max_tokens),
("save", self.save.to_string()),
("highlight", self.highlight.to_string()),
("proxy", proxy),
Expand Down Expand Up @@ -389,6 +449,14 @@ impl Config {
self.temperature = Some(value);
}
}
"max_tokens" => {
if unset {
self.max_tokens = None;
} else {
let value = value.parse().with_context(|| "Invalid value")?;
self.max_tokens = Some(value);
}
}
"save" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.save = value;
Expand Down Expand Up @@ -520,6 +588,28 @@ fn create_config_file(config_path: &Path) -> Result<()> {
.map_err(text_map_err)?;
let mut raw_config = format!("api_key: {api_key}\n");

let ans = Confirm::new("Use OpenAI via Azure?")
.with_default(false)
.prompt()
.map_err(confirm_map_err)?;
if ans {
let endpoint = Text::new("Azure OpenAI endpoint:")
.prompt()
.map_err(text_map_err)?;
raw_config.push_str(&format!("aoai_endpoint: {endpoint}\n"));

let deployment = Text::new("Model deployment name:")
.prompt()
.map_err(text_map_err)?;
raw_config.push_str(&format!("aoai_deployment: {deployment}\n"));

let use_chat: bool = Confirm::new("Use chat API (not available for older models)?")
.with_default(true)
.prompt()
.map_err(text_map_err)?;
raw_config.push_str(&format!("aoai_use_chat: {use_chat}\n"));
}

let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
Expand Down
5 changes: 4 additions & 1 deletion src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ pub struct Role {
pub prompt: String,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Maximum number of tokens to return
pub max_tokens: Option<u32>,
}

impl Role {
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
pub fn new(prompt: &str, temperature: Option<f64>, max_tokens: Option<u32>) -> Self {
Self {
name: TEMP_NAME.into(),
prompt: prompt.into(),
temperature,
max_tokens,
}
}

Expand Down