Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support qianwen:qwen-vl-plus #275

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::{
patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType, SendData,
TokensCountFactors,
message::*, patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType,
SendData, TokensCountFactors,
};

use crate::{client::*, config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};

use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -123,7 +123,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i-1] != '\\' {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
Expand Down Expand Up @@ -189,7 +189,7 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {

patch_system_message(&mut messages);

let mut invalid_urls = vec![];
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.map(|message| {
Expand All @@ -211,7 +211,7 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
invalid_urls.push(url.clone());
network_image_urls.push(url.clone());
json!({ "url": url })
}
},
Expand All @@ -223,8 +223,11 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {
})
.collect();

if !invalid_urls.is_empty() {
bail!("The model does not support non-data URLs: {:?}", invalid_urls);
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}

let mut body = json!({
Expand Down
150 changes: 110 additions & 40 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};
use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};

use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};

Expand All @@ -13,10 +13,15 @@ use serde_json::{json, Value};
const API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";

const MODELS: [(&str, usize); 3] = [
("qwen-turbo", 6144),
("qwen-plus", 6144),
("qwen-max", 6144),
const API_URL_VL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation";

const MODELS: [(&str, usize); 5] = [
("qwen-turbo", 8192),
("qwen-plus", 32768),
("qwen-max", 8192),
("qwen-max-longcontext", 30720),
("qwen-vl-plus", 0),
];

#[derive(Debug, Clone, Deserialize, Default)]
Expand All @@ -34,7 +39,7 @@ impl Client for QianwenClient {

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
send_message(builder, self.is_vl()).await
}

async fn send_message_streaming_inner(
Expand All @@ -44,7 +49,7 @@ impl Client for QianwenClient {
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
send_message_streaming(builder, handler, self.is_vl()).await
}
}

Expand All @@ -68,49 +73,71 @@ impl QianwenClient {
let api_key = self.get_api_key()?;

let stream = data.stream;
let body = build_body(data, self.model.name.clone());

debug!("Qianwen Request: {API_URL} {body}");
let is_vl = self.is_vl();
let url = match is_vl {
true => API_URL_VL,
false => API_URL,
};
let body = build_body(data, self.model.name.clone(), is_vl)?;

debug!("Qianwen Request: {url} {body}");

let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body);
let mut builder = client.post(url).bearer_auth(api_key).json(&body);
if stream {
builder = builder.header("X-DashScope-SSE", "enable");
}

Ok(builder)
}

fn is_vl(&self) -> bool {
self.model.name.starts_with("qwen-vl")
}
}

async fn send_message(builder: RequestBuilder) -> Result<String> {
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
check_error(&data)?;

let output = data["output"]["text"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
let output = if is_vl {
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
} else {
data["output"]["text"].as_str()
};

let output = output.ok_or_else(|| anyhow!("Unexpected response {data}"))?;

Ok(output.to_string())
}

async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
let mut offset = 0;

while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["output"]["text"].as_str() {
check_error(&data)?;
if is_vl {
let text = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
if let Some(text) = text {
let text = &text[offset..];
handler.text(text)?;
offset += text.len();
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
}
}
Err(err) => {
match err {
EventSourceError::InvalidStatusCode(_, res) => {
let data: Value = res.json().await?;
check_error(&data)?;
bail!("Request failed");
}
EventSourceError::StreamEnded => {}
_ => {
bail!("{}", err);
Expand All @@ -125,38 +152,81 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
}

fn check_error(data: &Value) -> Result<()> {
if let Some(code) = data["code"].as_str() {
if let Some(message) = data["message"].as_str() {
bail!("{message}");
} else {
bail!("{code}");
}
if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) {
bail!("{code}: {message}");
}
Ok(())
}

fn build_body(data: SendData, model: String) -> Value {
fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
let SendData {
messages,
temperature,
stream,
} = data;

let mut parameters = json!({});

if stream {
parameters["incremental_output"] = true.into();
}
let (input, parameters) = if is_vl {
let mut exist_embeded_image = false;

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if url.starts_with("data:") {
exist_embeded_image = true;
}
json!({"image": url})
},
})
.collect(),
};
json!({ "role": role, "content": content })
})
.collect();

json!({
"model": model,
"input": json!({
if exist_embeded_image {
bail!("The model does not support embeded images");
}

let input = json!({
"messages": messages,
});

let mut parameters = json!({});
if let Some(v) = temperature {
parameters["top_k"] = ((v * 50.0).round() as usize).into();
}
(input, parameters)
} else {
let input = json!({
"messages": messages,
}),
});

let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
};

let body = json!({
"model": model,
"input": input,
"parameters": parameters
})
});
Ok(body)
}