Skip to content

Commit

Permalink
Add SessionContext::record_batches (#9197)
Browse files Browse the repository at this point in the history
* feat: issue #9157 adding record_batches for Vec<BatchRecord>

* fix bugs

* optimize code and tests

* optimize test

* optimize tests

* abandon useless schema

* collect into a single batches
  • Loading branch information
Lordworms authored Feb 14, 2024
1 parent 196b718 commit 61e9605
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
25 changes: 24 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
optimizer::optimizer::Optimizer,
physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
};
use arrow_schema::Schema;
use datafusion_common::{
alias::AliasGenerator,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
Expand Down Expand Up @@ -934,7 +935,29 @@ impl SessionContext {
.build()?,
))
}

/// Create a [`DataFrame`] for reading a [`Vec[`RecordBatch`]`]
pub fn read_batches(
&self,
batches: impl IntoIterator<Item = RecordBatch>,
) -> Result<DataFrame> {
// check schema uniqueness
let mut batches = batches.into_iter().peekable();
let schema = if let Some(batch) = batches.peek() {
batch.schema().clone()
} else {
Arc::new(Schema::empty())
};
let provider = MemTable::try_new(schema, vec![batches.collect()])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
/// Registers a [`ListingTable`] that can assemble multiple files
/// from locations in an [`ObjectStore`] instance into a single
/// table.
Expand Down
66 changes: 66 additions & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::{
},
record_batch::RecordBatch,
};
use arrow_array::Float32Array;
use arrow_schema::ArrowError;
use std::sync::Arc;

Expand Down Expand Up @@ -1431,6 +1432,71 @@ async fn unnest_analyze_metrics() -> Result<()> {

Ok(())
}
#[tokio::test]
async fn test_read_batches() -> Result<()> {
let config = SessionConfig::new();
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime);
let ctx = SessionContext::new_with_state(state);

let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("number", DataType::Float32, false),
]));

let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])),
],
)
.unwrap(),
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![3, 4, 5])),
Arc::new(Float32Array::from(vec![1.11, 2.22, 3.33])),
],
)
.unwrap(),
];
let df = ctx.read_batches(batches).unwrap();
df.clone().show().await.unwrap();
let result = df.collect().await?;
let expected = [
"+----+--------+",
"| id | number |",
"+----+--------+",
"| 1 | 1.12 |",
"| 2 | 3.4 |",
"| 3 | 2.33 |",
"| 4 | 9.1 |",
"| 5 | 6.66 |",
"| 3 | 1.11 |",
"| 4 | 2.22 |",
"| 5 | 3.33 |",
"+----+--------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn test_read_batches_empty() -> Result<()> {
let config = SessionConfig::new();
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime);
let ctx = SessionContext::new_with_state(state);

let batches = vec![];
let df = ctx.read_batches(batches).unwrap();
df.clone().show().await.unwrap();
let result = df.collect().await?;
let expected = ["++", "++"];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn consecutive_projection_same_schema() -> Result<()> {
Expand Down

0 comments on commit 61e9605

Please sign in to comment.