IR: copy type parameters for local functions in LocalDeclarationLowering

Local functions raised in LocalDeclarationLowering continue to refer to
type parameters that are no longer visible to them.
This commit only adds new type parameters to their declarations, which
makes JVM accept those declarations. The generated IR is still
semantically incorrect (needs further fix), but code generation seems
to proceed nevertheless.
This commit is contained in:
Georgy Bronnikov
2019-12-24 18:25:57 +03:00
parent 01da7f289b
commit 8d0ffa1444
9 changed files with 220 additions and 35 deletions
@@ -23,17 +23,18 @@ import org.jetbrains.kotlin.backend.common.push
import org.jetbrains.kotlin.descriptors.Visibilities
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.expressions.IrPropertyReference
import org.jetbrains.kotlin.ir.expressions.IrValueAccessExpression
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
import org.jetbrains.kotlin.ir.types.IrSimpleType
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.IrTypeProjection
import org.jetbrains.kotlin.ir.util.isLocal
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
import kotlin.collections.set
class Closure(val capturedValues: List<IrValueSymbol> = emptyList())
data class Closure(val capturedValues: List<IrValueSymbol>, val capturedTypeParameters: List<IrTypeParameter>)
class ClosureAnnotator(irFile: IrFile) {
private val closureBuilders = mutableMapOf<IrDeclaration, ClosureBuilder>()
@@ -58,6 +59,9 @@ class ClosureAnnotator(irFile: IrFile) {
private val declaredValues = mutableSetOf<IrValueDeclaration>()
private val includes = mutableSetOf<ClosureBuilder>()
private val potentiallyCapturedTypeParameters = mutableSetOf<IrTypeParameter>()
private val capturedTypeParameters = mutableSetOf<IrTypeParameter>()
var processed = false
/*
@@ -74,7 +78,7 @@ class ClosureAnnotator(irFile: IrFile) {
}
}
// TODO: We can save the closure and reuse it.
return Closure(result.toList())
return Closure(result.toList(), capturedTypeParameters.toList())
}
@@ -95,6 +99,31 @@ class ClosureAnnotator(irFile: IrFile) {
fun isExternal(valueDeclaration: IrValueDeclaration): Boolean {
return !declaredValues.contains(valueDeclaration)
}
fun isExternal(typeParameter: IrTypeParameter): Boolean {
return potentiallyCapturedTypeParameters.contains(typeParameter)
}
fun addPotentiallyCapturedTypeParameter(param: IrTypeParameter) {
potentiallyCapturedTypeParameters.add(param)
}
fun seeType(type: IrType) {
if (type !is IrSimpleType) return
(type.classifier as? IrTypeParameterSymbol)?.let { typeParameterSymbol ->
val typeParameter = typeParameterSymbol.owner
if (isExternal(typeParameter))
capturedTypeParameters.add(typeParameter)
}
for (arg in type.arguments) {
(arg as? IrTypeProjection)?.let { seeType(arg.type) }
}
type.abbreviation?.let { abbreviation ->
for (arg in abbreviation.arguments) {
(arg as? IrTypeProjection)?.let { seeType(arg.type) }
}
}
}
}
private inner class ClosureCollectorVisitor : IrElementVisitorVoid {
@@ -129,6 +158,8 @@ class ClosureAnnotator(irFile: IrFile) {
constructor.valueParameters.forEach { v -> closureBuilder.declareVariable(v) }
}
collectPotentiallyCapturedTypeParameters(closureBuilder)
closuresStack.push(closureBuilder)
declaration.acceptChildrenVoid(this)
closuresStack.pop()
@@ -154,6 +185,7 @@ class ClosureAnnotator(irFile: IrFile) {
}
}
collectPotentiallyCapturedTypeParameters(closureBuilder)
closuresStack.push(closureBuilder)
declaration.acceptChildrenVoid(this)
@@ -188,6 +220,16 @@ class ClosureAnnotator(irFile: IrFile) {
expression.setter?.let { processMemberAccess(it.owner) }
}
override fun visitExpression(expression: IrExpression) {
super.visitExpression(expression)
val typeParameterContainerScopeBuilder = closuresStack.peek()?.let {
if (it.owner is IrConstructor) {
closuresStack[closuresStack.size - 2]
} else it
}
typeParameterContainerScopeBuilder?.seeType(expression.type)
}
private fun processMemberAccess(declaration: IrDeclaration) {
if (declaration.isLocal) {
if (declaration is IrSimpleFunction && declaration.visibility != Visibilities.LOCAL) {
@@ -200,5 +242,15 @@ class ClosureAnnotator(irFile: IrFile) {
}
}
}
private fun collectPotentiallyCapturedTypeParameters(closureBuilder: ClosureBuilder) {
closuresStack.takeLastWhile { it.owner !is IrClass }.forEach {
(it.owner as? IrTypeParametersContainer)?.let { container ->
for (tp in container.typeParameters) {
closureBuilder.addPotentiallyCapturedTypeParameter(tp)
}
}
}
}
}
}
@@ -30,12 +30,17 @@ import org.jetbrains.kotlin.ir.descriptors.WrappedSimpleFunctionDescriptor
import org.jetbrains.kotlin.ir.descriptors.WrappedValueParameterDescriptor
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
import org.jetbrains.kotlin.ir.symbols.impl.IrConstructorSymbolImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrFieldSymbolImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrValueParameterSymbolImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.types.impl.IrSimpleTypeImpl
import org.jetbrains.kotlin.ir.types.impl.IrTypeAbbreviationImpl
import org.jetbrains.kotlin.ir.types.impl.IrUninitializedType
import org.jetbrains.kotlin.ir.types.impl.makeTypeProjection
import org.jetbrains.kotlin.ir.util.constructedClass
import org.jetbrains.kotlin.ir.util.dump
import org.jetbrains.kotlin.ir.util.parentAsClass
@@ -77,6 +82,14 @@ object BOUND_VALUE_PARAMETER : IrDeclarationOriginImpl("BOUND_VALUE_PARAMETER")
object BOUND_RECEIVER_PARAMETER : IrDeclarationOriginImpl("BOUND_RECEIVER_PARAMETER")
/*
Local functions raised in LocalDeclarationLowering continue to refer to
type parameters no longer visible to them.
We add new type parameters to their declarations, which
makes JVM accept those declarations. The generated IR is still
semantically incorrect (TODO: needs further fix), but code generation seems
to proceed nevertheless.
*/
class LocalDeclarationsLowering(
val context: BackendContext,
val localNameProvider: LocalNameProvider = LocalNameProvider.DEFAULT,
@@ -98,6 +111,8 @@ class LocalDeclarationsLowering(
}
private abstract class LocalContext {
val capturedTypeParameterToTypeParameter: MutableMap<IrTypeParameter, IrTypeParameter> = mutableMapOf()
/**
* @return the expression to get the value for given declaration, or `null` if [IrGetValue] should be used.
*/
@@ -165,6 +180,29 @@ class LocalDeclarationsLowering(
}
private fun LocalContext.remapType(type: IrType): IrType {
if (type !is IrSimpleType) return type
val classifier = (type.classifier as? IrTypeParameterSymbol)?.let { capturedTypeParameterToTypeParameter[it.owner]?.symbol }
?: type.classifier
val arguments = type.arguments.map { remapTypeArgument(it) }
return IrSimpleTypeImpl(
classifier, type.hasQuestionMark, arguments, type.annotations,
type.abbreviation?.let { remapTypeAbbreviation(it) }
)
}
private fun LocalContext.remapTypeArgument(argument: IrTypeArgument) =
(argument as? IrTypeProjection)?.let { makeTypeProjection(remapType(it.type), it.variance) }
?: argument
private fun LocalContext.remapTypeAbbreviation(abbreviation: IrTypeAbbreviation): IrTypeAbbreviation =
IrTypeAbbreviationImpl(
abbreviation.typeAlias, // TODO: if/when the language gets local or nested type aliases, this will need remapping.
abbreviation.hasQuestionMark,
abbreviation.arguments.map { remapTypeArgument(it) },
abbreviation.annotations
)
private inner class LocalDeclarationsTransformer(val irFile: IrFile) {
val localFunctions: MutableMap<IrFunction, LocalFunctionContext> = LinkedHashMap()
val localClasses: MutableMap<IrClass, LocalClassContext> = LinkedHashMap()
@@ -349,11 +387,12 @@ class LocalDeclarationsLowering(
expression.startOffset, expression.endOffset,
expression.type, // TODO functional type for transformed descriptor
newCallee.symbol,
expression.typeArgumentsCount,
newCallee.typeParameters.size,
expression.origin
).also {
it.fillArguments2(expression, newCallee)
it.copyTypeArgumentsFrom(expression)
it.setLocalTypeArguments(oldCallee)
it.copyTypeArgumentsFrom(expression, shift = newCallee.typeParameters.size - expression.typeArgumentsCount)
it.copyAttributes(expression)
}
}
@@ -447,11 +486,12 @@ class LocalDeclarationsLowering(
oldCall.startOffset, oldCall.endOffset,
newCallee.returnType,
newCallee.symbol,
oldCall.typeArgumentsCount,
newCallee.typeParameters.size,
oldCall.origin,
oldCall.superQualifierSymbol
).also {
it.copyTypeArgumentsFrom(oldCall)
it.setLocalTypeArguments(oldCall.symbol.owner)
it.copyTypeArgumentsFrom(oldCall, shift = newCallee.typeParameters.size - oldCall.typeArgumentsCount)
}
private fun createNewCall(oldCall: IrConstructorCall, newCallee: IrConstructor) =
@@ -465,6 +505,13 @@ class LocalDeclarationsLowering(
it.copyTypeArgumentsFrom(oldCall)
}
private fun IrMemberAccessExpression.setLocalTypeArguments(callee: IrFunction) {
val context = localFunctions[callee] ?: return
for ((outerTypeParameter, innerTypeParameter) in context.capturedTypeParameterToTypeParameter) {
putTypeArgument(innerTypeParameter.index, outerTypeParameter.defaultType) // TODO: remap default type!
}
}
private fun transformDeclarations() {
localFunctions.values.forEach {
createLiftedDeclaration(it)
@@ -502,6 +549,7 @@ class LocalDeclarationsLowering(
private fun createLiftedDeclaration(localFunctionContext: LocalFunctionContext) {
val oldDeclaration = localFunctionContext.declaration
assert(oldDeclaration.dispatchReceiverParameter == null)
val memberOwner = localFunctionContext.ownerForLoweredDeclaration
val newDescriptor = WrappedSimpleFunctionDescriptor(oldDeclaration.descriptor)
@@ -512,10 +560,8 @@ class LocalDeclarationsLowering(
throw AssertionError("local functions must not have dispatch receiver")
}
val newDispatchReceiverParameter = null
// TODO: consider using fields to access the closure of enclosing class.
val capturedValues = localFunctionContext.closure.capturedValues
val (capturedValues, capturedTypeParameters) = localFunctionContext.closure
val newDeclaration = IrFunctionImpl(
oldDeclaration.startOffset,
@@ -525,7 +571,7 @@ class LocalDeclarationsLowering(
newName,
Visibilities.PRIVATE,
Modality.FINAL,
oldDeclaration.returnType,
returnType = IrUninitializedType,
isInline = oldDeclaration.isInline,
isExternal = oldDeclaration.isExternal,
isTailrec = oldDeclaration.isTailrec,
@@ -538,17 +584,29 @@ class LocalDeclarationsLowering(
localFunctionContext.transformedDeclaration = newDeclaration
newDeclaration.parent = memberOwner
val newTypeParameters = newDeclaration.copyTypeParameters(capturedTypeParameters)
localFunctionContext.capturedTypeParameterToTypeParameter.putAll(
capturedTypeParameters.zip(newTypeParameters)
)
newDeclaration.copyTypeParametersFrom(oldDeclaration)
newDeclaration.dispatchReceiverParameter = newDispatchReceiverParameter
// Type parameters of oldDeclaration may depend on captured type parameters, so deal with that after copying.
newDeclaration.typeParameters.drop(newTypeParameters.size).forEach { tp ->
tp.superTypes.replaceAll { localFunctionContext.remapType(it) }
}
newDeclaration.parent = memberOwner
newDeclaration.returnType = localFunctionContext.remapType(oldDeclaration.returnType)
newDeclaration.dispatchReceiverParameter = null
newDeclaration.extensionReceiverParameter = oldDeclaration.extensionReceiverParameter?.run {
copyTo(newDeclaration).also {
copyTo(newDeclaration, type = localFunctionContext.remapType(this.type)).also {
newParameterToOld.putAbsentOrSame(it, this)
}
}
newDeclaration.copyAttributes(oldDeclaration)
newDeclaration.valueParameters += createTransformedValueParameters(capturedValues, oldDeclaration, newDeclaration)
newDeclaration.valueParameters += createTransformedValueParameters(
capturedValues, localFunctionContext, oldDeclaration, newDeclaration
)
newDeclaration.recordTransformedValueParameters(localFunctionContext)
newDeclaration.annotations.addAll(oldDeclaration.annotations)
@@ -558,6 +616,7 @@ class LocalDeclarationsLowering(
private fun createTransformedValueParameters(
capturedValues: List<IrValueSymbol>,
localFunctionContext: LocalContext,
oldDeclaration: IrFunction,
newDeclaration: IrFunction
) = ArrayList<IrValueParameter>(capturedValues.size + oldDeclaration.valueParameters.size).apply {
@@ -573,7 +632,7 @@ class LocalDeclarationsLowering(
IrValueParameterSymbolImpl(parameterDescriptor),
suggestNameForCapturedValue(p, generatedNames),
i,
p.type,
localFunctionContext.remapType(p.type),
null,
isCrossinline = false,
isNoinline = false
@@ -585,7 +644,11 @@ class LocalDeclarationsLowering(
}
oldDeclaration.valueParameters.mapTo(this) { v ->
v.copyTo(newDeclaration, index = v.index + capturedValues.size).also {
v.copyTo(
newDeclaration,
index = v.index + capturedValues.size,
type = localFunctionContext.remapType(v.type)
).also {
newParameterToOld.putAbsentOrSame(it, v)
}
}
@@ -644,7 +707,7 @@ class LocalDeclarationsLowering(
throw AssertionError("Local class constructor can't have extension receiver: ${ir2string(oldDeclaration)}")
}
newDeclaration.valueParameters += createTransformedValueParameters(capturedValues, oldDeclaration, newDeclaration)
newDeclaration.valueParameters += createTransformedValueParameters(capturedValues, localClassContext, oldDeclaration, newDeclaration)
newDeclaration.recordTransformedValueParameters(constructorContext)
newDeclaration.metadata = oldDeclaration.metadata
@@ -78,8 +78,8 @@ private class AddContinuationLowering(private val context: JvmBackendContext) :
"Inconsistency between callable reference to suspend lambda and the corresponding continuation"
}
+irCall(constructor.symbol).apply {
for (tp in info.constructor.parentAsClass.typeParameters) {
putTypeArgument(tp.index, expression.getTypeArgument(tp.index))
for (typeParameter in info.constructor.parentAsClass.typeParameters) {
putTypeArgument(typeParameter.index, expression.getTypeArgument(typeParameter.index))
}
expressionArguments.forEachIndexed { index, argument ->
putValueArgument(index, argument)
@@ -134,7 +134,7 @@ private class AddContinuationLowering(private val context: JvmBackendContext) :
if (info.arity <= 1) {
val singleParameterField = receiverField ?: parametersWithoutArguments.singleOrNull()
val create = addCreate(constructor, suspendLambda, info, parametersWithArguments, singleParameterField)
addInvokeCallingCreate(info.function, create, invokeSuspend, invokeToOverride, singleParameterField)
addInvokeCallingCreate(create, invokeSuspend, invokeToOverride, singleParameterField)
} else {
addInvokeCallingConstructor(
constructor,
@@ -242,7 +242,6 @@ private class AddContinuationLowering(private val context: JvmBackendContext) :
// 2) starting newly created coroutine by calling `invokeSuspend`.
// Thus, it creates a clone of suspend lambda and starts it.
private fun IrClass.addInvokeCallingCreate(
suspendLambdaFunction: IrFunction,
create: IrFunction,
invokeSuspend: IrFunction,
invokeToOverride: IrSimpleFunctionSymbol,
@@ -289,8 +288,8 @@ private class AddContinuationLowering(private val context: JvmBackendContext) :
function.body = context.createIrBuilder(function.symbol).irBlockBody {
// Create a copy
val newlyCreatedObject = irTemporary(irCall(constructor).also { constructorCall ->
for (tp in typeParameters) {
constructorCall.putTypeArgument(tp.index, tp.defaultType)
for (typeParameter in typeParameters) {
constructorCall.putTypeArgument(typeParameter.index, typeParameter.defaultType)
}
for ((index, field) in parametersWithArguments.withIndex()) {
constructorCall.putValueArgument(index, irGetField(irGet(function.dispatchReceiverParameter!!), field))
@@ -328,8 +327,8 @@ private class AddContinuationLowering(private val context: JvmBackendContext) :
function.body = context.createIrBuilder(function.symbol).irBlockBody {
var index = 0
val constructorCall = irCall(constructor).also {
for (tp in typeParameters) {
it.putTypeArgument(tp.index, tp.defaultType)
for (typeParameter in typeParameters) {
it.putTypeArgument(typeParameter.index, typeParameter.defaultType)
}
for ((i, field) in parametersWithArguments.withIndex()) {
if (info.reference.getValueArgument(i) == null) continue
@@ -36,12 +36,12 @@ interface IrMemberAccessExpression : IrExpression, IrDeclarationReference {
fun IrMemberAccessExpression.getTypeArgument(typeParameterDescriptor: TypeParameterDescriptor): IrType? =
getTypeArgument(typeParameterDescriptor.index)
fun IrMemberAccessExpression.copyTypeArgumentsFrom(other: IrMemberAccessExpression) {
assert(typeArgumentsCount == other.typeArgumentsCount) {
"Mismatching type arguments: $typeArgumentsCount vs ${other.typeArgumentsCount} "
fun IrMemberAccessExpression.copyTypeArgumentsFrom(other: IrMemberAccessExpression, shift: Int = 0) {
assert(typeArgumentsCount == other.typeArgumentsCount + shift) {
"Mismatching type arguments: $typeArgumentsCount vs ${other.typeArgumentsCount} + $shift"
}
for (i in 0 until typeArgumentsCount) {
putTypeArgument(i, other.getTypeArgument(i))
for (i in 0 until other.typeArgumentsCount) {
putTypeArgument(i + shift, other.getTypeArgument(i))
}
}
@@ -0,0 +1,7 @@
package test;
class TypeParamInInner {
void check() {
TypeParamInInnerKt.outer("OK");
}
}
@@ -0,0 +1,31 @@
package test
class outerClass<T>(val t: T) {
inner class innerClass {
fun getT() = t
}
}
fun <T> outer(arg: T): T {
class localClass(val v: T) {
init {
fun innerFunInLocalClass() = v
val vv = innerFunInLocalClass()
}
fun member() = v
}
fun innerFun(): T {
class localClassInLocalFunction {
val v = arg
}
return localClass(arg).member()
}
fun <X> innerFunWithOwnTypeParam(x: X) = x
innerFunWithOwnTypeParam(arg)
return innerFun()
}
@@ -0,0 +1,18 @@
package test
public fun </*0*/ T> outer(/*0*/ T): T
public/*package*/ open class TypeParamInInner {
public/*package*/ constructor TypeParamInInner()
public/*package*/ open fun check(): kotlin.Unit
}
public final class outerClass</*0*/ T> {
public constructor outerClass</*0*/ T>(/*0*/ T)
public final val t: T
public final inner class innerClass /*captured type parameters: /*0*/ T*/ {
public constructor innerClass()
public final fun getT(): T
}
}
@@ -313,6 +313,11 @@ public class CompileJavaAgainstKotlinTestGenerated extends AbstractCompileJavaAg
runTest("compiler/testData/compileJavaAgainstKotlin/method/TraitImpl.kt");
}
@TestMetadata("TypeParamInInner.kt")
public void testTypeParamInInner() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/TypeParamInInner.kt");
}
@TestMetadata("Vararg.kt")
public void testVararg() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/Vararg.kt");
@@ -956,6 +961,11 @@ public class CompileJavaAgainstKotlinTestGenerated extends AbstractCompileJavaAg
runTest("compiler/testData/compileJavaAgainstKotlin/method/TraitImpl.kt");
}
@TestMetadata("TypeParamInInner.kt")
public void testTypeParamInInner() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/TypeParamInInner.kt");
}
@TestMetadata("Vararg.kt")
public void testVararg() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/Vararg.kt");
@@ -311,6 +311,11 @@ public class IrCompileJavaAgainstKotlinTestGenerated extends AbstractIrCompileJa
runTest("compiler/testData/compileJavaAgainstKotlin/method/TraitImpl.kt");
}
@TestMetadata("TypeParamInInner.kt")
public void testTypeParamInInner() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/TypeParamInInner.kt");
}
@TestMetadata("Vararg.kt")
public void testVararg() throws Exception {
runTest("compiler/testData/compileJavaAgainstKotlin/method/Vararg.kt");