JVM_IR KT-44974 fix SAM-converted capturing extension lambda

This commit is contained in:
Dmitry Petrov
2021-02-16 17:47:11 +03:00
parent 83ed67546b
commit 56a104dda9
10 changed files with 121 additions and 20 deletions
@@ -19987,6 +19987,12 @@ public class FirBlackBoxCodegenTestGenerated extends AbstractFirBlackBoxCodegenT
runTest("compiler/testData/codegen/box/invokedynamic/sam/samConversionOnFunctionReference.kt");
}
@Test
@TestMetadata("samExtFunWithCapturingLambda.kt")
public void testSamExtFunWithCapturingLambda() throws Exception {
runTest("compiler/testData/codegen/box/invokedynamic/sam/samExtFunWithCapturingLambda.kt");
}
@Test
@TestMetadata("simpleFunInterfaceConstructor.kt")
public void testSimpleFunInterfaceConstructor() throws Exception {
@@ -134,6 +134,8 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
return super.visitTypeOperator(expression)
}
val samSuperType = expression.typeOperand
val invokable = expression.argument
val reference = if (invokable is IrFunctionReference) {
invokable
@@ -145,8 +147,6 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
}
reference.transformChildrenVoid()
val samSuperType = expression.typeOperand
if (shouldGenerateIndySamConversions) {
val lambdaMetafactoryArguments =
LambdaMetafactoryArgumentsBuilder(context, crossinlineLambdas)
@@ -169,15 +169,23 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
wrapWithIndySamConversion(samType, lambdaMetafactoryArguments)
}
is IrBlock -> {
val indySamConversion = wrapWithIndySamConversion(samType, lambdaMetafactoryArguments)
argument.statements[argument.statements.size - 1] = indySamConversion
argument.type = indySamConversion.type
return argument
wrapFunctionReferenceInsideBlockWithIndySamConversion(samType, lambdaMetafactoryArguments, argument)
}
else -> throw AssertionError("Block or function reference expected: ${expression.render()}")
}
}
private fun wrapFunctionReferenceInsideBlockWithIndySamConversion(
samType: IrType,
lambdaMetafactoryArguments: LambdaMetafactoryArguments,
block: IrBlock
): IrExpression {
val indySamConversion = wrapWithIndySamConversion(samType, lambdaMetafactoryArguments)
block.statements[block.statements.size - 1] = indySamConversion
block.type = indySamConversion.type
return block
}
private val jvmIndyLambdaMetafactoryIntrinsic = context.ir.symbols.indyLambdaMetafactoryIntrinsic
private val specialNullabilityAnnotationsFqNames =
@@ -690,3 +698,5 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
declaration.parent.let { it is IrClass && it.isFileClass }
}
}
@@ -6,6 +6,7 @@
package org.jetbrains.kotlin.backend.jvm.lower
import org.jetbrains.kotlin.backend.common.ir.allOverridden
import org.jetbrains.kotlin.backend.common.lower.VariableRemapper
import org.jetbrains.kotlin.backend.common.lower.parents
import org.jetbrains.kotlin.backend.jvm.JvmBackendContext
import org.jetbrains.kotlin.backend.jvm.ir.erasedUpperBound
@@ -14,13 +15,19 @@ import org.jetbrains.kotlin.backend.jvm.ir.isCompiledToJvmDefault
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.builders.declarations.buildClass
import org.jetbrains.kotlin.ir.builders.declarations.buildValueParameter
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionReferenceImpl
import org.jetbrains.kotlin.ir.overrides.buildFakeOverrideMember
import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.util.dump
import org.jetbrains.kotlin.ir.util.functions
import org.jetbrains.kotlin.ir.util.getInlineClassUnderlyingType
import org.jetbrains.kotlin.ir.util.render
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.Name
class LambdaMetafactoryArguments(
@@ -192,10 +199,12 @@ class LambdaMetafactoryArgumentsBuilder(
adaptLambdaSignature(implLambda, fakeInstanceMethod, signatureAdaptationConstraints)
val newReference = remapExtensionLambda(implLambda, reference)
if (samMethod.isFakeOverride && nonFakeOverriddenFuns.size == 1) {
return LambdaMetafactoryArguments(nonFakeOverriddenFuns.single(), fakeInstanceMethod, reference, listOf())
return LambdaMetafactoryArguments(nonFakeOverriddenFuns.single(), fakeInstanceMethod, newReference, listOf())
}
return LambdaMetafactoryArguments(samMethod, fakeInstanceMethod, reference, nonFakeOverriddenFuns)
return LambdaMetafactoryArguments(samMethod, fakeInstanceMethod, newReference, nonFakeOverriddenFuns)
}
private fun adaptLambdaSignature(
@@ -223,8 +232,51 @@ class LambdaMetafactoryArgumentsBuilder(
if (constraints.returnType == TypeAdaptationConstraint.FORCE_BOXING) {
lambda.returnType = lambda.returnType.makeNullable()
}
}
private fun remapExtensionLambda(lambda: IrSimpleFunction, reference: IrFunctionReference): IrFunctionReference {
val oldExtensionReceiver = lambda.extensionReceiverParameter
?: return reference
val newValueParameters = ArrayList<IrValueParameter>()
val oldToNew = HashMap<IrValueParameter, IrValueParameter>()
var newParameterIndex = 0
newValueParameters.add(
oldExtensionReceiver.copy(lambda, newParameterIndex++, Name.identifier("\$receiver")).also {
oldToNew[oldExtensionReceiver] = it
}
)
lambda.valueParameters.mapTo(newValueParameters) { oldParameter ->
oldParameter.copy(lambda, newParameterIndex++).also {
oldToNew[oldParameter] = it
}
}
lambda.body?.transformChildrenVoid(VariableRemapper(oldToNew))
lambda.extensionReceiverParameter = null
lambda.valueParameters = newValueParameters
return IrFunctionReferenceImpl(
reference.startOffset, reference.endOffset, reference.type,
lambda.symbol,
typeArgumentsCount = 0,
valueArgumentsCount = newValueParameters.size,
reflectionTarget = null,
origin = reference.origin
)
}
private fun IrValueParameter.copy(parent: IrSimpleFunction, newIndex: Int, newName: Name = this.name): IrValueParameter =
buildValueParameter(parent) {
updateFrom(this@copy)
index = newIndex
name = newName
}
private fun adaptFakeInstanceMethodSignature(fakeInstanceMethod: IrSimpleFunction, constraints: SignatureAdaptationConstraints) {
for ((valueParameter, constraint) in constraints.valueParameters) {
if (valueParameter.parent != fakeInstanceMethod)
@@ -406,13 +458,13 @@ class LambdaMetafactoryArgumentsBuilder(
private fun IrType.isJvmPrimitiveType() =
isBoolean() || isChar() || isByte() || isShort() || isInt() || isLong() || isFloat() || isDouble()
}
private fun collectValueParameters(irFun: IrFunction): List<IrValueParameter> {
if (irFun.extensionReceiverParameter == null)
return irFun.valueParameters
return ArrayList<IrValueParameter>().apply {
add(irFun.extensionReceiverParameter!!)
addAll(irFun.valueParameters)
}
fun collectValueParameters(irFun: IrFunction): List<IrValueParameter> {
if (irFun.extensionReceiverParameter == null)
return irFun.valueParameters
return ArrayList<IrValueParameter>().apply {
add(irFun.extensionReceiverParameter!!)
addAll(irFun.valueParameters)
}
}
@@ -299,13 +299,14 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil
dynamicCallArguments.add(extensionReceiver)
}
val samMethodValueParametersCount = samMethod.valueParameters.size
val samMethodValueParametersCount = samMethod.valueParameters.size +
if (samMethod.extensionReceiverParameter != null && irFunRef.extensionReceiver == null) 1 else 0
val targetFunValueParametersCount = targetFun.valueParameters.size
for (i in 0 until targetFunValueParametersCount - samMethodValueParametersCount) {
val targetFunValueParameter = targetFun.valueParameters[i]
addValueParameter("p${syntheticParameterIndex++}", targetFunValueParameter.type)
val capturedValueArgument = irFunRef.getValueArgument(i)
?: fail("Captured value argument #$i (${targetFunValueParameter.name} not provided")
?: fail("Captured value argument #$i (${targetFunValueParameter.name}) not provided")
dynamicCallArguments.add(capturedValueArgument)
}
}
@@ -17,7 +17,7 @@ class Test {
}
}
// FILE: test.kt
// FILE: samFunExpression.kt
import java.lang.reflect.Method
import kotlin.test.assertEquals
@@ -17,7 +17,7 @@ class Test {
}
}
// FILE: test.kt
// FILE: samFunReference.kt
import java.lang.reflect.Method
import kotlin.test.assertEquals
@@ -0,0 +1,15 @@
// TARGET_BACKEND: JVM
// JVM_TARGET: 1.8
// SAM_CONVERSIONS: INDY
fun interface IExtFun {
fun String.foo(s: String) : String
}
fun box(): String {
val oChar = 'O'
val iExtFun = IExtFun { this + oChar.toString() + it }
return with(iExtFun) {
"".foo("K")
}
}
@@ -19987,6 +19987,12 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest {
runTest("compiler/testData/codegen/box/invokedynamic/sam/samConversionOnFunctionReference.kt");
}
@Test
@TestMetadata("samExtFunWithCapturingLambda.kt")
public void testSamExtFunWithCapturingLambda() throws Exception {
runTest("compiler/testData/codegen/box/invokedynamic/sam/samExtFunWithCapturingLambda.kt");
}
@Test
@TestMetadata("simpleFunInterfaceConstructor.kt")
public void testSimpleFunInterfaceConstructor() throws Exception {
@@ -19987,6 +19987,12 @@ public class IrBlackBoxCodegenTestGenerated extends AbstractIrBlackBoxCodegenTes
runTest("compiler/testData/codegen/box/invokedynamic/sam/samConversionOnFunctionReference.kt");
}
@Test
@TestMetadata("samExtFunWithCapturingLambda.kt")
public void testSamExtFunWithCapturingLambda() throws Exception {
runTest("compiler/testData/codegen/box/invokedynamic/sam/samExtFunWithCapturingLambda.kt");
}
@Test
@TestMetadata("simpleFunInterfaceConstructor.kt")
public void testSimpleFunInterfaceConstructor() throws Exception {
@@ -16758,6 +16758,11 @@ public class LightAnalysisModeTestGenerated extends AbstractLightAnalysisModeTes
runTest("compiler/testData/codegen/box/invokedynamic/sam/samConversionOnFunctionReference.kt");
}
@TestMetadata("samExtFunWithCapturingLambda.kt")
public void testSamExtFunWithCapturingLambda() throws Exception {
runTest("compiler/testData/codegen/box/invokedynamic/sam/samExtFunWithCapturingLambda.kt");
}
@TestMetadata("simpleFunInterfaceConstructor.kt")
public void testSimpleFunInterfaceConstructor() throws Exception {
runTest("compiler/testData/codegen/box/invokedynamic/sam/simpleFunInterfaceConstructor.kt");