Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: deadlock between background job and requests #720

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 28 additions & 127 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
mod llama;
mod utils;

use std::{collections::HashMap, sync::Arc};

use async_stream::stream;
use async_trait::async_trait;
use cxx::UniquePtr;
use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use futures::stream::BoxStream;
use llama::LlamaService;
use tabby_inference::{
decoding::{StopCondition, StopConditionFactory},
helpers, TextGeneration, TextGenerationOptions,
};
use tokio::{
sync::mpsc::{channel, Sender},
task::yield_now,
decoding::StopConditionFactory, helpers, TextGeneration, TextGenerationOptions,
};

#[cxx::bridge(namespace = "llama")]
Expand Down Expand Up @@ -45,66 +39,36 @@ mod ffi {
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}

struct InferenceRequest {
tx: Sender<String>,
stop_condition: StopCondition,
#[derive(Builder, Debug)]
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
}

struct AsyncTextInferenceEngine {
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
pub struct LlamaTextGeneration {
service: LlamaService,
stop_condition_factory: StopConditionFactory,
requests: Mutex<HashMap<u32, InferenceRequest>>,

next_request_id: Mutex<u32>,
}

impl AsyncTextInferenceEngine {
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}

Self {
engine: Mutex::new(engine),
service: LlamaService::new(engine),
stop_condition_factory: StopConditionFactory::default(),
requests: Mutex::new(HashMap::new()),
next_request_id: Mutex::new(0),
}
}
}

async fn background_job(&self) {
let mut requests = self.requests.lock().await;
if requests.len() == 0 {
return;
}

let mut engine = self.engine.lock().await;

let result = match engine.as_mut().unwrap().step() {
Ok(result) => result,
Err(err) => {
fatal!("Failed to step: {}", err)
}
};

for ffi::StepOutput { request_id, text } in result {
let mut stopped = false;
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();

if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if !stop_condition.should_stop(&text) {
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
} else {
// Stoop words stopped
stopped = true;
}

if stopped {
requests.remove(&request_id);
engine.as_mut().unwrap().stop_request(request_id);
}
}
#[async_trait]
impl TextGeneration for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}

async fn generate_stream(
Expand All @@ -114,23 +78,10 @@ impl AsyncTextInferenceEngine {
) -> BoxStream<String> {
let stop_condition = self.stop_condition_factory.create(prompt, options.language);

let (tx, mut rx) = channel::<String>(4);
{
let mut engine = self.engine.lock().await;

let mut request_id = self.next_request_id.lock().await;
self.requests
.lock()
.await
.insert(*request_id, InferenceRequest { tx, stop_condition });
engine
.as_mut()
.unwrap()
.add_request(*request_id, prompt, options.max_input_length);

// 2048 should be large enough to avoid collision.
*request_id = (*request_id + 1) % 2048;
}
let mut rx = self
.service
.add_request(prompt, options.max_input_length, stop_condition)
.await;

let s = stream! {
let mut length = 0;
Expand All @@ -148,53 +99,3 @@ impl AsyncTextInferenceEngine {
Box::pin(s)
}
}

#[derive(Builder, Debug)]
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
}

pub struct LlamaTextGeneration {
engine: Arc<AsyncTextInferenceEngine>,
}

impl LlamaTextGeneration {
pub fn create(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}
let ret = LlamaTextGeneration {
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
};
ret.start_background_job();
ret
}

pub fn start_background_job(&self) {
let engine = self.engine.clone();
tokio::spawn(async move {
loop {
engine.background_job().await;
yield_now().await;
}
});
}
}

#[async_trait]
impl TextGeneration for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}

async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
self.engine.generate_stream(prompt, options).await
}
}
155 changes: 155 additions & 0 deletions crates/llama-cpp-bindings/src/llama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::{collections::HashMap, thread::JoinHandle};

use cxx::UniquePtr;
use tabby_inference::decoding::StopCondition;
use tokio::sync::mpsc::{channel, Receiver, Sender};

use crate::ffi;

struct LlamaInitRequest {
prompt: String,
max_input_length: usize,

tx: Sender<String>,
stop_condition: StopCondition,
}

struct LlamaRunningRequest {
tx: Sender<String>,
stop_condition: StopCondition,
}

struct LlamaServiceImpl {
next_request_id: u32,
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
rx: Receiver<LlamaInitRequest>,
requests: HashMap<u32, LlamaRunningRequest>,
}

impl LlamaServiceImpl {
fn new(engine: UniquePtr<ffi::TextInferenceEngine>, rx: Receiver<LlamaInitRequest>) -> Self {
Self {
next_request_id: 0,
engine,
rx,
requests: HashMap::new(),
}
}

fn alloc_request_id(&mut self) -> u32 {
let ret = self.next_request_id;
self.next_request_id += 1;
ret
}

async fn next_request(&mut self) -> Option<LlamaInitRequest> {
if self.requests.is_empty() {
self.rx.recv().await
} else {
self.rx.try_recv().ok()
}
}

async fn background_job(&mut self) {
while let Some(LlamaInitRequest {
prompt,
tx,
max_input_length,
stop_condition,
}) = self.next_request().await
{
let request_id = self.alloc_request_id();
self.requests
.insert(request_id, LlamaRunningRequest { tx, stop_condition });
self.engine
.as_mut()
.unwrap()
.add_request(request_id, &prompt, max_input_length);
}

let result = match self.engine.as_mut().unwrap().step() {
Ok(result) => result,
Err(err) => {
crate::fatal!("Failed to step: {}", err)
}
};

for ffi::StepOutput { request_id, text } in result {
let mut stopped = false;
let LlamaRunningRequest { tx, stop_condition } =
self.requests.get_mut(&request_id).unwrap();

if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if !stop_condition.should_stop(&text) {
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
} else {
// Stoop words stopped
stopped = true;
}

if stopped {
self.requests.remove(&request_id);
self.engine.as_mut().unwrap().stop_request(request_id);
}
}
}
}

fn start_llama_service_impl(
engine: UniquePtr<ffi::TextInferenceEngine>,
rx: Receiver<LlamaInitRequest>,
) -> JoinHandle<()> {
let mut service = LlamaServiceImpl::new(engine, rx);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

std::thread::spawn(move || {
let local = tokio::task::LocalSet::new();
local.spawn_local(async move {
loop {
service.background_job().await;
}
});

rt.block_on(local);
})
}

pub struct LlamaService {
tx: Sender<LlamaInitRequest>,
}

impl LlamaService {
pub fn new(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
let (tx, rx) = channel(20);
start_llama_service_impl(engine, rx);
Self { tx }
}

pub async fn add_request(
&self,
prompt: &str,
max_input_length: usize,
stop_condition: StopCondition,
) -> Receiver<String> {
let (tx, rx) = channel(8);
self.tx
.send(LlamaInitRequest {
prompt: prompt.to_owned(),
tx,
max_input_length,
stop_condition,
})
.await
.expect("Failed to add request");

rx
}
}
2 changes: 1 addition & 1 deletion crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextG
.build()
.unwrap();

Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
}
Loading