Skip to content

Commit

Permalink
feat: support qianwen:qwen-vl-plus (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Dec 19, 2023
1 parent 34d58b2 commit 6c9d7a6
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 48 deletions.
19 changes: 11 additions & 8 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::{
patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType, SendData,
TokensCountFactors,
message::*, patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType,
SendData, TokensCountFactors,
};

use crate::{client::*, config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};

use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -123,7 +123,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i-1] != '\\' {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
Expand Down Expand Up @@ -189,7 +189,7 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {

patch_system_message(&mut messages);

let mut invalid_urls = vec![];
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.map(|message| {
Expand All @@ -211,7 +211,7 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
invalid_urls.push(url.clone());
network_image_urls.push(url.clone());
json!({ "url": url })
}
},
Expand All @@ -223,8 +223,11 @@ fn build_body(data: SendData, _model: String) -> Result<Value> {
})
.collect();

if !invalid_urls.is_empty() {
bail!("The model does not support non-data URLs: {:?}", invalid_urls);
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}

let mut body = json!({
Expand Down
150 changes: 110 additions & 40 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};
use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};

use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};

Expand All @@ -13,10 +13,15 @@ use serde_json::{json, Value};
const API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";

const MODELS: [(&str, usize); 3] = [
("qwen-turbo", 6144),
("qwen-plus", 6144),
("qwen-max", 6144),
const API_URL_VL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation";

const MODELS: [(&str, usize); 5] = [
("qwen-turbo", 8192),
("qwen-plus", 32768),
("qwen-max", 8192),
("qwen-max-longcontext", 30720),
("qwen-vl-plus", 0),
];

#[derive(Debug, Clone, Deserialize, Default)]
Expand All @@ -34,7 +39,7 @@ impl Client for QianwenClient {

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
send_message(builder, self.is_vl()).await
}

async fn send_message_streaming_inner(
Expand All @@ -44,7 +49,7 @@ impl Client for QianwenClient {
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
send_message_streaming(builder, handler, self.is_vl()).await
}
}

Expand All @@ -68,49 +73,71 @@ impl QianwenClient {
let api_key = self.get_api_key()?;

let stream = data.stream;
let body = build_body(data, self.model.name.clone());

debug!("Qianwen Request: {API_URL} {body}");
let is_vl = self.is_vl();
let url = match is_vl {
true => API_URL_VL,
false => API_URL,
};
let body = build_body(data, self.model.name.clone(), is_vl)?;

debug!("Qianwen Request: {url} {body}");

let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body);
let mut builder = client.post(url).bearer_auth(api_key).json(&body);
if stream {
builder = builder.header("X-DashScope-SSE", "enable");
}

Ok(builder)
}

fn is_vl(&self) -> bool {
self.model.name.starts_with("qwen-vl")
}
}

async fn send_message(builder: RequestBuilder) -> Result<String> {
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
check_error(&data)?;

let output = data["output"]["text"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
let output = if is_vl {
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
} else {
data["output"]["text"].as_str()
};

let output = output.ok_or_else(|| anyhow!("Unexpected response {data}"))?;

Ok(output.to_string())
}

async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
let mut offset = 0;

while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["output"]["text"].as_str() {
check_error(&data)?;
if is_vl {
let text = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
if let Some(text) = text {
let text = &text[offset..];
handler.text(text)?;
offset += text.len();
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
}
}
Err(err) => {
match err {
EventSourceError::InvalidStatusCode(_, res) => {
let data: Value = res.json().await?;
check_error(&data)?;
bail!("Request failed");
}
EventSourceError::StreamEnded => {}
_ => {
bail!("{}", err);
Expand All @@ -125,38 +152,81 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
}

fn check_error(data: &Value) -> Result<()> {
if let Some(code) = data["code"].as_str() {
if let Some(message) = data["message"].as_str() {
bail!("{message}");
} else {
bail!("{code}");
}
if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) {
bail!("{code}: {message}");
}
Ok(())
}

fn build_body(data: SendData, model: String) -> Value {
fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
let SendData {
messages,
temperature,
stream,
} = data;

let mut parameters = json!({});

if stream {
parameters["incremental_output"] = true.into();
}
let (input, parameters) = if is_vl {
let mut exist_embeded_image = false;

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if url.starts_with("data:") {
exist_embeded_image = true;
}
json!({"image": url})
},
})
.collect(),
};
json!({ "role": role, "content": content })
})
.collect();

json!({
"model": model,
"input": json!({
if exist_embeded_image {
bail!("The model does not support embeded images");
}

let input = json!({
"messages": messages,
});

let mut parameters = json!({});
if let Some(v) = temperature {
parameters["top_k"] = ((v * 50.0).round() as usize).into();
}
(input, parameters)
} else {
let input = json!({
"messages": messages,
}),
});

let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
};

let body = json!({
"model": model,
"input": input,
"parameters": parameters
})
});
Ok(body)
}

0 comments on commit 6c9d7a6

Please sign in to comment.