Skip to content

Commit

Permalink
JetBrains#623 Add support of window functions in Exposed DSL (JetBrai…
Browse files Browse the repository at this point in the history
…ns#1651)

* JetBrains#623 Add support of window functions in Exposed DSL

-Support of partition by and order by clauses
-Support of window frame clause (without EXCLUDE)
-Factories for common window functions
-Support for using aggregate functions as window functions

* JetBrains#623 Fix linter warnings

* Better name for NthValue class 'index' argument

Change argument to 'n' to conform with the most DB docs

Co-authored-by: Jocelyne <38375996+joc-a@users.noreply.github.com>

* Better name for nthValue function 'index' argument

Change argument to 'n' to conform with the most DB docs

Co-authored-by: Jocelyne <38375996+joc-a@users.noreply.github.com>

* JetBrains#623 Refactor window function definition

Window function definitions moved from top-level to ISqlExpressionBuilder
to eliminate their irrelevant appearance in code completion.

* JetBrains#623 Dump window functions related API changes

* JetBrains#623 Fix smart cast warning by replacing not-null assertion with null-safe call

* JetBrains#623 remove redundant DatabaseDialect.supportsWindowFunctions

This flag was used only in tests. Its value was false only for MySql < 8
which is covered by DatabaseTestsBase.kt Transaction.isOldMySql

---------

Co-authored-by: Dmitry Levin <dlevin@anylogic.com>
Co-authored-by: Jocelyne <38375996+joc-a@users.noreply.github.com>
  • Loading branch information
3 people authored and saral committed Oct 3, 2023
1 parent 4ec9b96 commit 1afe347
Show file tree
Hide file tree
Showing 10 changed files with 1,292 additions and 19 deletions.
256 changes: 247 additions & 9 deletions exposed-core/api/exposed-core.api

Large diffs are not rendered by default.

56 changes: 46 additions & 10 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ class Min<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the minimum value is obtained. */
val expr: Expression<in S>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MIN(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -174,8 +178,12 @@ class Max<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the maximum value is obtained. */
val expr: Expression<in S>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MAX(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -185,8 +193,12 @@ class Avg<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the average is calculated. */
val expr: Expression<in S>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("AVG(", expr, ")") }

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -196,8 +208,12 @@ class Sum<T>(
/** Returns the expression from which the sum is calculated. */
val expr: Expression<T>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("SUM(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -208,13 +224,17 @@ class Count(
val expr: Expression<*>,
/** Returns whether only distinct element should be count. */
val distinct: Boolean = false
) : Function<Long>(LongColumnType()) {
) : Function<Long>(LongColumnType()), WindowFunction<Long> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder {
+"COUNT("
if (distinct) +"DISTINCT "
+expr
+")"
}

override fun over(): WindowFunctionDefinition<Long> {
return WindowFunctionDefinition(LongColumnType(), this)
}
}

// Aggregate Functions for Statistics
Expand All @@ -227,7 +247,7 @@ class StdDevPop<T>(
/** Returns the expression from which the population standard deviation is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -237,6 +257,10 @@ class StdDevPop<T>(
functionProvider.stdDevPop(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -247,7 +271,7 @@ class StdDevSamp<T>(
/** Returns the expression from which the sample standard deviation is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -257,6 +281,10 @@ class StdDevSamp<T>(
functionProvider.stdDevSamp(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -267,7 +295,7 @@ class VarPop<T>(
/** Returns the expression from which the population variance is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -277,6 +305,10 @@ class VarPop<T>(
functionProvider.varPop(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -287,7 +319,7 @@ class VarSamp<T>(
/** Returns the expression from which the sample variance is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -297,6 +329,10 @@ class VarSamp<T>(
functionProvider.varSamp(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

// JSON Functions
Expand Down Expand Up @@ -325,7 +361,7 @@ class JsonExtract<T>(
/**
* Represents an SQL function that advances the specified [seq] and returns the new value.
*/
sealed class NextVal<T> (
sealed class NextVal<T>(
/** Returns the sequence from which the next value is obtained. */
val seq: Sequence,
columnType: IColumnType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,68 @@ interface ISqlExpressionBuilder {
caseSensitive: Boolean = true
): RegexpOp<T> = RegexpOp(this, pattern, caseSensitive)

// Window Functions

/** Returns the number of the current row within its partition, counting from 1. */
fun rowNumber(): RowNumber = RowNumber()

/** Returns the rank of the current row, with gaps; that is, the row_number of the first row in its peer group. */
fun rank(): Rank = Rank()

/** Returns the rank of the current row, without gaps; this function effectively counts peer groups. */
fun denseRank(): DenseRank = DenseRank()

/**
* Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1).
* The value thus ranges from 0 to 1 inclusive.
*/
fun percentRank(): PercentRank = PercentRank()

/**
* Returns the cumulative distribution, that is (number of partition rows preceding or peers with current row) /
* (total partition rows). The value thus ranges from 1/N to 1.
*/
fun cumeDist(): CumeDist = CumeDist()

/** Returns an integer ranging from 1 to the [numBuckets], dividing the partition as equally as possible. */
fun ntile(numBuckets: ExpressionWithColumnType<Int>): Ntile = Ntile(numBuckets)

/**
* Returns value evaluated at the row that is [offset] rows before the current row within the partition;
* if there is no such row, instead returns [defaultValue].
* Both [offset] and [defaultValue] are evaluated with respect to the current row.
*/
fun <T> ExpressionWithColumnType<T>.lag(
offset: ExpressionWithColumnType<Int> = intLiteral(1),
defaultValue: ExpressionWithColumnType<T>? = null
): Lag<T> = Lag(this, offset, defaultValue)

/**
* Returns value evaluated at the row that is [offset] rows after the current row within the partition;
* if there is no such row, instead returns [defaultValue].
* Both [offset] and [defaultValue] are evaluated with respect to the current row.
*/
fun <T> ExpressionWithColumnType<T>.lead(
offset: ExpressionWithColumnType<Int> = intLiteral(1),
defaultValue: ExpressionWithColumnType<T>? = null
): Lead<T> = Lead(this, offset, defaultValue)

/**
* Returns value evaluated at the row that is the first row of the window frame.
*/
fun <T> ExpressionWithColumnType<T>.firstValue(): FirstValue<T> = FirstValue(this)

/**
* Returns value evaluated at the row that is the last row of the window frame.
*/
fun <T> ExpressionWithColumnType<T>.lastValue(): LastValue<T> = LastValue(this)

/**
* Returns value evaluated at the row that is the [n]'th row of the window frame
* (counting from 1); null if no such row.
*/
fun <T> ExpressionWithColumnType<T>.nthValue(n: ExpressionWithColumnType<Int>): NthValue<T> = NthValue(this, n)

// JSON Conditions

/**
Expand Down
Loading

0 comments on commit 1afe347

Please sign in to comment.