Skip to content

Commit

Permalink
Push SessionState into FileFormat (apache#4349) (apache#4699)
Browse files Browse the repository at this point in the history
* Push SessionState into FileFormat (apache#4349)

* Rename ctx to state

* More renames
  • Loading branch information
tustvold authored Dec 22, 2022
1 parent 4917235 commit c9d6118
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 199 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ impl TableProvider for DataFrame {

async fn scan(
&self,
_ctx: &SessionState,
_state: &SessionState,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait TableProvider: Sync + Send {
/// parallelized or distributed.
async fn scan(
&self,
ctx: &SessionState,
state: &SessionState,
projection: Option<&Vec<usize>>,
filters: &[Expr],
// limit can be used to reduce the amount scanned
Expand Down Expand Up @@ -94,7 +94,7 @@ pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider with the given url
async fn create(
&self,
ctx: &SessionState,
state: &SessionState,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>>;
}
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl TableProvider for EmptyTable {

async fn scan(
&self,
_ctx: &SessionState,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
59 changes: 38 additions & 21 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use object_store::{GetResult, ObjectMeta, ObjectStore};
use super::FileFormat;
use crate::avro_to_arrow::read_avro_schema_from_reader;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{AvroExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand All @@ -47,6 +48,7 @@ impl FileFormat for AvroFormat {

async fn infer_schema(
&self,
_state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand All @@ -68,6 +70,7 @@ impl FileFormat for AvroFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -77,6 +80,7 @@ impl FileFormat for AvroFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -101,10 +105,11 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let task_ctx = ctx.task_ctx();
let session_ctx = SessionContext::with_config(config);
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches = stream
Expand All @@ -124,9 +129,10 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, Some(1)).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(11, batches[0].num_columns());
Expand All @@ -138,9 +144,10 @@ mod tests {
#[tokio::test]
async fn read_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -190,9 +197,10 @@ mod tests {
#[tokio::test]
async fn read_bool_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![1]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -216,9 +224,10 @@ mod tests {
#[tokio::test]
async fn read_i32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -239,9 +248,10 @@ mod tests {
#[tokio::test]
async fn read_i96_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![10]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -262,9 +272,10 @@ mod tests {
#[tokio::test]
async fn read_f32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![6]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -288,9 +299,10 @@ mod tests {
#[tokio::test]
async fn read_f64_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![7]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -314,9 +326,10 @@ mod tests {
#[tokio::test]
async fn read_binary_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![9]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -338,14 +351,15 @@ mod tests {
}

async fn get_exec(
state: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let testdata = crate::test_util::arrow_test_data();
let store_root = format!("{}/avro", testdata);
let format = AvroFormat {};
scan_format(&format, &store_root, file_name, projection, limit).await
scan_format(state, &format, &store_root, file_name, projection, limit).await
}
}

Expand All @@ -356,13 +370,16 @@ mod tests {

use super::super::test_util::scan_format;
use crate::error::DataFusionError;
use crate::prelude::SessionContext;

#[tokio::test]
async fn test() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let format = AvroFormat {};
let testdata = crate::test_util::arrow_test_data();
let filename = "avro/alltypes_plain.avro";
let result = scan_format(&format, &testdata, filename, None, None).await;
let result = scan_format(&state, &format, &testdata, filename, None, None).await;
assert!(matches!(
result,
Err(DataFusionError::NotImplemented(msg))
Expand Down
26 changes: 19 additions & 7 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use super::FileFormat;
use crate::datasource::file_format::file_type::FileCompressionType;
use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -113,6 +114,7 @@ impl FileFormat for CsvFormat {

async fn infer_schema(
&self,
_state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand Down Expand Up @@ -150,6 +152,7 @@ impl FileFormat for CsvFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -159,6 +162,7 @@ impl FileFormat for CsvFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -184,11 +188,12 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let session_ctx = SessionContext::with_config(config);
let state = session_ctx.state();
let task_ctx = state.task_ctx();
// skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work)
let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let task_ctx = ctx.task_ctx();
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
Expand All @@ -212,9 +217,11 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0, 1, 2, 3]);
let exec = get_exec("aggregate_test_100.csv", projection, Some(1)).await?;
let exec =
get_exec(&state, "aggregate_test_100.csv", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(4, batches[0].num_columns());
Expand All @@ -225,8 +232,11 @@ mod tests {

#[tokio::test]
async fn infer_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();

let projection = None;
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -259,9 +269,10 @@ mod tests {
#[tokio::test]
async fn read_char_column() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;

let batches = collect(exec, task_ctx).await.expect("Collect batches");

Expand All @@ -281,12 +292,13 @@ mod tests {
}

async fn get_exec(
state: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let root = format!("{}/csv", crate::test_util::arrow_test_data());
let format = CsvFormat::default();
scan_format(&format, &root, file_name, projection, limit).await
scan_format(state, &format, &root, file_name, projection, limit).await
}
}
Loading

0 comments on commit c9d6118

Please sign in to comment.