Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(COPY TO): hive partitioning support #2634

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

263 changes: 263 additions & 0 deletions crates/datasources/src/common/sink/hive_partitioning.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
use std::any::Any;
use std::fmt;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Result as DfResult;
use datafusion::error::DataFusionError;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::insert::DataSink;
use datafusion::physical_plan::metrics::MetricsSet;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, SendableRecordBatchStream};
use futures::StreamExt;
use object_store::path::Path as ObjectPath;
use tokio::task::JoinSet;

use crate::common::sink::write::demux::start_demuxer_task;

/// A data sink factory used to create a sink for a given path.
pub trait SinkProducer: std::fmt::Debug + Send + Sync {
fn create_sink(&self, path: ObjectPath) -> Box<dyn DataSink>;
}

/// A data sink that takes a stream of record batches and writes them to a hive-partitioned
/// directory structure. Delegating creation of underlying sinks to a `SinkProducer`.
#[derive(Debug)]
pub struct HivePartitionedSinkAdapter<S: SinkProducer> {
producer: S,
partition_columns: Vec<String>,
base_output_path: ObjectPath,
file_extension: String,
schema: Arc<Schema>,
}

impl<S: SinkProducer> HivePartitionedSinkAdapter<S> {
pub fn new(
producer: S,
partition_columns: Vec<String>,
base_output_path: ObjectPath,
file_extension: String,
schema: Arc<Schema>,
) -> Self {
HivePartitionedSinkAdapter {
producer,
partition_columns,
base_output_path,
file_extension,
schema,
}
}
}

impl<S: SinkProducer> fmt::Display for HivePartitionedSinkAdapter<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SinkPartitioner")
}
}

impl<S: SinkProducer> DisplayAs for HivePartitionedSinkAdapter<S> {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default => write!(f, "{self}"),
DisplayFormatType::Verbose => write!(f, "{self}"),
}
}
}

#[async_trait]
impl<S: SinkProducer + 'static> DataSink for HivePartitionedSinkAdapter<S> {
fn as_any(&self) -> &dyn Any {
self
}

fn metrics(&self) -> Option<MetricsSet> {
None
}

async fn write_all(
&self,
stream: SendableRecordBatchStream,
context: &Arc<TaskContext>,
) -> DfResult<u64> {
if self.partition_columns.is_empty() {
let sink = self.producer.create_sink(self.base_output_path.clone());
return sink.write_all(stream, context).await;
}

let utf8_schema = cast_schema_to_utf8(&self.schema, &self.partition_columns)?;
let column_types = get_columns_with_types(&utf8_schema, self.partition_columns.clone())?;

let utf8_schema_inner = utf8_schema.clone();
let partition_columns = self.partition_columns.clone();

let utf8_stream = stream.map(move |batch_result| {
if let Ok(batch) = batch_result {
let casted_batch = cast_record_batch_to_utf8(
&batch,
&partition_columns,
utf8_schema_inner.clone(),
)?;
Ok(casted_batch)
} else {
batch_result
}
});

let utf8_rb_stream = Box::pin(RecordBatchStreamAdapter::new(utf8_schema, utf8_stream));

let (demux_task, mut file_stream_rx) = start_demuxer_task(
utf8_rb_stream,
context,
Some(column_types),
self.base_output_path.clone(),
self.file_extension.clone(),
);

let mut sink_write_tasks: JoinSet<DfResult<usize>> = JoinSet::new();
let writer_schema = remove_partition_columns(&self.schema, &self.partition_columns);

while let Some((path, mut rx)) = file_stream_rx.recv().await {
let ctx = context.clone();
let sink = self.producer.create_sink(path);

let stream = async_stream::stream! {
while let Some(item) = rx.recv().await {
yield Ok(item);
}
};

let rb_stream = Box::pin(RecordBatchStreamAdapter::new(writer_schema.clone(), stream));

sink_write_tasks.spawn(async move {
sink.write_all(rb_stream, &ctx)
.await
.map(|row_count| row_count as usize)
});
}

let mut row_count = 0;

while let Some(result) = sink_write_tasks.join_next().await {
match result {
Ok(r) => {
row_count += r?;
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

match demux_task.await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}

Ok(row_count as u64)
}
}

/// Get the partition columns with their types from the schema.
pub fn get_columns_with_types(
schema: &Schema,
columns: Vec<String>,
) -> DfResult<Vec<(String, DataType)>> {
columns
.iter()
.map(|col| {
schema
.field_with_name(col)
.map(|field| (field.name().to_owned(), field.data_type().to_owned()))
.map_err(|e| DataFusionError::External(Box::new(e)))
})
.collect()
}

// Keeping this somewhat conservative for now.
//
// (For more involved types like timestamps & floats
// casting these to strings which are ultimately used as
// file paths could be problematic because of
// special characters, precision loss etc).
fn supported_partition_column_type(data_type: &DataType) -> bool {
matches!(data_type, |DataType::Boolean| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Utf8)
}

fn cast_record_batch_to_utf8(
batch: &datafusion::arrow::record_batch::RecordBatch,
partition_columns: &Vec<String>,
schema: Arc<Schema>,
) -> DfResult<datafusion::arrow::record_batch::RecordBatch> {
let mut columns = batch.columns().to_vec();

for column_name in partition_columns {
let col_index = batch.schema().index_of(column_name).unwrap();
let casted_array = cast(batch.column(col_index).as_ref(), &DataType::Utf8)?;
columns[col_index] = casted_array;
}

let casted_batch = RecordBatch::try_new(schema, columns)?;

Ok(casted_batch)
}

fn cast_schema_to_utf8(schema: &Schema, partition_columns: &[String]) -> DfResult<Arc<Schema>> {
let mut fields = schema.fields().to_vec();

for column_name in partition_columns.iter() {
let idx = schema.index_of(column_name)?;

let data_type = fields[idx].data_type();

if data_type == &DataType::Utf8 {
continue;
} else if !supported_partition_column_type(data_type) {
return Err(DataFusionError::Execution(
format!("Partition column of type '{data_type}' is not supported").to_string(),
));
}

let casted_field = Field::new(column_name, DataType::Utf8, fields[idx].is_nullable());
fields[idx] = Arc::new(casted_field);
}

Ok(Arc::new(Schema::new_with_metadata(
fields,
schema.metadata().clone(),
)))
}

fn remove_partition_columns(schema: &Schema, partition_columns: &[String]) -> Arc<Schema> {
let filtered_schema = Arc::new(Schema::new(
schema
.fields()
.iter()
.filter(|f| !partition_columns.contains(f.name()))
.map(|f| (**f).clone())
.collect::<Vec<_>>(),
));

filtered_schema
}
24 changes: 24 additions & 0 deletions crates/datasources/src/common/sink/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use object_store::path::Path as ObjectPath;
use object_store::ObjectStore;
use tokio::io::{AsyncWrite, AsyncWriteExt};

use super::hive_partitioning::SinkProducer;
use super::SharedBuffer;
use crate::common::errors::Result;

Expand Down Expand Up @@ -164,3 +165,26 @@ impl<W: AsyncWrite + Unpin + Send, F: JsonFormat> AsyncJsonWriter<W, F> {
Ok(())
}
}


#[derive(Debug)]
pub struct JsonSinkProducer {
store: Arc<dyn ObjectStore>,
opts: JsonSinkOpts,
}

impl JsonSinkProducer {
pub fn from_obj_store(store: Arc<dyn ObjectStore>, opts: JsonSinkOpts) -> Self {
JsonSinkProducer { store, opts }
}

pub fn create_sink(&self, loc: impl Into<ObjectPath>) -> JsonSink {
JsonSink::from_obj_store(self.store.clone(), loc, self.opts.clone())
}
}

impl SinkProducer for JsonSinkProducer {
fn create_sink(&self, loc: ObjectPath) -> Box<dyn DataSink> {
Box::new(self.create_sink(loc))
}
}
3 changes: 3 additions & 0 deletions crates/datasources/src/common/sink/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
pub mod hive_partitioning;
mod write;

pub mod bson;
pub mod csv;
pub mod json;
Expand Down
25 changes: 25 additions & 0 deletions crates/datasources/src/common/sink/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ use futures::StreamExt;
use object_store::path::Path as ObjectPath;
use object_store::ObjectStore;

use super::hive_partitioning::SinkProducer;

const BUFFER_SIZE: usize = 8 * 1024 * 1024;


#[derive(Debug, Clone)]
pub struct ParquetSinkOpts {
melbourne2991 marked this conversation as resolved.
Show resolved Hide resolved
pub row_group_size: usize,
Expand Down Expand Up @@ -111,3 +114,25 @@ impl DataSink for ParquetSink {
self.stream_into_inner(data).await.map(|x| x as u64)
}
}

#[derive(Debug)]
pub struct ParquetSinkProducer {
store: Arc<dyn ObjectStore>,
opts: ParquetSinkOpts,
}

impl ParquetSinkProducer {
pub fn from_obj_store(store: Arc<dyn ObjectStore>, opts: ParquetSinkOpts) -> Self {
ParquetSinkProducer { store, opts }
}

pub fn create_sink(&self, loc: impl Into<ObjectPath>) -> ParquetSink {
ParquetSink::from_obj_store(self.store.clone(), loc, self.opts.clone())
}
}

impl SinkProducer for ParquetSinkProducer {
fn create_sink(&self, loc: ObjectPath) -> Box<dyn DataSink> {
Box::new(self.create_sink(loc))
}
}
Loading