Skip to content

Commit

Permalink
fix NamedStructField should be rewritten in OperatorToFunction in sub…
Browse files Browse the repository at this point in the history
…query (#10103)
  • Loading branch information
alamb authored Apr 17, 2024
1 parent 1fa25ae commit 9974cee
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 40 deletions.
129 changes: 89 additions & 40 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{Expr, LogicalPlan, Subquery};
use std::sync::Arc;

/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
Expand All @@ -45,52 +46,66 @@ impl AnalyzerRule for ApplyFunctionRewrites {
}

fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
self.analyze_internal(&plan, options)
analyze_internal(&plan, &self.function_rewrites, options)
}
}

impl ApplyFunctionRewrites {
fn analyze_internal(
&self,
plan: &LogicalPlan,
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.analyze_internal(p, options))
.collect::<Result<Vec<_>>>()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());

if let LogicalPlan::TableScan(ts) = plan {
let source_schema =
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
schema.merge(&source_schema);
}
fn analyze_internal(
plan: &LogicalPlan,
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(p, function_rewrites, options))
.collect::<Result<Vec<_>>>()?;

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites: &self.function_rewrites,
options,
schema: &schema,
};
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
if let LogicalPlan::TableScan(ts) = plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
)?;
schema.merge(&source_schema);
}

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites,
options,
schema: &schema,
};

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
}

fn rewrite_subquery(
mut subquery: Subquery,
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
options: &ConfigOptions,
) -> Result<Subquery> {
subquery.subquery = Arc::new(analyze_internal(
&subquery.subquery,
function_rewrites,
options,
)?);
Ok(subquery)
}

struct OperatorToFunctionRewriter<'a> {
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
options: &'a ConfigOptions,
Expand All @@ -111,6 +126,40 @@ impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
expr = result.data
}

// recurse into subqueries if needed
let expr = match expr {
Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?),

Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists {
subquery: rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?,
negated,
}),

Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => Expr::InSubquery(InSubquery {
expr,
subquery: rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?,
negated,
}),

expr => expr,
};

Ok(if transformed {
Transformed::yes(expr)
} else {
Expand Down
55 changes: 55 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,58 @@ logical_plan
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
----TableScan: t projection=[a]

###
## Ensure that operators are rewritten in subqueries
###

statement ok
create table foo(x int) as values (1);

# Show input data
query ?
select struct(1, 'b')
----
{c0: 1, c1: b}


query T
select (select struct(1, 'b')['c1']);
----
b

query T
select 'foo' || (select struct(1, 'b')['c1']);
----
foob

query I
SELECT * FROM (VALUES (1), (2))
WHERE column1 IN (SELECT struct(1, 'b')['c0']);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM (VALUES (1), (2))
WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
----
1


query I
SELECT * FROM foo
WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM foo
WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1


statement ok
drop table foo;

0 comments on commit 9974cee

Please sign in to comment.