Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend "unfold" operation and support it in the compiler plugin #742

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnsContainer
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.get
Expand All @@ -14,6 +15,7 @@ import org.jetbrains.kotlinx.dataframe.impl.api.insertImpl
import org.jetbrains.kotlinx.dataframe.impl.api.removeImpl
import kotlin.reflect.KProperty

@Interpretable("Replace0")
public fun <T, C> DataFrame<T>.replace(columns: ColumnsSelector<T, C>): ReplaceClause<T, C> =
ReplaceClause(this, columns)

Expand Down
29 changes: 16 additions & 13 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/unfold.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@ import org.jetbrains.kotlinx.dataframe.AnyColumnReference
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.impl.api.createDataFrameImpl
import org.jetbrains.kotlinx.dataframe.typeClass
import org.jetbrains.kotlinx.dataframe.impl.api.unfoldImpl
import kotlin.reflect.KProperty

public inline fun <reified T> DataColumn<T>.unfold(): AnyCol =
when (kind()) {
ColumnKind.Group, ColumnKind.Frame -> this
else -> when {
isPrimitive() -> this
else -> values().createDataFrameImpl(typeClass) {
(this as CreateDataFrameDsl<T>).properties()
}.asColumnGroup(name()).asDataColumn()
}
}
public inline fun <reified T> DataColumn<T>.unfold(vararg props: KProperty<*>, maxDepth: Int = 0): AnyCol =
unfoldImpl(skipPrimitive = true) { properties(roots = props, maxDepth = maxDepth) }

public inline fun <reified T> DataColumn<T>.unfold(noinline body: CreateDataFrameDsl<T>.() -> Unit): AnyCol =
unfoldImpl(skipPrimitive = false, body)

public inline fun <T, reified C> ReplaceClause<T, C>.unfold(vararg props: KProperty<*>, maxDepth: Int = 0): DataFrame<T> =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name unfolding would read better, or byUnfolding/withUnfolded. replace {}.unfold {} doesn't read as a sentence anymore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's possible, let's avoid motion or gravity to the native language, I believe, it's not a goal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, we should also use KCallable instead of KProperty for java classes support :)

with { it.unfold(props = props, maxDepth) }

@Refine
@Interpretable("ReplaceUnfold1")
public inline fun <T, reified C> ReplaceClause<T, C>.unfold(noinline body: CreateDataFrameDsl<C>.() -> Unit): DataFrame<T> =
with { it.unfoldImpl(skipPrimitive = false, body) }

public fun <T> DataFrame<T>.unfold(columns: ColumnsSelector<T, *>): DataFrame<T> = replace(columns).with { it.unfold() }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.jetbrains.kotlinx.dataframe.impl.api

import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.CreateDataFrameDsl
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
import org.jetbrains.kotlinx.dataframe.api.isPrimitive
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.typeClass

@PublishedApi
internal fun <T> DataColumn<T>.unfoldImpl(skipPrimitive: Boolean, body: CreateDataFrameDsl<T>.() -> Unit): AnyCol {
return when (kind()) {
ColumnKind.Group, ColumnKind.Frame -> this
else -> when {
skipPrimitive && isPrimitive() -> this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was very confused, like how can you unfold a primitive? but it's an isPrimitive() which can be a collection too... Can we rename isPrimitive() to something like isPrimitiveOrListLike()? unfold seems to be the only operation using it

Copy link
Collaborator Author

@koperagen koperagen Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't. Have a look unfold primitive test. skipPrimitive = false is needed to make it work, and skipPrimitive = true is needed to avoid unpacking for example a column of String to ColumnGroup, size: Int, the same as we do for toDataFrame with overloads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but can you take a look at the isPrimitive() function? That function also returns true when you run in on a collection and an array. My suggestion was only to rename the isPrimitive() function.

else -> values().createDataFrameImpl(typeClass) {
body((this as CreateDataFrameDsl<T>))
}.asColumnGroup(name()).asDataColumn()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.jetbrains.kotlinx.dataframe.api

import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.io.readJsonStr
import org.junit.Test
import kotlin.reflect.typeOf

Expand All @@ -13,4 +15,61 @@ class ReplaceTests {
conv.columnNames() shouldBe listOf("b")
conv.columnTypes() shouldBe listOf(typeOf<Double>())
}

@Test
fun `unfold primitive`() {
val a by columnOf("123")
val df = dataFrameOf(a)

val conv = df.replace { a }.unfold {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use replace and unfold independently? If somehow yes, could you please add test for this, of only together, could be combined to one function?

"b" from { it }
"c" from { DataRow.readJsonStr("""{"prop": 1}""") }
}

val b = conv["a"]["b"]
b.type() shouldBe typeOf<String>()
b.values() shouldBe listOf("123")

val c = conv["a"]["c"]["prop"]
c.type() shouldBe typeOf<Int>()
c.values() shouldBe listOf(1)
}

@Test
fun `unfold properties`() {
val col by columnOf(A("1", 123, B(3.0)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case "More fine-grained toDataFrame. Instead of converting 20-30 properties to 2-3 level of nesting all at once user can choose to convert toDataFrame(maxDepth = 0) and unfold required properties to whatever level they need" covered here, in this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, yes. I intend to have a more representative example as a part of compiler plugin demo. There's a tree of objects with many properties and potentially deep nesting from konsist library. It will be a good illustration. But here it merely unfolds one specific column up to 2 levels.

val df1 = dataFrameOf(col)
val conv = df1.replace { col }.unfold(maxDepth = 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specifying maxDepth now breaks, while it worked before.
Try running df1.replace { col }.with { it.unfold() } before the PR and after it.
It works before, but now it gives: java.lang.UnsupportedOperationException: Can not get nested column 'd' from ValueColumn 'bb'


val a = conv["col"]["a"]
a.type() shouldBe typeOf<String>()
a.values() shouldBe listOf("1")

val b = conv["col"]["b"]
b.type() shouldBe typeOf<Int>()
b.values() shouldBe listOf(123)

val d = conv["col"]["bb"]["d"]
d.type() shouldBe typeOf<Double>()
d.values() shouldBe listOf(3.0)
}

class B(val d: Double)
class A(val a: String, val b: Int, val bb: B)

@Test
fun `skip primitive`() {
val col1 by columnOf("1", "2")
val col2 by columnOf(B(1.0), B(2.0))
val df1 = dataFrameOf(col1, col2)
val conv = df1.replace { nameStartsWith("col") }.unfold()

val a = conv["col1"]
a.type() shouldBe typeOf<String>()
a.values() shouldBe listOf("1", "2")

val b = conv["col2"]["d"]
b.type() shouldBe typeOf<Double>()
b.values() shouldBe listOf(1.0, 2.0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,19 @@

package org.jetbrains.kotlinx.dataframe.plugin

import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TraverseConfiguration
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.CreateDataFrameDslImplApproximation
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID

fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
val callReturnType = call.resolvedType
Expand All @@ -38,44 +32,6 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In

val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
when (name) {
"toDataFrameDsl" -> {
val list = call.argumentList as FirResolvedArgumentList
val lambda = (list.arguments.singleOrNull() as? FirAnonymousFunctionExpression)?.anonymousFunction
val statements = lambda?.body?.statements
if (statements != null) {
val receiver = CreateDataFrameDslImplApproximation()
statements.filterIsInstance<FirFunctionCall>().forEach {
val schemaProcessor = it.loadInterpreter() ?: return@forEach
interpret(
it,
schemaProcessor,
mapOf("dsl" to Interpreter.Success(receiver), "call" to Interpreter.Success(call)),
reporter
)
}
PluginDataFrameSchema(receiver.columns)
} else {
PluginDataFrameSchema(emptyList())
}
}
"toDataFrame" -> {
val list = call.argumentList as FirResolvedArgumentList
val argument = list.mapping.entries.firstOrNull { it.value.name == Name.identifier("maxDepth") }?.key
val maxDepth = when (argument) {
null -> 0
is FirLiteralExpression -> (argument.value as Number).toInt()
else -> null
}
if (maxDepth != null) {
toDataFrame(maxDepth, call, TraverseConfiguration())
} else {
PluginDataFrameSchema(emptyList())
}
}
"toDataFrameDefault" -> {
val maxDepth = 0
toDataFrame(maxDepth, call, TraverseConfiguration())
}
"Aggregate" -> {
val groupByCall = call.explicitReceiver as? FirFunctionCall
val interpreter = groupByCall?.loadInterpreter(session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,11 @@ fun <T> AbstractInterpreter<T>.kproperty(

internal fun <T> AbstractInterpreter<T>.string(
name: ArgumentName? = null
): ExpectedArgumentProvider<String
> = arg(name, lens = Interpreter.Value)
): ExpectedArgumentProvider<String> =
arg(name, lens = Interpreter.Value)

internal fun <T> AbstractInterpreter<T>.dsl(
name: ArgumentName? = null
): ExpectedArgumentProvider<(Any, Map<String, Interpreter.Success<Any?>>) -> Unit> =
arg(name, lens = Interpreter.Dsl, defaultValue = Present(value = {_, _ -> }))

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.string
import org.jetbrains.kotlinx.dataframe.plugin.impl.type

Expand Down Expand Up @@ -48,11 +49,11 @@ class AddDslApproximation(val columns: MutableList<SimpleCol>)

class AddWithDsl : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl)
val Arguments.body by dsl()

override fun Arguments.interpret(): PluginDataFrameSchema {
val addDsl = AddDslApproximation(receiver.columns().toMutableList())
body(addDsl)
body(addDsl, emptyMap())
return PluginDataFrameSchema(addDsl.columns)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class RenameInto : AbstractSchemaModificationInterpreter() {
override fun Arguments.interpret(): PluginDataFrameSchema {
require(receiver.columns.size == newNames.size)
var i = 0
return receiver.schema.map(receiver.columns.mapTo(mutableSetOf()) { it.path.path }, nextName = { newNames[i].also { i += 1 } })
return receiver.schema.rename(receiver.columns.mapTo(mutableSetOf()) { it.path.path }, nextName = { newNames[i].also { i += 1 } })
}
}

internal fun PluginDataFrameSchema.map(selected: ColumnsSet, nextName: () -> String): PluginDataFrameSchema {
internal fun PluginDataFrameSchema.rename(selected: ColumnsSet, nextName: () -> String): PluginDataFrameSchema {
return PluginDataFrameSchema(
f(columns(), nextName, selected, emptyList())
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.jetbrains.kotlin.fir.declarations.utils.effectiveVisibility
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
import org.jetbrains.kotlin.fir.declarations.utils.isStatic
import org.jetbrains.kotlin.fir.expressions.FirCallableReferenceAccess
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.java.JavaTypeParameterStack
Expand Down Expand Up @@ -46,6 +46,7 @@ import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
Expand All @@ -54,26 +55,68 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import java.util.*

class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.body by dsl()
override fun Arguments.interpret(): PluginDataFrameSchema {
val dsl = CreateDataFrameDslImplApproximation()
val receiver = receiver ?: return PluginDataFrameSchema(emptyList())
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema(emptyList())
when {
arg.isStarProjection -> PluginDataFrameSchema(emptyList())
else -> {
val classLike = arg.type as? ConeClassLikeType ?: return PluginDataFrameSchema(emptyList())
body(dsl, mapOf(Properties0.classExtraArgument to Interpreter.Success(classLike)))
}
}
return PluginDataFrameSchema(dsl.columns)
}
}

class ToDataFrame : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
}
}

class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
}
}

private const val DEFAULT_MAX_DEPTH = 0

class Properties0 : AbstractInterpreter<Unit>() {
companion object {
const val classExtraArgument = "explicitReceiver"
}

val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
val Arguments.call: FirFunctionCall by arg()
val Arguments.coneKotlinType: ConeKotlinType by arg(name = name(classExtraArgument))
val Arguments.maxDepth: Int by arg()
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl, defaultValue = Present(value = {}))
val Arguments.body by dsl()

override fun Arguments.interpret() {
dsl.configuration.maxDepth = maxDepth
body(dsl.configuration.traverseConfiguration)
val schema = toDataFrame(dsl.configuration.maxDepth, call, dsl.configuration.traverseConfiguration)
body(dsl.configuration.traverseConfiguration, emptyMap())
val schema = toDataFrame(dsl.configuration.maxDepth, coneKotlinType, dsl.configuration.traverseConfiguration)
dsl.columns.addAll(schema.columns())
}
}

class CreateDataFrameConfiguration {
var maxDepth = 0
var maxDepth = DEFAULT_MAX_DEPTH
var traverseConfiguration: TraverseConfiguration = TraverseConfiguration()
}

Expand Down Expand Up @@ -123,7 +166,7 @@ class Exclude1 : AbstractInterpreter<Unit>() {
@OptIn(SymbolInternals::class)
internal fun KotlinTypeFacade.toDataFrame(
maxDepth: Int,
call: FirFunctionCall,
classLikeType: ConeKotlinType,
traverseConfiguration: TraverseConfiguration
): PluginDataFrameSchema {
fun ConeKotlinType.isValueType() =
Expand Down Expand Up @@ -238,14 +281,21 @@ internal fun KotlinTypeFacade.toDataFrame(
}
}

val receiver = call.explicitReceiver ?: return PluginDataFrameSchema(emptyList())
return PluginDataFrameSchema(convert(classLikeType, 0))
}

internal fun KotlinTypeFacade.toDataFrame(
maxDepth: Int,
explicitReceiver: FirExpression?,
traverseConfiguration: TraverseConfiguration
): PluginDataFrameSchema {
val receiver = explicitReceiver ?: return PluginDataFrameSchema(emptyList())
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema(emptyList())
return when {
arg.isStarProjection -> PluginDataFrameSchema(emptyList())
else -> {
val classLike = arg.type as? ConeClassLikeType ?: return PluginDataFrameSchema(emptyList())
val columns = convert(classLike, 0)
PluginDataFrameSchema(columns)
return toDataFrame(maxDepth, classLike, traverseConfiguration)
}
}
}
Expand Down
Loading
Loading