Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(flow): make from_substrait_* async& worker handle refactor #4210

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ workspace = true
[dependencies]
api.workspace = true
arrow-schema.workspace = true
async-recursion = "1.0"
async-trait.workspace = true
bytes.workspace = true
catalog.workspace = true
Expand Down
152 changes: 52 additions & 100 deletions src/flow/src/adapter/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
//! For single-thread flow worker

use std::collections::{BTreeMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use common_telemetry::info;
use enum_as_inner::EnumAsInner;
use hydroflow::scheduled::graph::Hydroflow;
use snafu::{ensure, OptionExt};
use tokio::sync::{broadcast, mpsc, Mutex};
use snafu::ensure;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};

use crate::adapter::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu};
use crate::adapter::FlowId;
Expand All @@ -39,7 +39,7 @@ type ReqId = usize;
pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) {
let (itc_client, itc_server) = create_inter_thread_call();
let worker_handle = WorkerHandle {
itc_client: Mutex::new(itc_client),
itc_client,
shutdown: AtomicBool::new(false),
};
let worker = Worker {
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<'subgraph> ActiveDataflowState<'subgraph> {

#[derive(Debug)]
pub struct WorkerHandle {
itc_client: Mutex<InterThreadCallClient>,
itc_client: InterThreadCallClient,
shutdown: AtomicBool,
}

Expand All @@ -122,12 +122,7 @@ impl WorkerHandle {
}
);

let ret = self
.itc_client
.lock()
.await
.call_with_resp(create_reqs)
.await?;
let ret = self.itc_client.call_with_resp(create_reqs).await?;
ret.into_create().map_err(|ret| {
InternalSnafu {
reason: format!(
Expand All @@ -141,7 +136,8 @@ impl WorkerHandle {
/// remove task, return task id
pub async fn remove_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
let req = Request::Remove { flow_id };
let ret = self.itc_client.lock().await.call_with_resp(req).await?;

let ret = self.itc_client.call_with_resp(req).await?;

ret.into_remove().map_err(|ret| {
InternalSnafu {
Expand All @@ -157,15 +153,12 @@ impl WorkerHandle {
///
/// the returned error is unrecoverable, and the worker should be shutdown/rebooted
pub async fn run_available(&self, now: repr::Timestamp) -> Result<(), Error> {
self.itc_client
.lock()
.await
.call_no_resp(Request::RunAvail { now })
self.itc_client.call_no_resp(Request::RunAvail { now })
}

pub async fn contains_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
let req = Request::ContainTask { flow_id };
let ret = self.itc_client.lock().await.call_with_resp(req).await?;
let ret = self.itc_client.call_with_resp(req).await?;

ret.into_contain_task().map_err(|ret| {
InternalSnafu {
Expand All @@ -178,23 +171,9 @@ impl WorkerHandle {
}

/// shutdown the worker
pub async fn shutdown(&self) -> Result<(), Error> {
pub fn shutdown(&self) -> Result<(), Error> {
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
self.itc_client.lock().await.call_no_resp(Request::Shutdown)
} else {
UnexpectedSnafu {
reason: "Worker already shutdown",
}
.fail()
}
}

/// shutdown the worker
pub fn shutdown_blocking(&self) -> Result<(), Error> {
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
self.itc_client
.blocking_lock()
.call_no_resp(Request::Shutdown)
self.itc_client.call_no_resp(Request::Shutdown)
} else {
UnexpectedSnafu {
reason: "Worker already shutdown",
Expand All @@ -206,8 +185,7 @@ impl WorkerHandle {

impl Drop for WorkerHandle {
fn drop(&mut self) {
let ret = futures::executor::block_on(async { self.shutdown().await });
if let Err(ret) = ret {
if let Err(ret) = self.shutdown() {
common_telemetry::error!(
ret;
"While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
Expand Down Expand Up @@ -276,7 +254,7 @@ impl<'s> Worker<'s> {
/// Run the worker, blocking, until shutdown signal is received
pub fn run(&mut self) {
loop {
let (req_id, req) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
discord9 marked this conversation as resolved.
Show resolved Hide resolved
ret
} else {
common_telemetry::error!(
Expand All @@ -285,19 +263,26 @@ impl<'s> Worker<'s> {
break;
};

let ret = self.handle_req(req_id, req);
match ret {
Ok(Some((id, resp))) => {
if let Err(err) = self.itc_server.blocking_lock().resp(id, resp) {
let ret = self.handle_req(req);
match (ret, ret_tx) {
(Ok(Some(resp)), Some(ret_tx)) => {
if let Err(err) = ret_tx.send(resp) {
common_telemetry::error!(
err;
"Worker's itc server has been closed unexpectedly, shutting down worker"
"Result receiver is dropped, can't send result"
);
break;
};
}
Ok(None) => continue,
Err(()) => {
(Ok(None), None) => continue,
(Ok(Some(resp)), None) => {
common_telemetry::error!(
"Expect no result for current request, but found {resp:?}"
)
}
(Ok(None), Some(_)) => {
common_telemetry::error!("Expect result for current request, but found nothing")
}
(Err(()), _) => {
break;
}
}
Expand All @@ -315,7 +300,7 @@ impl<'s> Worker<'s> {
/// handle request, return response if any, Err if receive shutdown signal
///
/// return `Err(())` if receive shutdown request
fn handle_req(&mut self, req_id: ReqId, req: Request) -> Result<Option<(ReqId, Response)>, ()> {
fn handle_req(&mut self, req: Request) -> Result<Option<Response>, ()> {
let ret = match req {
Request::Create {
flow_id,
Expand All @@ -339,24 +324,21 @@ impl<'s> Worker<'s> {
create_if_not_exists,
err_collector,
);
Some((
req_id,
Response::Create {
result: task_create_result,
},
))
Some(Response::Create {
result: task_create_result,
})
}
Request::Remove { flow_id } => {
let ret = self.remove_flow(flow_id);
Some((req_id, Response::Remove { result: ret }))
Some(Response::Remove { result: ret })
}
Request::RunAvail { now } => {
self.run_tick(now);
None
}
Request::ContainTask { flow_id } => {
let ret = self.task_states.contains_key(&flow_id);
Some((req_id, Response::ContainTask { result: ret }))
Some(Response::ContainTask { result: ret })
}
Request::Shutdown => return Err(()),
};
Expand Down Expand Up @@ -406,83 +388,50 @@ enum Response {

fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) {
let (arg_send, arg_recv) = mpsc::unbounded_channel();
let (ret_send, ret_recv) = mpsc::unbounded_channel();
let client = InterThreadCallClient {
call_id: AtomicUsize::new(0),
arg_sender: arg_send,
ret_recv,
};
let server = InterThreadCallServer {
arg_recv,
ret_sender: ret_send,
};
let server = InterThreadCallServer { arg_recv };
(client, server)
}

#[derive(Debug)]
struct InterThreadCallClient {
call_id: AtomicUsize,
arg_sender: mpsc::UnboundedSender<(ReqId, Request)>,
ret_recv: mpsc::UnboundedReceiver<(ReqId, Response)>,
arg_sender: mpsc::UnboundedSender<(Request, Option<oneshot::Sender<Response>>)>,
}

impl InterThreadCallClient {
/// call without expecting responses or blocking
fn call_no_resp(&self, req: Request) -> Result<(), Error> {
// TODO(discord9): relax memory order later
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
self.arg_sender
.send((call_id, req))
.map_err(from_send_error)
self.arg_sender.send((req, None)).map_err(from_send_error)
}

/// call blocking, and return the result
async fn call_with_resp(&mut self, req: Request) -> Result<Response, Error> {
// TODO(discord9): relax memory order later
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
async fn call_with_resp(&self, req: Request) -> Result<Response, Error> {
let (tx, rx) = oneshot::channel();
self.arg_sender
.send((call_id, req))
.send((req, Some(tx)))
.map_err(from_send_error)?;

// TODO(discord9): better inter thread call impl, i.e. support multiple client(also consider if it's necessary)
// since one node manger might manage multiple worker, but one worker should only belong to one node manager
let (ret_call_id, ret) = self
.ret_recv
.recv()
.await
.context(InternalSnafu { reason: "InterThreadCallClient call_blocking failed, ret_recv has been closed and there are no remaining messages in the channel's buffer" })?;

ensure!(
ret_call_id == call_id,
rx.await.map_err(|_| {
InternalSnafu {
reason: "call id mismatch, worker/worker handler should be in sync",
reason: "Sender is dropped",
}
);
Ok(ret)
.build()
})
}
}

#[derive(Debug)]
struct InterThreadCallServer {
pub arg_recv: mpsc::UnboundedReceiver<(ReqId, Request)>,
pub ret_sender: mpsc::UnboundedSender<(ReqId, Response)>,
pub arg_recv: mpsc::UnboundedReceiver<(Request, Option<oneshot::Sender<Response>>)>,
}

impl InterThreadCallServer {
pub async fn recv(&mut self) -> Option<(usize, Request)> {
pub async fn recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
self.arg_recv.recv().await
}

pub fn blocking_recv(&mut self) -> Option<(usize, Request)> {
pub fn blocking_recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
self.arg_recv.blocking_recv()
}

/// Send response back to the client
pub fn resp(&self, call_id: ReqId, resp: Response) -> Result<(), Error> {
self.ret_sender
.send((call_id, resp))
.map_err(from_send_error)
}
}

fn from_send_error<T>(err: mpsc::error::SendError<T>) -> Error {
Expand Down Expand Up @@ -546,7 +495,10 @@ mod test {
create_if_not_exists: true,
err_collector: ErrCollector::default(),
};
handle.create_flow(create_reqs).await.unwrap();
assert_eq!(
handle.create_flow(create_reqs).await.unwrap(),
Some(flow_id)
);
tx.send((Row::empty(), 0, 0)).unwrap();
handle.run_available(0).await.unwrap();
assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty());
Expand Down
2 changes: 1 addition & 1 deletion src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::repr::{self, value_to_internal_ts, Row};

/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
pub enum UnmaterializableFunc {
Now,
CurrentSchema,
Expand Down
4 changes: 2 additions & 2 deletions src/flow/src/expr/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use crate::repr::{self, value_to_internal_ts, Diff, Row};
/// expressions in `self.expressions`, even though this is not something
/// we can directly evaluate. The plan creation methods will defensively
/// ensure that the right thing happens.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
pub struct MapFilterProject {
/// A sequence of expressions that should be appended to the row.
///
Expand Down Expand Up @@ -462,7 +462,7 @@ impl MapFilterProject {
}

/// A wrapper type which indicates it is safe to simply evaluate all expressions.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
pub struct SafeMfpPlan {
/// the inner `MapFilterProject` that is safe to evaluate.
pub(crate) mfp: MapFilterProject,
Expand Down
3 changes: 1 addition & 2 deletions src/flow/src/expr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod accum;
mod func;

/// Describes an aggregation expression.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
pub struct AggregateExpr {
/// Names the aggregation function.
pub func: AggregateFunc,
Expand All @@ -32,6 +32,5 @@ pub struct AggregateExpr {
/// so it only used in generate KeyValPlan from AggregateExpr
pub expr: ScalarExpr,
/// Should the aggregation be applied only to distinct results in each group.
#[serde(default)]
pub distinct: bool,
}
Loading