Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix projection name with DataFrame::with_column and window functions #12000

Merged
merged 9 commits into from
Aug 17, 2024
48 changes: 44 additions & 4 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,21 +1441,27 @@ impl DataFrame {
/// ```
pub fn with_column(self, name: &str, expr: Expr) -> Result<DataFrame> {
let window_func_exprs = find_window_exprs(&[expr.clone()]);
let plan = if window_func_exprs.is_empty() {
self.plan

let (plan, mut col_exists, window_func) = if window_func_exprs.is_empty() {
(self.plan, false, false)
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
(
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?,
true,
true,
)
};

let new_column = expr.alias(name);
let mut col_exists = false;
let mut fields: Vec<Expr> = plan
.schema()
.iter()
.map(|(qualifier, field)| {
if field.name() == name {
col_exists = true;
new_column.clone()
} else if window_func && qualifier.is_none() {
col(Column::from((qualifier, field))).alias(name)
} else {
col(Column::from((qualifier, field)))
}
Expand Down Expand Up @@ -1703,6 +1709,7 @@ mod tests {
use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
Expand Down Expand Up @@ -2869,6 +2876,39 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_window_function_with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
let func = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::RowNumber,
),
vec![],
))
.alias("row_num");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use the expr fn here and make this more concise:

Suggested change
let func = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::RowNumber,
),
vec![],
))
.alias("row_num");
let func = row_number().alias("row_num");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also ran this test without the code changes and it fails like this:


assertion `left == right` failed
  left: 4
 right: 5

Left:  4
Right: 5
<Click to see difference>

thread 'dataframe::tests::test_window_function_with_column' panicked at datafusion/core/src/dataframe/mod.rs:2882:9:
assertion `left == right` failed
  left: 4
 right: 5
stack backtrace:
   0: rust_begin_unwind
             at /rustc/3f5fd8dd41153bc5fdca9427e9e05be2c767ba23/library/std/src/panicking.rs:652:5
...

And the output was like

[
    "+----+----+-----+-----------------------------------------------------------------------+---+",
    "| c1 | c2 | c3  | ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING | r |",
    "+----+----+-----+-----------------------------------------------------------------------+---+",
    "| c  | 2  | 1   | 1                                                                     | 1 |",
    "| d  | 5  | -40 | 2                                                                     | 2 |",
    "+----+----+-----+-----------------------------------------------------------------------+---+",
]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good I went ahead and used the more concise method call. Thanks!


// Should create an additional column with alias 'r' that has window func results
let df = df_impl.with_column("r", func)?.limit(0, Some(2))?;
assert_eq!(4, df.schema().fields().len());

let df_results = df.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+---+",
"| c1 | c2 | c3 | r |",
"+----+----+-----+---+",
"| c | 2 | 1 | 1 |",
"| d | 5 | -40 | 2 |",
"+----+----+-----+---+",
],
&df_results
);

Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/7790
// The join operation outputs two identical column names, but they belong to different relations.
#[tokio::test]
Expand Down