Skip to content

Commit

Permalink
Wire up Azure OpenAI completion provider (#8646)
Browse files Browse the repository at this point in the history
This PR wires up support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
as an alternative AI provider in the assistant panel.

This can be configured using the following in the settings file:

```json
{
  "assistant": {
    "provider": {
      "type": "azure_openai",
      "api_url": "https://{your-resource-name}.openai.azure.com",
      "deployment_id": "gpt-4",
      "api_version": "2023-05-15"
    }
  },
}
```

You will need to deploy a model within Azure and update the settings
accordingly.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Mar 1, 2024
1 parent 7c1ef96 commit eb1ab69
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 66 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 15 additions & 1 deletion assets/settings/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,29 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// Deprecated: Please use `provider.api_url` instead.
// The default OpenAI API endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
// Deprecated: Please use `provider.default_model` instead.
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_open_ai_model": "gpt-4-1106-preview"
"default_open_ai_model": "gpt-4-1106-preview",
"provider": {
"type": "openai",
// The default OpenAI API endpoint to use when starting new conversations.
"api_url": "https://api.openai.com/v1",
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_model": "gpt-4-1106-preview"
}
},
// Whether the screen sharing icon is shown in the os status bar.
"show_call_status_icon": true,
Expand Down
1 change: 1 addition & 0 deletions crates/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ parse_duration = "2.1.1"
postage.workspace = true
rand.workspace = true
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
tiktoken-rs.workspace = true
Expand Down
67 changes: 56 additions & 11 deletions crates/ai/src/providers/open_ai/completion.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};

use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Expand All @@ -6,23 +13,17 @@ use futures::{
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;

use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};

use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};

#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
Expand Down Expand Up @@ -196,12 +197,56 @@ async fn stream_completion(
}
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
pub enum AzureOpenAiApiVersion {
/// Retiring April 2, 2024.
#[serde(rename = "2023-03-15-preview")]
V2023_03_15Preview,
#[serde(rename = "2023-05-15")]
V2023_05_15,
/// Retiring April 2, 2024.
#[serde(rename = "2023-06-01-preview")]
V2023_06_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-07-01-preview")]
V2023_07_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-08-01-preview")]
V2023_08_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-09-01-preview")]
V2023_09_01Preview,
#[serde(rename = "2023-12-01-preview")]
V2023_12_01Preview,
#[serde(rename = "2024-02-15-preview")]
V2024_02_15Preview,
}

impl fmt::Display for AzureOpenAiApiVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
Self::V2023_03_15Preview => "2023-03-15-preview",
Self::V2023_05_15 => "2023-05-15",
Self::V2023_06_01Preview => "2023-06-01-preview",
Self::V2023_07_01Preview => "2023-07-01-preview",
Self::V2023_08_01Preview => "2023-08-01-preview",
Self::V2023_09_01Preview => "2023-09-01-preview",
Self::V2023_12_01Preview => "2023-12-01-preview",
Self::V2024_02_15Preview => "2024-02-15-preview",
}
)
}
}

#[derive(Clone)]
pub enum OpenAiCompletionProviderKind {
OpenAi,
AzureOpenAi {
deployment_id: String,
api_version: String,
api_version: AzureOpenAiApiVersion,
},
}

Expand All @@ -217,8 +262,8 @@ impl OpenAiCompletionProviderKind {
deployment_id,
api_version,
} => {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
}
}
}
Expand Down
77 changes: 45 additions & 32 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,18 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
let (api_url, model_name) = cx.update(|cx| {
let (provider_kind, api_url, model_name) = cx.update(|cx| {
let settings = AssistantSettings::get_global(cx);
(
settings.openai_api_url.clone(),
settings.default_open_ai_model.full_name().to_string(),
)
})?;
anyhow::Ok((
settings.provider_kind()?,
settings.provider_api_url()?,
settings.provider_model_name()?,
))
})??;

let completion_provider = OpenAiCompletionProvider::new(
api_url,
OpenAiCompletionProviderKind::OpenAi,
provider_kind,
model_name,
cx.background_executor().clone(),
)
Expand Down Expand Up @@ -693,24 +695,29 @@ impl AssistantPanel {
Task::ready(Ok(Vec::new()))
};

let mut model = AssistantSettings::get_global(cx)
.default_open_ai_model
.clone();
let model_name = model.full_name();

let prompt = cx.background_executor().spawn(async move {
let snippets = snippets.await?;
let Some(mut model_name) = AssistantSettings::get_global(cx)
.provider_model_name()
.log_err()
else {
return;
};

let language_name = language_name.as_deref();
generate_content_prompt(
user_prompt,
language_name,
buffer,
range,
snippets,
model_name,
project_name,
)
let prompt = cx.background_executor().spawn({
let model_name = model_name.clone();
async move {
let snippets = snippets.await?;

let language_name = language_name.as_deref();
generate_content_prompt(
user_prompt,
language_name,
buffer,
range,
snippets,
&model_name,
project_name,
)
}
});

let mut messages = Vec::new();
Expand All @@ -722,7 +729,7 @@ impl AssistantPanel {
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
model = conversation.model.clone();
model_name = conversation.model.full_name().to_string();
}

cx.spawn(|_, mut cx| async move {
Expand All @@ -735,7 +742,7 @@ impl AssistantPanel {
});

let request = Box::new(OpenAiRequest {
model: model.full_name().into(),
model: model_name,
messages,
stream: true,
stop: vec!["|END|>".to_string()],
Expand Down Expand Up @@ -1454,8 +1461,14 @@ impl Conversation {
});

let settings = AssistantSettings::get_global(cx);
let model = settings.default_open_ai_model.clone();
let api_url = settings.openai_api_url.clone();
let model = settings
.provider_model()
.log_err()
.unwrap_or(OpenAiModel::FourTurbo);
let api_url = settings
.provider_api_url()
.log_err()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string());

let mut this = Self {
id: Some(Uuid::new_v4().to_string()),
Expand Down Expand Up @@ -3655,9 +3668,9 @@ fn report_assistant_event(
let client = workspace.read(cx).project().read(cx).client();
let telemetry = client.telemetry();

let model = AssistantSettings::get_global(cx)
.default_open_ai_model
.clone();
let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else {
return;
};

telemetry.report_assistant_event(conversation_id, assistant_kind, model.full_name())
telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name)
}
Loading

0 comments on commit eb1ab69

Please sign in to comment.