Skip to content

Commit

Permalink
[Compiler plugin] Support dataFrameOf(Pair<String, List<T>)
Browse files Browse the repository at this point in the history
  • Loading branch information
koperagen committed Oct 3, 2024
1 parent fea737e commit 3596f05
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ public inline fun <reified C> dataFrameOf(vararg header: String, fill: (String)

public fun dataFrameOf(header: Iterable<String>): DataFrameBuilder = DataFrameBuilder(header.asList())

@Refine
@Interpretable("DataFrameOf3")
public fun dataFrameOf(vararg columns: Pair<String, List<Any?>>): DataFrame<*> =
columns.map { it.second.toColumn(it.first, Infer.Type) }.toDataFrame()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FunctionCallTransformer(
val tokenFir = token.toClassSymbol(session)!!.fir
tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol })

return buildLetCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
return buildScopeFunctionCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
}
}

Expand Down Expand Up @@ -253,7 +253,7 @@ class FunctionCallTransformer(
val keyToken = groupMarker.toClassSymbol(session)!!.fir
keyToken.callShapeData = CallShapeData.RefinedType(groupApis.map { it.scope.symbol })

return buildLetCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
return buildScopeFunctionCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
}
}

Expand Down Expand Up @@ -305,18 +305,17 @@ class FunctionCallTransformer(
private fun Name.asTokenName() = identifierOrNullIfSpecial?.titleCase() ?: DEFAULT_NAME

@OptIn(SymbolInternals::class)
private fun buildLetCall(
private fun buildScopeFunctionCall(
call: FirFunctionCall,
originalSymbol: FirNamedFunctionSymbol,
dataSchemaApis: List<DataSchemaApi>,
additionalDeclarations: List<FirClass>
): FirFunctionCall {

val explicitReceiver = call.explicitReceiver ?: return call
val receiverType = explicitReceiver.resolvedType
val explicitReceiver = call.explicitReceiver
val receiverType = explicitReceiver?.resolvedType
val returnType = call.resolvedType
val resolvedLet = findLet()
val parameter = resolvedLet.valueParameterSymbols[0]
val scopeFunction = if (explicitReceiver != null) findLet() else findRun()

// original call is inserted later
call.transformCalleeReference(object : FirTransformer<Nothing?>() {
Expand Down Expand Up @@ -350,20 +349,23 @@ class FunctionCallTransformer(
returnTypeRef = buildResolvedTypeRef {
type = returnType
}
val itName = Name.identifier("it")
val parameterSymbol = FirValueParameterSymbol(itName)
valueParameters += buildValueParameter {
moduleData = session.moduleData
origin = FirDeclarationOrigin.Source
returnTypeRef = buildResolvedTypeRef {
type = receiverType
val parameterSymbol = receiverType?.let {
val itName = Name.identifier("it")
val parameterSymbol = FirValueParameterSymbol(itName)
valueParameters += buildValueParameter {
moduleData = session.moduleData
origin = FirDeclarationOrigin.Source
returnTypeRef = buildResolvedTypeRef {
type = receiverType
}
this.name = itName
this.symbol = parameterSymbol
containingFunctionSymbol = fSymbol
isCrossinline = false
isNoinline = false
isVararg = false
}
this.name = itName
this.symbol = parameterSymbol
containingFunctionSymbol = fSymbol
isCrossinline = false
isNoinline = false
isVararg = false
parameterSymbol
}
body = buildBlock {
this.coneTypeOrNull = returnType
Expand All @@ -375,20 +377,23 @@ class FunctionCallTransformer(
statements += additionalDeclarations

statements += buildReturnExpression {
val itPropertyAccess = buildPropertyAccessExpression {
coneTypeOrNull = receiverType
calleeReference = buildResolvedNamedReference {
name = itName
resolvedSymbol = parameterSymbol
if (parameterSymbol != null) {
val itPropertyAccess = buildPropertyAccessExpression {
coneTypeOrNull = receiverType
calleeReference = buildResolvedNamedReference {
name = parameterSymbol.name
resolvedSymbol = parameterSymbol
}
}
if (callDispatchReceiver != null) {
call.replaceDispatchReceiver(itPropertyAccess)
}
call.replaceExplicitReceiver(itPropertyAccess)
if (callExtensionReceiver != null) {
call.replaceExtensionReceiver(itPropertyAccess)
}
}
if (callDispatchReceiver != null) {
call.replaceDispatchReceiver(itPropertyAccess)
}
call.replaceExplicitReceiver(itPropertyAccess)
if (callExtensionReceiver != null) {
call.replaceExtensionReceiver(itPropertyAccess)
}

result = call
this.target = target
}
Expand All @@ -397,11 +402,19 @@ class FunctionCallTransformer(
isLambda = true
hasExplicitParameterList = false
typeRef = buildResolvedTypeRef {
type = ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
typeArguments = arrayOf(receiverType, returnType),
isNullable = false
)
type = if (receiverType != null) {
ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
typeArguments = arrayOf(receiverType, returnType),
isNullable = false
)
} else {
ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function0"))),
typeArguments = arrayOf(returnType),
isNullable = false
)
}
}
invocationKind = EventOccurrencesRange.EXACTLY_ONCE
inlineStatus = InlineStatus.Inline
Expand All @@ -413,11 +426,13 @@ class FunctionCallTransformer(
val newCall1 = buildFunctionCall {
source = call.source
this.coneTypeOrNull = returnType
typeArguments += buildTypeProjectionWithVariance {
typeRef = buildResolvedTypeRef {
type = receiverType
if (receiverType != null) {
typeArguments += buildTypeProjectionWithVariance {
typeRef = buildResolvedTypeRef {
type = receiverType
}
variance = Variance.INVARIANT
}
variance = Variance.INVARIANT
}

typeArguments += buildTypeProjectionWithVariance {
Expand All @@ -429,11 +444,14 @@ class FunctionCallTransformer(
dispatchReceiver = null
this.explicitReceiver = callExplicitReceiver
extensionReceiver = callExtensionReceiver ?: callDispatchReceiver
argumentList = buildResolvedArgumentList(original = null, linkedMapOf(argument to parameter.fir))
argumentList = buildResolvedArgumentList(
original = null,
linkedMapOf(argument to scopeFunction.valueParameterSymbols[0].fir)
)
calleeReference = buildResolvedNamedReference {
source = call.calleeReference.source
this.name = Name.identifier("let")
resolvedSymbol = resolvedLet
this.name = scopeFunction.name
resolvedSymbol = scopeFunction
}
}
return newCall1
Expand Down Expand Up @@ -565,5 +583,9 @@ class FunctionCallTransformer(
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("let")).single()
}

private fun findRun(): FirFunctionSymbol<*> {
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("run")).single { it.typeParameterSymbols.size == 1 }
}

private fun String.titleCase() = replaceFirstChar { it.uppercaseChar() }
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.types.commonSuperTypeOrNull
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.types.type
import org.jetbrains.kotlin.fir.types.typeContext
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
Expand Down Expand Up @@ -36,3 +39,18 @@ class DataFrameBuilderInvoke0 : AbstractSchemaModificationInterpreter() {
return PluginDataFrameSchema(columns)
}
}

class DataFrameOf3 : AbstractSchemaModificationInterpreter() {
val Arguments.columns: List<Interpreter.Success<Pair<*, *>>> by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
val res = columns.map {
val it = it.value
val name = (it.first as? FirLiteralExpression)?.value as? String
val type = (it.second as? FirExpression)?.resolvedType?.typeArguments?.getOrNull(0)?.type
if (name == null || type == null) return PluginDataFrameSchema(emptyList())
simpleColumnOf(name, type)
}
return PluginDataFrameSchema(res)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter

class PairConstructor : AbstractInterpreter<Pair<*, *>>() {
val Arguments.receiver: Any? by arg(lens = Interpreter.Id)
val Arguments.that: Any? by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): Pair<*, *> {
return receiver to that
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ fun <T> KotlinTypeFacade.interpret(
is FirCallableReferenceAccess -> {
toKPropertyApproximation(it, session)
}

is FirFunctionCall -> {
it.loadInterpreter()?.let { processor ->
interpret(it, processor, emptyMap(), reporter)
}
}
else -> null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,20 @@ import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
import org.jetbrains.kotlin.fir.expressions.UnresolvedExpressionTypeAccess
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedNamedFunctionSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate
Expand All @@ -76,12 +83,14 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.PairConstructor
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReadExcel
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameColumn
Expand All @@ -91,8 +100,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToTop
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Update0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.UpdateWith0
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names

@OptIn(UnresolvedExpressionTypeAccess::class)
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
if (
calleeReference.toResolvedNamedFunctionSymbol()?.callableId == Names.TO &&
coneTypeOrNull?.classId == Names.PAIR
) {
return PairConstructor()
}
val symbol =
(calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol ?: return null
val argName = Name.identifier("interpreter")
Expand Down Expand Up @@ -208,6 +225,7 @@ internal inline fun <reified T> String.load(): T {
"ToTop" -> ToTop()
"Update0" -> Update0()
"Aggregate" -> Aggregate()
"DataFrameOf3" -> DataFrameOf3()
else -> error("$this")
} as T
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dataframe.plugin.utils

import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
Expand Down Expand Up @@ -50,6 +51,9 @@ object Names {
val LOCAL_DATE_CLASS_ID = kotlinx.datetime.LocalDate::class.classId()
val LOCAL_DATE_TIME_CLASS_ID = kotlinx.datetime.LocalDateTime::class.classId()
val INSTANT_CLASS_ID = kotlinx.datetime.Instant::class.classId()

val PAIR = ClassId(FqName("kotlin"), Name.identifier("Pair"))
val TO = CallableId(FqName("kotlin"), Name.identifier("to"))
}

private fun KClass<*>.classId(): ClassId {
Expand Down
14 changes: 14 additions & 0 deletions plugins/kotlin-dataframe/testData/box/dataFrameOf_to.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf(
"a" to listOf(1, 2),
"b" to listOf("str1", "str2"),
)
val i: Int = df.a[0]
val str: String = df.b[0]
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ public void testDataFrameOf() {
runTest("testData/box/dataFrameOf.kt");
}

@Test
@TestMetadata("dataFrameOf_to.kt")
public void testDataFrameOf_to() {
runTest("testData/box/dataFrameOf_to.kt");
}

@Test
@TestMetadata("dataFrameOf_vararg.kt")
public void testDataFrameOf_vararg() {
Expand Down

0 comments on commit 3596f05

Please sign in to comment.