Skip to content

Commit

Permalink
Merge pull request #617 from Kotlin/pivot-then
Browse files Browse the repository at this point in the history
then operation in pivot column selection DSL inside aggregate
  • Loading branch information
koperagen authored Apr 2, 2024
2 parents 35792c9 + d3a9d11 commit ff139d7
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public fun <G> GroupBy<*, G>.pivotCounts(vararg columns: KProperty<*>, inward: B

// region pivot

public fun <T> AggregateGroupedDsl<T>.pivot(inward: Boolean = true, columns: ColumnsSelector<T, *>): PivotGroupBy<T> =
public fun <T> AggregateGroupedDsl<T>.pivot(inward: Boolean = true, columns: PivotColumnsSelector<T, *>): PivotGroupBy<T> =
PivotInAggregateImpl(this, columns, inward)

public fun <T> AggregateGroupedDsl<T>.pivot(vararg columns: String, inward: Boolean = true): PivotGroupBy<T> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateBody
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedDsl
import org.jetbrains.kotlinx.dataframe.api.PivotColumnsSelector
import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy
import org.jetbrains.kotlinx.dataframe.impl.api.AggregatedPivot
import org.jetbrains.kotlinx.dataframe.impl.api.aggregatePivot
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet

internal data class PivotInAggregateImpl<T>(
val aggregator: AggregateGroupedDsl<T>,
val columns: ColumnsSelector<T, *>,
val columns: PivotColumnsSelector<T, *>,
val inward: Boolean?,
val default: Any? = null
) : PivotGroupBy<T>, AggregatableInternal<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,19 @@ class PivotTests {
1, -1, 5
)
}

@Test
fun `pivot then in aggregate`() {
val df = dataFrameOf(
"category1" to List(12) { it % 3 },
"category2" to List(12) { "category2_${it % 2}" },
"category3" to List(12) { "category3_${it % 5}" },
"value" to List(12) { it }
)

val df1 = df.groupBy("category1").aggregate {
pivot { "category2" then "category3" }.count()
}
df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public fun <G> GroupBy<*, G>.pivotCounts(vararg columns: KProperty<*>, inward: B

// region pivot

public fun <T> AggregateGroupedDsl<T>.pivot(inward: Boolean = true, columns: ColumnsSelector<T, *>): PivotGroupBy<T> =
public fun <T> AggregateGroupedDsl<T>.pivot(inward: Boolean = true, columns: PivotColumnsSelector<T, *>): PivotGroupBy<T> =
PivotInAggregateImpl(this, columns, inward)

public fun <T> AggregateGroupedDsl<T>.pivot(vararg columns: String, inward: Boolean = true): PivotGroupBy<T> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateBody
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedDsl
import org.jetbrains.kotlinx.dataframe.api.PivotColumnsSelector
import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy
import org.jetbrains.kotlinx.dataframe.impl.api.AggregatedPivot
import org.jetbrains.kotlinx.dataframe.impl.api.aggregatePivot
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet

internal data class PivotInAggregateImpl<T>(
val aggregator: AggregateGroupedDsl<T>,
val columns: ColumnsSelector<T, *>,
val columns: PivotColumnsSelector<T, *>,
val inward: Boolean?,
val default: Any? = null
) : PivotGroupBy<T>, AggregatableInternal<T> {
Expand Down
15 changes: 15 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,19 @@ class PivotTests {
1, -1, 5
)
}

@Test
fun `pivot then in aggregate`() {
val df = dataFrameOf(
"category1" to List(12) { it % 3 },
"category2" to List(12) { "category2_${it % 2}" },
"category3" to List(12) { "category3_${it % 5}" },
"value" to List(12) { it }
)

val df1 = df.groupBy("category1").aggregate {
pivot { "category2" then "category3" }.count()
}
df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count()
}
}

0 comments on commit ff139d7

Please sign in to comment.