Skip to content

Commit

Permalink
[SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT N…
Browse files Browse the repository at this point in the history
…ULLS ]

### What changes were proposed in this pull request?
The mainstream database support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG`/`NTH_VALUE`/`FIRST_VALUE`/`LAST_VALUE`.
But the current implement of `LEAD`/`LAG` don't support this syntax.

**Oracle**
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/LEAD.html#GUID-0A0481F1-E98F-4535-A739-FCCA8D1B5B77

**Presto**
https://prestodb.io/docs/current/functions/window.html

**Redshift**
https://docs.aws.amazon.com/redshift/latest/dg/r_WF_LEAD.html

**DB2**
https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1513.htm

**Teradata**
https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/GjCT6l7trjkIEjt~7Dhx4w

**Snowflake**
https://docs.snowflake.com/en/sql-reference/functions/lead.html
https://docs.snowflake.com/en/sql-reference/functions/lag.html

### Why are the changes needed?
Support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG` is very useful.

### Does this PR introduce _any_ user-facing change?
'Yes'.

### How was this patch tested?
Jenkins test.

Closes #30387 from beliefer/SPARK-33443.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: beliefer <beliefer@163.com>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Dec 24, 2020
1 parent 32d4a2b commit 3e9821e
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ abstract class FrameLessOffsetWindowFunction

override def nullable: Boolean = default == null || default.nullable || input.nullable

override val ignoreNulls = false

override lazy val frame: WindowFrame = fakeFrame

override def checkInputDataTypes(): TypeCheckResult = {
Expand Down Expand Up @@ -443,9 +441,13 @@ abstract class FrameLessOffsetWindowFunction
since = "2.0.0",
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Lead(input: Expression, offset: Expression, default: Expression)
case class Lead(
input: Expression, offset: Expression, default: Expression, ignoreNulls: Boolean)
extends FrameLessOffsetWindowFunction {

def this(input: Expression, offset: Expression, default: Expression) =
this(input, offset, default, false)

def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))

def this(input: Expression) = this(input, Literal(1))
Expand Down Expand Up @@ -485,10 +487,14 @@ case class Lead(input: Expression, offset: Expression, default: Expression)
since = "2.0.0",
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Lag(input: Expression, inputOffset: Expression, default: Expression)
case class Lag(
input: Expression, inputOffset: Expression, default: Expression, ignoreNulls: Boolean)
extends FrameLessOffsetWindowFunction {

def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))
def this(input: Expression, inputOffset: Expression, default: Expression) =
this(input, inputOffset, default, false)

def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null))

def this(input: Expression) = this(input, Literal(1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,21 @@ trait WindowExecBase extends UnaryExecNode {
* [[WindowExpression]]s and factory function for the [[WindowFrameFunction]].
*/
protected lazy val windowFrameExpressionFactoryPairs = {
type FrameKey = (String, FrameType, Expression, Expression)
type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression])
type ExpressionBuffer = mutable.Buffer[Expression]
val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]

// Add a function and its function to the map for a given frame.
def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
val key = (tpe, fr.frameType, fr.lower, fr.upper)
val key = fn match {
// This branch is used for Lead/Lag to support ignoring null.
// All window frames move in rows. If there are multiple Leads or Lags acting on a row
// and operating on different input expressions, they should not be moved uniformly
// by row. Therefore, we put these functions in different window frames.
case f: FrameLessOffsetWindowFunction if f.ignoreNulls =>
(tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized))
case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil)
}
val (es, fns) = framedFunctions.getOrElseUpdate(
key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
es += e
Expand Down Expand Up @@ -183,7 +191,7 @@ trait WindowExecBase extends UnaryExecNode {
// Create the factory to produce WindowFunctionFrame.
val factory = key match {
// Frameless offset Frame
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _) =>
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) =>
target: InternalRow =>
new FrameLessOffsetWindowFunctionFrame(
target,
Expand All @@ -193,8 +201,9 @@ trait WindowExecBase extends UnaryExecNode {
child.output,
(expressions, schema) =>
MutableProjection.create(expressions, schema),
offset)
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _) =>
offset,
expr.nonEmpty)
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, _) =>
target: InternalRow => {
new UnboundedOffsetWindowFunctionFrame(
target,
Expand All @@ -206,7 +215,7 @@ trait WindowExecBase extends UnaryExecNode {
MutableProjection.create(expressions, schema),
offset)
}
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _) =>
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, _) =>
target: InternalRow => {
new UnboundedPrecedingOffsetWindowFunctionFrame(
target,
Expand All @@ -220,13 +229,13 @@ trait WindowExecBase extends UnaryExecNode {
}

// Entire Partition Frame.
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) =>
target: InternalRow => {
new UnboundedWindowFunctionFrame(target, processor)
}

// Growing Frame.
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) =>
target: InternalRow => {
new UnboundedPrecedingWindowFunctionFrame(
target,
Expand All @@ -235,7 +244,7 @@ trait WindowExecBase extends UnaryExecNode {
}

// Shrinking Frame.
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) =>
target: InternalRow => {
new UnboundedFollowingWindowFunctionFrame(
target,
Expand All @@ -244,7 +253,7 @@ trait WindowExecBase extends UnaryExecNode {
}

// Moving Frame.
case ("AGGREGATE", frameType, lower, upper) =>
case ("AGGREGATE", frameType, lower, upper, _) =>
target: InternalRow => {
new SlidingWindowFunctionFrame(
target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ abstract class OffsetWindowFunctionFrameBase(
/** Index of the input row currently used for output. */
protected var inputIndex = 0

/** Attributes of the input row currently used for output. */
protected val inputAttrs = inputSchema.map(_.withNullability(true))

/**
* Create the projection used when the offset row exists.
* Please note that this project always respect null input values (like PostgreSQL).
*/
protected val projection = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
expressions.toSeq.map(_.input), inputAttrs)

Expand All @@ -114,7 +116,6 @@ abstract class OffsetWindowFunctionFrameBase(
/** Create the projection used when the offset row DOES NOT exists. */
protected val fillDefaultValue = {
// Collect the expressions and bind them.
val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
if (e.default == null || e.default.foldable && e.default.eval() == null) {
// The default value is null.
Expand Down Expand Up @@ -147,31 +148,132 @@ class FrameLessOffsetWindowFunctionFrame(
expressions: Array[OffsetWindowFunction],
inputSchema: Seq[Attribute],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
offset: Int)
offset: Int,
ignoreNulls: Boolean = false)
extends OffsetWindowFunctionFrameBase(
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {

/** Holder the UnsafeRow where the input operator by function is not null. */
private var nextSelectedRow = EmptyRow

// The number of rows skipped to get the next UnsafeRow where the input operator by function
// is not null.
private var skippedNonNullCount = 0

/** Create the projection to determine whether input is null. */
private val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)

/** Check if the output value of the first index is null. */
private def nullCheck(row: InternalRow): Boolean = project(row).getBoolean(0)

/** find the offset row whose input is not null */
private def findNextRowWithNonNullInput(): Unit = {
while (skippedNonNullCount < offset && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
if (!nullCheck(r)) {
nextSelectedRow = r
skippedNonNullCount += 1
}
inputIndex += 1
}
}

override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
if (ignoreNulls) {
findNextRowWithNonNullInput()
} else {
while (inputIndex < offset) {
if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
}
inputIndex = offset
}
inputIndex = offset
}

private val doWrite = if (ignoreNulls && offset > 0) {
// For illustration, here is one example: the input data contains nine rows,
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
// We use lead(input, 2) with IGNORE NULLS and the process is as follows:
// 1. current row -> null, next selected row -> y, output: y;
// 2. current row -> x, next selected row -> z, output: z;
// 3. current row -> null, next selected row -> z, output: z;
// 4. current row -> null, next selected row -> z, output: z;
// 5. current row -> y, next selected row -> v, output: v;
// 6. current row -> null, next selected row -> v, output: v;
// 7. current row -> z, next selected row -> empty, output: null;
// ... next selected row is empty, all following return null.
(current: InternalRow) =>
if (nextSelectedRow == EmptyRow) {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
} else {
if (nullCheck(current)) {
projection(nextSelectedRow)
} else {
skippedNonNullCount -= 1
findNextRowWithNonNullInput()
if (skippedNonNullCount == offset) {
projection(nextSelectedRow)
} else {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
nextSelectedRow = EmptyRow
}
}
}
} else if (ignoreNulls && offset < 0) {
// For illustration, here is one example: the input data contains nine rows,
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
// We use lag(input, 1) with IGNORE NULLS and the process is as follows:
// 1. current row -> null, next selected row -> empty, output: null;
// 2. current row -> x, next selected row -> empty, output: null;
// 3. current row -> null, next selected row -> x, output: x;
// 4. current row -> null, next selected row -> x, output: x;
// 5. current row -> y, next selected row -> x, output: x;
// 6. current row -> null, next selected row -> y, output: y;
// 7. current row -> z, next selected row -> y, output: y;
// 8. current row -> v, next selected row -> z, output: z;
// 9. current row -> null, next selected row -> v, output: v;
val absOffset = Math.abs(offset)
(current: InternalRow) =>
if (skippedNonNullCount == absOffset) {
nextSelectedRow = EmptyRow
skippedNonNullCount -= 1
while (nextSelectedRow == EmptyRow && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
if (!nullCheck(r)) {
nextSelectedRow = r
}
inputIndex += 1
}
}
if (nextSelectedRow == EmptyRow) {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
} else {
projection(nextSelectedRow)
}
if (!nullCheck(current)) {
skippedNonNullCount += 1
}
} else {
(current: InternalRow) =>
if (inputIndex >= 0 && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
fillDefaultValue(current)
}
inputIndex += 1
}

override def write(index: Int, current: InternalRow): Unit = {
if (inputIndex >= 0 && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
fillDefaultValue(current)
}
inputIndex += 1
doWrite(current)
}
}

Expand Down
40 changes: 36 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,24 @@ object functions {
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
Lag(e.expr, Literal(offset), Literal(defaultValue))
def lag(e: Column, offset: Int, defaultValue: Any): Column = {
lag(e, offset, defaultValue, false)
}

/**
* Window function: returns the value that is `offset` rows before the current row, and
* `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls`
* determines whether null values of row are included in or eliminated from the calculation.
* For example, an `offset` of one will return the previous row at any given point in the
* window partition.
*
* This is equivalent to the LAG function in SQL.
*
* @group window_funcs
* @since 3.2.0
*/
def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
}

/**
Expand Down Expand Up @@ -989,8 +1005,24 @@ object functions {
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
Lead(e.expr, Literal(offset), Literal(defaultValue))
def lead(e: Column, offset: Int, defaultValue: Any): Column = {
lead(e, offset, defaultValue, false)
}

/**
* Window function: returns the value that is `offset` rows after the current row, and
* `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls`
* determines whether null values of row are included in or eliminated from the calculation.
* The default value of `ignoreNulls` is false. For example, an `offset` of one will return
* the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
* @group window_funcs
* @since 3.2.0
*/
def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
}

/**
Expand Down
Loading

0 comments on commit 3e9821e

Please sign in to comment.