Skip to content

Commit

Permalink
refactor sort exec stream and combine batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 6, 2021
1 parent b84789a commit 5a0425b
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 145 deletions.
96 changes: 85 additions & 11 deletions datafusion/src/physical_plan/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,22 @@

//! Defines common code used in execution plans

use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;
use arrow::compute::concat;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;

use crate::arrow::error::ArrowError;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;

use super::{RecordBatchStream, SendableRecordBatchStream};

/// Stream of record batches
pub struct SizedRecordBatchStream {
schema: SchemaRef,
Expand Down Expand Up @@ -83,6 +81,32 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatc
.map_err(DataFusionError::from)
}

/// Combine a slice of record batches into one, or returns None if the slice itself
/// is empty; all the record batches inside the slice must be of the same schema.
pub(crate) fn combine_batches(
batches: &[RecordBatch],
schema: SchemaRef,
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
Ok(None)
} else {
let columns = schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<_>>>()?;
Ok(Some(RecordBatch::try_new(schema.clone(), columns)?))
}
}

/// Recursively builds a list of files in a directory with a given extension
pub fn build_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
let mut filenames: Vec<String> = Vec::new();
Expand Down Expand Up @@ -144,3 +168,53 @@ pub(crate) fn spawn_execution(
}
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Float32Array, Float64Array},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};

#[test]
fn test_combine_batches_empty() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));
let result = combine_batches(&[], schema)?;
assert!(result.is_none());
Ok(())
}

#[test]
fn test_combine_batches() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));

let batch_count = 1000;
let batch_size = 10;
let batches = (0..batch_count)
.map(|i| {
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Float32Array::from(vec![i as f32; batch_size])),
Arc::new(Float64Array::from(vec![i as f64; batch_size])),
],
)
.unwrap()
})
.collect::<Vec<_>>();

let result = combine_batches(&batches, schema)?;
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(batch_count * batch_size, result.num_rows());
Ok(())
}
}
87 changes: 32 additions & 55 deletions datafusion/src/physical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,28 @@

//! Defines the SORT plan

use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;

use pin_project_lite::pin_project;

pub use arrow::compute::SortOptions;
use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::{
common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric,
};
pub use arrow::compute::SortOptions;
use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};
use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;
use pin_project_lite::pin_project;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

/// Sort execution plan
#[derive(Debug)]
Expand Down Expand Up @@ -190,47 +186,25 @@ impl ExecutionPlan for SortExec {
}
}

fn sort_batches(
batches: &[RecordBatch],
schema: &SchemaRef,
fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
return Ok(None);
}
// combine all record batches into one for each column
let combined_batch = RecordBatch::try_new(
schema.clone(),
schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
)?;

// sort combined record batch
) -> ArrowResult<RecordBatch> {
// TODO: pushup the limit expression to sort
let indices = lexsort_to_indices(
&expr
.iter()
.map(|e| e.evaluate_to_sort_column(&combined_batch))
.map(|e| e.evaluate_to_sort_column(&batch))
.collect::<Result<Vec<SortColumn>>>()
.map_err(DataFusionError::into_arrow_external_error)?,
None,
)?;

// reorder all rows based on sorted indices
let sorted_batch = RecordBatch::try_new(
schema.clone(),
combined_batch
RecordBatch::try_new(
schema,
batch
.columns()
.iter()
.map(|column| {
Expand All @@ -245,8 +219,7 @@ fn sort_batches(
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
);
sorted_batch.map(Some)
)
}

pin_project! {
Expand All @@ -268,7 +241,6 @@ impl SortStream {
sort_time: Arc<SQLMetric>,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();

let schema = input.schema();
tokio::spawn(async move {
let schema = input.schema();
Expand All @@ -277,9 +249,14 @@ impl SortStream {
.map_err(DataFusionError::into_arrow_external_error)
.and_then(move |batches| {
let now = Instant::now();
let result = sort_batches(&batches, &schema, &expr);
// combine all record batches into one for each column
let combined = common::combine_batches(&batches, schema.clone())?;
// sort combined record batch
let result = combined
.map(|batch| sort_batch(batch, schema, &expr))
.transpose()?;
sort_time.add(now.elapsed().as_nanos() as usize);
result
Ok(result)
});

tx.send(sorted_batch)
Expand Down
Loading

0 comments on commit 5a0425b

Please sign in to comment.