Skip to content

Commit

Permalink
fix: add cancellation token to write_csv/json/parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
DDtKey committed Feb 6, 2023
1 parent 4233752 commit c31248e
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 97 deletions.
67 changes: 37 additions & 30 deletions datafusion/core/src/physical_plan/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
use tokio::task::{self, JoinHandle};
use tokio_util::sync::CancellationToken;

use super::{get_output_ordering, FileScanConfig};

Expand Down Expand Up @@ -286,38 +287,44 @@ pub async fn plan_to_csv(
let path = path.as_ref();
// create directory to contain the CSV files (one per partition)
let fs_path = Path::new(path);
match fs::create_dir(fs_path) {
Ok(()) => {
let mut tasks = vec![];
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.csv");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer = csv::Writer::new(file);
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;
let handle: JoinHandle<Result<()>> = task::spawn(async move {
stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
.map_err(DataFusionError::from)
});
tasks.push(handle);
}
futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}
Err(e) => Err(DataFusionError::Execution(format!(
if let Err(e) = fs::create_dir(fs_path) {
return Err(DataFusionError::Execution(format!(
"Could not create directory {path}: {e:?}"
))),
)));
}

let mut tasks = vec![];
// Create cancellation-token to interrupt background execution on drop.
let cancellation_token = CancellationToken::new();
let _ = cancellation_token.clone().drop_guard();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.csv");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer = csv::Writer::new(file);
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;

let cancellation_token = cancellation_token.child_token();
let handle: JoinHandle<Result<()>> = task::spawn(async move {
let exec_future = stream.map(|batch| writer.write(&batch?)).try_collect();

tokio::select! {
res = exec_future => res.map_err(DataFusionError::from),
_ = cancellation_token.cancelled() => Err(DataFusionError::Execution("Execution was stopped by the caller".to_string()))
}
});
tasks.push(handle);
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}

#[cfg(test)]
Expand Down
67 changes: 37 additions & 30 deletions datafusion/core/src/physical_plan/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
use tokio::task::{self, JoinHandle};
use tokio_util::sync::CancellationToken;

use super::{get_output_ordering, FileScanConfig};

Expand Down Expand Up @@ -230,38 +231,44 @@ pub async fn plan_to_json(
let path = path.as_ref();
// create directory to contain the CSV files (one per partition)
let fs_path = Path::new(path);
match fs::create_dir(fs_path) {
Ok(()) => {
let mut tasks = vec![];
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.json");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer = json::LineDelimitedWriter::new(file);
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;
let handle: JoinHandle<Result<()>> = task::spawn(async move {
stream
.map(|batch| writer.write(batch?))
.try_collect()
.await
.map_err(DataFusionError::from)
});
tasks.push(handle);
}
futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}
Err(e) => Err(DataFusionError::Execution(format!(
if let Err(e) = fs::create_dir(fs_path) {
return Err(DataFusionError::Execution(format!(
"Could not create directory {path}: {e:?}"
))),
)));
}

let mut tasks = vec![];
// Create cancellation-token to interrupt background execution on drop.
let cancellation_token = CancellationToken::new();
let _ = cancellation_token.clone().drop_guard();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.json");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer = json::LineDelimitedWriter::new(file);
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;

let cancellation_token = cancellation_token.child_token();
let handle: JoinHandle<Result<()>> = task::spawn(async move {
let exec_future = stream.map(|batch| writer.write(batch?)).try_collect();

tokio::select! {
res = exec_future => res.map_err(DataFusionError::from),
_ = cancellation_token.cancelled() => Err(DataFusionError::Execution("Execution was stopped by the caller".to_string()))
}
});
tasks.push(handle);
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}

#[cfg(test)]
Expand Down
83 changes: 46 additions & 37 deletions datafusion/core/src/physical_plan/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ use parquet::basic::{ConvertedType, LogicalType};
use parquet::errors::ParquetError;
use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties};
use parquet::schema::types::ColumnDescriptor;
use tokio_util::sync::CancellationToken;

mod metrics;
mod page_filter;
Expand Down Expand Up @@ -706,45 +707,53 @@ pub async fn plan_to_parquet(
let path = path.as_ref();
// create directory to contain the Parquet files (one per partition)
let fs_path = std::path::Path::new(path);
match fs::create_dir(fs_path) {
Ok(()) => {
let mut tasks = vec![];
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.parquet");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer =
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;
let handle: tokio::task::JoinHandle<Result<()>> =
tokio::task::spawn(async move {
stream
.map(|batch| {
writer
.write(&batch?)
.map_err(DataFusionError::ParquetError)
})
.try_collect()
.await
.map_err(DataFusionError::from)?;
writer.close().map_err(DataFusionError::from).map(|_| ())
});
tasks.push(handle);
}
futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}
Err(e) => Err(DataFusionError::Execution(format!(
if let Err(e) = fs::create_dir(fs_path) {
return Err(DataFusionError::Execution(format!(
"Could not create directory {path}: {e:?}"
))),
)));
}

let mut tasks = vec![];
// Create cancellation-token to interrupt background execution on drop.
let cancellation_token = CancellationToken::new();
let _ = cancellation_token.clone().drop_guard();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.parquet");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer =
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
let task_ctx = Arc::new(TaskContext::from(state));
let stream = plan.execute(i, task_ctx)?;

let cancellation_token = cancellation_token.child_token();
let handle: tokio::task::JoinHandle<Result<()>> = tokio::task::spawn(
async move {
let exec_future = stream
.map(|batch| {
writer.write(&batch?).map_err(DataFusionError::ParquetError)
})
.try_collect();

tokio::select! {
res = exec_future => res.map_err(DataFusionError::from)?,
_ = cancellation_token.cancelled() => return Err(DataFusionError::Execution("Execution was stopped by the caller".to_string()))
}

writer.close().map_err(DataFusionError::from).map(|_| ())
},
);
tasks.push(handle);
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}

// Copy from the arrow-rs
Expand Down

0 comments on commit c31248e

Please sign in to comment.