Skip to content

Commit

Permalink
handle 0 and NULL value of NTH_VALUE function (#12676)
Browse files Browse the repository at this point in the history
* handle 0 and NULL value of NTH_VALUE function

* use exec_err

* cargo fmt
  • Loading branch information
thinh2 authored Oct 1, 2024
1 parent 2c2e0e7 commit f54712d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
29 changes: 10 additions & 19 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::PhysicalExpr;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use datafusion_common::{exec_err, ScalarValue};
use datafusion_common::ScalarValue;
use datafusion_expr::window_state::WindowAggState;
use datafusion_expr::PartitionEvaluator;

Expand Down Expand Up @@ -86,16 +86,13 @@ impl NthValue {
n: i64,
ignore_nulls: bool,
) -> Result<Self> {
match n {
0 => exec_err!("NTH_VALUE expects n to be non-zero"),
_ => Ok(Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Nth(n),
ignore_nulls,
}),
}
Ok(Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Nth(n),
ignore_nulls,
})
}

/// Get the NTH_VALUE kind
Expand Down Expand Up @@ -188,10 +185,7 @@ impl PartitionEvaluator for NthValueEvaluator {
// Negative index represents reverse direction.
(n_range >= reverse_index, true)
}
Ordering::Equal => {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
Ordering::Equal => (true, false),
}
}
};
Expand Down Expand Up @@ -298,10 +292,7 @@ impl PartitionEvaluator for NthValueEvaluator {
)
}
}
Ordering::Equal => {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
Ordering::Equal => ScalarValue::try_from(arr.data_type()),
}
}
}
Expand Down
18 changes: 12 additions & 6 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,26 @@ fn get_scalar_value_from_args(
}

fn get_signed_integer(value: ScalarValue) -> Result<i64> {
if value.is_null() {
return Ok(0);
}

if !value.data_type().is_integer() {
return Err(DataFusionError::Execution(
"Expected an integer value".to_string(),
));
return exec_err!("Expected an integer value");
}

value.cast_to(&DataType::Int64)?.try_into()
}

fn get_unsigned_integer(value: ScalarValue) -> Result<u64> {
if value.is_null() {
return Ok(0);
}

if !value.data_type().is_integer() {
return Err(DataFusionError::Execution(
"Expected an integer value".to_string(),
));
return exec_err!("Expected an integer value");
}

value.cast_to(&DataType::UInt64)?.try_into()
}

Expand Down
39 changes: 39 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4894,3 +4894,42 @@ NULL a4 5

statement ok
drop table t

## test handle NULL and 0 value of nth_value
statement ok
CREATE TABLE t(v1 int, v2 int);

statement ok
INSERT INTO t VALUES (1,1), (1,2),(1,3),(2,1),(2,2);

query II
SELECT v1, NTH_VALUE(v2, null) OVER (PARTITION BY v1 ORDER BY v2) FROM t;
----
1 NULL
1 NULL
1 NULL
2 NULL
2 NULL

query II
SELECT v1, NTH_VALUE(v2, v2*null) OVER (PARTITION BY v1 ORDER BY v2) FROM t;
----
1 NULL
1 NULL
1 NULL
2 NULL
2 NULL

query II
SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t;
----
1 NULL
1 NULL
1 NULL
2 NULL
2 NULL

statement ok
DROP TABLE t;

## end test handle NULL and 0 of NTH_VALUE

0 comments on commit f54712d

Please sign in to comment.