Skip to content

Commit

Permalink
Support unparsing Array plan to SQL string (#13418)
Browse files Browse the repository at this point in the history
* unparse construct and access

* add sql e2e roundtrip

* remove unused tests

* fix test and clippy

* fix clippy
  • Loading branch information
goldmedal authored Nov 17, 2024
1 parent 61fa572 commit b75563b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
1 change: 0 additions & 1 deletion datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ impl ExprPlanner for NestedFunctionPlanner {

#[derive(Debug)]
pub struct FieldAccessPlanner;

impl ExprPlanner for FieldAccessPlanner {
fn plan_field_access(
&self,
Expand Down
47 changes: 45 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
use datafusion_expr::expr::Unnest;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName,
TimezoneInfo, UnaryOperator,
self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName,
Subscript, TimezoneInfo, UnaryOperator,
};
use std::sync::Arc;
use std::vec;
Expand Down Expand Up @@ -476,6 +476,19 @@ impl Unparser<'_> {
&self,
func_name: &str,
args: &[Expr],
) -> Result<ast::Expr> {
match func_name {
"make_array" => self.make_array_to_sql(args),
"array_element" => self.array_element_to_sql(args),
// TODO: support for the construct and access functions of the `map` and `struct` types
_ => self.scalar_function_to_sql_internal(func_name, args),
}
}

fn scalar_function_to_sql_internal(
&self,
func_name: &str,
args: &[Expr],
) -> Result<ast::Expr> {
let args = self.function_args_to_sql(args)?;
Ok(ast::Expr::Function(Function {
Expand All @@ -496,6 +509,29 @@ impl Unparser<'_> {
}))
}

fn make_array_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
let args = args
.iter()
.map(|e| self.expr_to_sql(e))
.collect::<Result<Vec<_>>>()?;
Ok(ast::Expr::Array(Array {
elem: args,
named: false,
}))
}

fn array_element_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
if args.len() != 2 {
return internal_err!("array_element must have exactly 2 arguments");
}
let array = self.expr_to_sql(&args[0])?;
let index = self.expr_to_sql(&args[1])?;
Ok(ast::Expr::Subscript {
expr: Box::new(array),
subscript: Box::new(Subscript::Index { index }),
})
}

pub fn sort_to_sql(&self, sort: &Sort) -> Result<ast::OrderByExpr> {
let Sort {
expr,
Expand Down Expand Up @@ -1485,6 +1521,7 @@ mod tests {
use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_nested::expr_fn::{array_element, make_array};
use datafusion_functions_window::row_number::row_number_udwf;

use crate::unparser::dialect::{
Expand Down Expand Up @@ -1889,6 +1926,12 @@ mod tests {
}),
r#"UNNEST("table".array_col)"#,
),
(make_array(vec![lit(1), lit(2), lit(3)]), "[1, 2, 3]"),
(array_element(col("array_col"), lit(1)), "array_col[1]"),
(
array_element(make_array(vec![lit(1), lit(2), lit(3)]), lit(1)),
"[1, 2, 3][1]",
),
];

for (expr, expected) in tests {
Expand Down
15 changes: 13 additions & 2 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ use datafusion_expr::builder::{
table_scan_with_filter_and_fetch, table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
use datafusion_functions_nested::planner::{FieldAccessPlanner, NestedFunctionPlanner};
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;

Expand Down Expand Up @@ -182,6 +184,11 @@ fn roundtrip_statement() -> Result<()> {
SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total
FROM person
GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#,
"SELECT ARRAY[1, 2, 3]",
"SELECT ARRAY[1, 2, 3][1]",
"SELECT [1, 2, 3]",
"SELECT [1, 2, 3][1]",
"SELECT left[1] FROM array"
];

// For each test sql string, we transform as follows:
Expand All @@ -195,10 +202,14 @@ fn roundtrip_statement() -> Result<()> {
.try_with_sql(query)?
.parse_statement()?;
let state = MockSessionState::default()
.with_scalar_function(make_array_udf())
.with_scalar_function(array_element_udf())
.with_aggregate_function(sum_udaf())
.with_aggregate_function(count_udaf())
.with_aggregate_function(max_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()))
.with_expr_planner(Arc::new(NestedFunctionPlanner))
.with_expr_planner(Arc::new(FieldAccessPlanner));
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
Expand Down Expand Up @@ -1239,6 +1250,6 @@ fn test_unnest_to_sql() {
sql_round_trip(
GenericDialect {},
r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#,
r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#,
r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"#,
);
}

0 comments on commit b75563b

Please sign in to comment.