Skip to content

Commit

Permalink
feat: proxy chat-completions api with tools support (#850)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Sep 9, 2024
1 parent a56d5f2 commit 6996546
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 72 deletions.
50 changes: 28 additions & 22 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ pub async fn openai_chat_completions_streaming(
let handle = |message: SseMmessage| -> Result<bool> {
if message.data == "[DONE]" {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
arguments,
normalize_function_id(&function_id),
))?;
}
Expand All @@ -128,9 +131,12 @@ pub async fn openai_chat_completions_streaming(
let index = index.unwrap_or_default();
if index != function_index {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
arguments,
normalize_function_id(&function_id),
))?;
}
Expand Down Expand Up @@ -207,7 +213,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments,
"arguments": tool_result.call.arguments.to_string(),
},
})
}).collect();
Expand Down Expand Up @@ -237,7 +243,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments,
"arguments": tool_result.call.arguments.to_string(),
},
}
]
Expand Down Expand Up @@ -302,24 +308,24 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu

let mut tool_calls = vec![];
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
json!(arguments),
Some(id.to_string()),
))
} else {
None
}
})
.collect()
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!(
"Tool call '{name}' is invalid: arguments must be in valid JSON format"
)
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
arguments,
Some(id.to_string()),
));
}
}
};

if text.is_empty() && tool_calls.is_empty() {
Expand Down
4 changes: 4 additions & 0 deletions src/client/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ impl SseHandler {
self.abort.clone()
}

pub fn get_tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}

pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..
Expand Down
Loading

0 comments on commit 6996546

Please sign in to comment.