Skip to content

Commit

Permalink
Merge pull request #783 from Kotlin/plugin-column-selection-dsl
Browse files Browse the repository at this point in the history
Add initial support for CS DSL in the compiler plugin
  • Loading branch information
koperagen authored Jul 16, 2024
2 parents f5a06eb + e6b4152 commit 5efb5ff
Show file tree
Hide file tree
Showing 25 changed files with 201 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.Predicate
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.BehaviorArg
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.ColumnDoesNotExistArg
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.ExampleArg
Expand Down Expand Up @@ -300,6 +301,7 @@ public interface AllColumnsSelectionDsl<out _UNUSED> {
*
* `df.`[select][DataFrame.select]` { `[all][ColumnsSelectionDsl.all]`() }`
*/
@Interpretable("All0")
public fun ColumnsSelectionDsl<*>.all(): TransformableColumnSet<*> =
asSingleColumn().allColumnsInternal()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnFilter
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar.ColumnGroupName
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar.ColumnSetName
Expand Down Expand Up @@ -138,6 +139,7 @@ public interface ColsAtAnyDepthColumnsSelectionDsl {
*
* `df.`[select][DataFrame.select]` { `[colsAtAnyDepth][ColumnsSelectionDsl.colsAtAnyDepth]` { !it.`[isColumnGroup][DataColumn.isColumnGroup]` } }`
*/
@Interpretable("ColsAtAnyDepth0")
public fun ColumnsSelectionDsl<*>.colsAtAnyDepth(predicate: ColumnFilter<*> = { true }): ColumnSet<*> =
asSingleColumn().colsAtAnyDepthInternal(predicate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnFilter
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar.ColumnGroupName
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar.ColumnSetName
Expand Down Expand Up @@ -203,6 +204,7 @@ public fun <C> ColumnSet<*>.colsOf(
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.FilterParam]
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.Return]
*/
@Interpretable("ColsOf1")
public inline fun <reified C> ColumnSet<*>.colsOf(
noinline filter: ColumnFilter<C> = { true },
): TransformableColumnSet<C> = colsOf(typeOf<C>(), filter)
Expand All @@ -228,6 +230,7 @@ public fun <C> ColumnsSelectionDsl<*>.colsOf(
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.FilterParam]
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.Return]
*/
@Interpretable("ColsOf0")
public inline fun <reified C> ColumnsSelectionDsl<*>.colsOf(
noinline filter: ColumnFilter<C> = { true },
): TransformableColumnSet<C> = asSingleColumn().colsOf(typeOf<C>(), filter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.Predicate
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.ColumnGroupName
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.ColumnSetName
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.PlainDslName
Expand Down Expand Up @@ -111,6 +112,7 @@ public interface FrameColsColumnsSelectionDsl {
*
* `df.`[select][DataFrame.select]` { `[frameCols][ColumnsSelectionDsl.frameCols]` { it.`[name][ColumnReference.name]`.`[startsWith][String.startsWith]`("my") } }`
*/
@Interpretable("FrameCols0")
public fun ColumnSet<*>.frameCols(filter: Predicate<FrameColumn<*>> = { true }): TransformableColumnSet<DataFrame<*>> =
frameColumnsInternal(filter)

Expand Down
2 changes: 1 addition & 1 deletion plugins/kotlin-dataframe/gradle.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
kotlin.code.style=official
kotlinVersion=2.0.20-dev-5379
kotlinVersion=2.0.20-Beta2-78
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ private class Checker(val cache: FirCache<String, PluginDataFrameSchema, KotlinT
val targetProjection = expression.typeArguments.getOrNull(0) as? FirTypeProjectionWithVariance ?: return
val targetType = targetProjection.typeRef.coneType as? ConeClassLikeType ?: return
val target = pluginDataFrameSchema(targetType)
val sourceColumns = source.flatten()
val targetColumns = target.flatten()
val sourceColumns = source.flatten(includeFrames = true)
val targetColumns = target.flatten(includeFrames = true)
val sourceMap = sourceColumns.associate { it.path.path to it.column }
val missingColumns = mutableListOf<String>()
var valid = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.jetbrains.kotlin.fir.extensions.FirFunctionCallRefinementExtension
import org.jetbrains.kotlin.fir.moduleData
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.builder.buildResolvedNamedReference
import org.jetbrains.kotlin.fir.resolve.calls.CallInfo
import org.jetbrains.kotlin.fir.resolve.calls.candidate.CallInfo
import org.jetbrains.kotlin.fir.resolve.defaultType
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.resolve.providers.symbolProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation

internal class And10 : AbstractInterpreter<List<ColumnWithPathApproximation>>() {
val Arguments.other: List<ColumnWithPathApproximation> by arg()
val Arguments.receiver: List<ColumnWithPathApproximation> by arg()
internal class And10 : AbstractInterpreter<ColumnsResolver>() {
val Arguments.other: ColumnsResolver by arg()
val Arguments.receiver: ColumnsResolver by arg()

override fun Arguments.interpret(): List<ColumnWithPathApproximation> {
return receiver + other
override fun Arguments.interpret(): ColumnsResolver {
return object : ColumnsResolver {
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
return receiver.resolve(df) + other.resolve(df)
}
}
}
}

class SingleColumnApproximation(val col: ColumnWithPathApproximation) : ColumnsResolver {
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
return listOf(col)
}
}

interface ColumnsResolver {
fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation>
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.impl.varargString

internal class Convert0 : AbstractInterpreter<ConvertApproximation>() {
val Arguments.columns: List<ColumnWithPathApproximation> by arg()
val Arguments.columns: ColumnsResolver by arg()
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
override val Arguments.startingSchema get() = receiver

override fun Arguments.interpret(): ConvertApproximation {
return ConvertApproximation(receiver, columns.map { it.path.path })
return ConvertApproximation(receiver, columns.resolve(receiver).map { it.path.path })
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame

class DropNulls0 : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.columns: List<ColumnWithPathApproximation> by arg()
val Arguments.columns: ColumnsResolver by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
return PluginDataFrameSchema(fillNullsImpl(receiver.columns(), columns.mapTo(mutableSetOf()) { it.path.path }, emptyList()))
return PluginDataFrameSchema(fillNullsImpl(receiver.columns(), columns.resolve(receiver).mapTo(mutableSetOf()) { it.path.path }, emptyList()))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
internal class Explode0 : AbstractInterpreter<PluginDataFrameSchema>() {
val Arguments.dropEmpty: Boolean by arg(defaultValue = Present(true))
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.selector: List<ColumnWithPathApproximation>? by arg(defaultValue = Present(null))
val Arguments.selector: ColumnsResolver? by arg(defaultValue = Present(null))
override val Arguments.startingSchema get() = receiver

override fun Arguments.interpret(): PluginDataFrameSchema {
val columns = selector ?: TODO()
return receiver.explodeImpl(dropEmpty, columns.map { ColumnPathApproximation(it.path.path) })
return receiver.explodeImpl(dropEmpty, columns.resolve(receiver).map { ColumnPathApproximation(it.path.path) })
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,24 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation

fun PluginDataFrameSchema.flatten(): List<ColumnWithPathApproximation> {
fun PluginDataFrameSchema.flatten(includeFrames: Boolean): List<ColumnWithPathApproximation> {
if (columns().isEmpty()) return emptyList()
val columns = mutableListOf<ColumnWithPathApproximation>()
flattenImpl(columns(), emptyList(), columns)
flattenImpl(columns(), emptyList(), columns, includeFrames)
return columns
}

fun flattenImpl(columns: List<SimpleCol>, path: List<String>, flatList: MutableList<ColumnWithPathApproximation>) {
fun flattenImpl(columns: List<SimpleCol>, path: List<String>, flatList: MutableList<ColumnWithPathApproximation>, includeFrames: Boolean) {
columns.forEach { column ->
val fullPath = path + listOf(column.name)
when (column) {
is SimpleColumnGroup -> {
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))
flattenImpl(column.columns(), fullPath, flatList)
flattenImpl(column.columns(), fullPath, flatList, includeFrames)
}
is SimpleFrameColumn -> {
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))
flattenImpl(column.columns(), fullPath, flatList)
flattenImpl(column.columns(), fullPath, flatList, includeFrames)
}
is SimpleDataColumn -> {
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame

class Group0 : AbstractInterpreter<GroupClauseApproximation>() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.columns: List<ColumnWithPathApproximation> by arg()
val Arguments.columns: ColumnsResolver by arg()

override fun Arguments.interpret(): GroupClauseApproximation {
return GroupClauseApproximation(receiver, columns)
}
}

class GroupClauseApproximation(val df: PluginDataFrameSchema, val columns: List<ColumnWithPathApproximation>)
class GroupClauseApproximation(val df: PluginDataFrameSchema, val columns: ColumnsResolver)

class Into0 : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: GroupClauseApproximation by arg()
val Arguments.column: String by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
val grouped = groupImpl(receiver.df.columns(), receiver.columns.mapTo(mutableSetOf()) { it.path.path }, column)
val grouped = groupImpl(receiver.df.columns(), receiver.columns.resolve(receiver.df).mapTo(mutableSetOf()) { it.path.path }, column)
return grouped
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class GroupBy(val df: PluginDataFrameSchema, val keys: List<ColumnWithPathApprox
class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.moveToTop: Boolean by arg(defaultValue = Present(true))
val Arguments.cols: List<ColumnWithPathApproximation> by arg()
val Arguments.cols: ColumnsResolver by arg()

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver, cols, moveToTop)
return GroupBy(receiver, cols.resolve(receiver), moveToTop)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal class Join0 : AbstractInterpreter<PluginDataFrameSchema>() {
override fun Arguments.interpret(): PluginDataFrameSchema {
val nameGenerator = ColumnNameGenerator()
val left = receiver.columns()
val right = removeImpl(other.columns(), setOf(selector.right.path.path)).updatedColumns
val right = removeImpl(other.columns(), setOf(selector.right.resolve(receiver).single().path.path)).updatedColumns

val rightColumnGroups = right.filterIsInstance<SimpleColumnGroup>().associateBy { it.name }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation

internal data class ColumnMatchApproximation(val left: ColumnWithPathApproximation, val right: ColumnWithPathApproximation)
internal data class ColumnMatchApproximation(val left: ColumnsResolver, val right: ColumnsResolver)

internal class Match0 : AbstractInterpreter<ColumnMatchApproximation>() {
val Arguments.receiver: List<ColumnWithPathApproximation> by arg()
val Arguments.other: List<ColumnWithPathApproximation> by arg()
val Arguments.receiver: ColumnsResolver by arg()
val Arguments.other: ColumnsResolver by arg()

override fun Arguments.interpret(): ColumnMatchApproximation {
return ColumnMatchApproximation(receiver.single(), other.single())
return ColumnMatchApproximation(receiver, other)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame

class Remove0 : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.columns: List<ColumnWithPathApproximation> by arg()
val Arguments.columns: ColumnsResolver by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
val removeResult = removeImpl(receiver.columns(), columns.mapTo(mutableSetOf()) { it.path.path })
val removeResult = removeImpl(receiver.columns(), columns.resolve(receiver).mapTo(mutableSetOf()) { it.path.path })
return PluginDataFrameSchema(removeResult.updatedColumns)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.varargString

class Rename : AbstractInterpreter<RenameClauseApproximation>() {
private val Arguments.receiver by dataFrame()
private val Arguments.columns: List<ColumnWithPathApproximation> by arg()
private val Arguments.columns: ColumnsResolver by arg()
override fun Arguments.interpret(): RenameClauseApproximation {
return RenameClauseApproximation(receiver, columns)
}
}

class RenameClauseApproximation(val schema: PluginDataFrameSchema, val columns: List<ColumnWithPathApproximation>)
class RenameClauseApproximation(val schema: PluginDataFrameSchema, val columns: ColumnsResolver)

class RenameInto : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: RenameClauseApproximation by arg()
val Arguments.newNames: List<String> by varargString()

override fun Arguments.interpret(): PluginDataFrameSchema {
require(receiver.columns.size == newNames.size)
val columns = receiver.columns.resolve(receiver.schema)
require(columns.size == newNames.size)
var i = 0
return receiver.schema.map(receiver.columns.mapTo(mutableSetOf()) { it.path.path }, nextName = { newNames[i].also { i += 1 } })
return receiver.schema.map(columns.mapTo(mutableSetOf()) { it.path.path }, nextName = { newNames[i].also { i += 1 } })
}
}

Expand Down
Loading

0 comments on commit 5efb5ff

Please sign in to comment.