From c533c8bf1f12076aa63e5adf89bb4c76ff56d045 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Wed, 25 Sep 2024 03:31:55 -0700 Subject: [PATCH] Use value when comparing with dictionary column (#12610) --- datafusion/sql/src/unparser/expr.rs | 49 +++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 516833a39f1e2..3b9284fb5dd21 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -219,12 +219,24 @@ impl Unparser<'_> { } Expr::Cast(Cast { expr, data_type }) => { let inner_expr = self.expr_to_sql_inner(expr)?; - Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }) + match data_type { + DataType::Dictionary(_, _) => match inner_expr { + // Dictionary values don't need to be cast to other types when rewritten back to sql + ast::Expr::Value(_) => Ok(inner_expr), + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + }, + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + } } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), @@ -2361,4 +2373,29 @@ mod tests { } Ok(()) } + + #[test] + fn test_cast_value_to_dict_expr() { + let tests = [( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "variation".to_string(), + )))), + data_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + }), + "'variation'", + )]; + for (value, expected) in tests { + let dialect = CustomDialectBuilder::new().build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } }