diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/AddContinuationLowering.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/AddContinuationLowering.kt index 030698017bd..6e60a7356a4 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/AddContinuationLowering.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/AddContinuationLowering.kt @@ -54,12 +54,17 @@ internal val addContinuationPhase = makeIrFilePhase( private class AddContinuationLowering(private val context: JvmBackendContext) : FileLoweringPass { override fun lower(irFile: IrFile) { val suspendLambdas = findSuspendAndInlineLambdas(irFile) - addContinuationObjectAndContinuationParameterToSuspendFunctions(irFile) transformSuspendLambdasIntoContinuations(irFile, suspendLambdas) - addContinuationParameterToSuspendCallsAndUpdateNonLocalReturns(irFile, suspendLambdas) + // This should be done after converting lambdas into classes to avoid breaking the invariant that + // each lambda is referenced at most once while creating `$$forInline` methods. + addContinuationObjectAndContinuationParameterToSuspendFunctions(irFile) + addContinuationParameterToSuspendCalls(irFile) + // This should be done after adding continuation parameters so that `attributeContainerId` links + // inside `$$forInline` copies of `invokeSuspend` do not confuse the previous passes. + fillInvokeSuspendForInlineBodies(irFile) } - private fun addContinuationParameterToSuspendCallsAndUpdateNonLocalReturns(irFile: IrFile, suspendLambdas: List) { + private fun addContinuationParameterToSuspendCalls(irFile: IrFile) { irFile.transformChildrenVoid(object : IrElementTransformerVoid() { val functionStack = mutableListOf() @@ -75,13 +80,6 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : return (super.visitCall(expression) as IrCall) .createSuspendFunctionCallViewIfNeeded(context, caller) } - - override fun visitReturn(expression: IrReturn): IrExpression { - val ret = super.visitReturn(expression) as IrReturn - val irFunction = expression.returnTargetSymbol.owner as? IrFunction ?: return ret - val target = suspendLambdas.find { it.function == irFunction }?.invokeSuspend ?: irFunction - return IrReturnImpl(ret.startOffset, ret.endOffset, ret.type, target.symbol, ret.value) - } }) } @@ -168,9 +166,8 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : } val invokeSuspend = addInvokeSuspendForLambda(info.function, parametersFields, receiverField) if (info.capturesCrossinline) { - addInvokeSuspendForInlineForLambda(invokeSuspend, info.function, parametersFields, receiverField) + addInvokeSuspendForInlineForLambda(invokeSuspend) } - info.invokeSuspend = invokeSuspend info.function.parentAsClass.declarations.remove(info.function) if (info.arity <= 1) { val singleParameterField = receiverField ?: parametersWithoutArguments.singleOrNull() @@ -201,15 +198,23 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : it.owner.name.asString() == INVOKE_SUSPEND_METHOD_NAME && it.owner.valueParameters.size == 1 && it.owner.valueParameters[0].type.isKotlinResult() }.owner - return addFunctionOverride(superMethod).also { it.copySuspendLambdaBodyFrom(irFunction, receiverField, fields) } + return addFunctionOverride(superMethod).apply { + val parametersToFields = mutableMapOf() + assert(irFunction.dispatchReceiverParameter == null) // LocalDeclarationsLowering-generated methods are static + irFunction.extensionReceiverParameter?.let { parametersToFields[it] = receiverField!! } + irFunction.valueParameters.zip(fields).toMap(parametersToFields) + body = irFunction.moveBodyTo(this, mapOf()) + body?.transformChildrenVoid(object : IrElementTransformerVoid() { + override fun visitGetValue(expression: IrGetValue): IrExpression { + val field = parametersToFields[expression.symbol.owner] ?: return expression + val receiver = IrGetValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, dispatchReceiverParameter!!.symbol) + return IrGetFieldImpl(expression.startOffset, expression.endOffset, field.symbol, field.type, receiver) + } + }) + } } - private fun IrClass.addInvokeSuspendForInlineForLambda( - invokeSuspend: IrFunction, - irFunction: IrFunction, - fields: List, - receiverField: IrField? - ): IrFunction { + private fun IrClass.addInvokeSuspendForInlineForLambda(invokeSuspend: IrFunction): IrFunction { return addFunction( INVOKE_SUSPEND_METHOD_NAME + FOR_INLINE_SUFFIX, context.irBuiltIns.anyNType, @@ -217,47 +222,19 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : origin = JvmLoweredDeclarationOrigin.FOR_INLINE_STATE_MACHINE_TEMPLATE ).apply { valueParameters += invokeSuspend.valueParameters.map { it.copyTo(this) } - }.also { it.copySuspendLambdaBodyFrom(irFunction, receiverField, fields) } + } } - private fun IrSimpleFunction.copySuspendLambdaBodyFrom( - irFunction: IrFunction, - receiverField: IrField?, - fields: List - ) { - body = irFunction.body?.deepCopyWithSymbols(this) - body?.transformChildrenVoid(object : IrElementTransformerVoid() { - override fun visitGetValue(expression: IrGetValue): IrExpression { - if (expression.symbol.owner == irFunction.extensionReceiverParameter) { - assert(receiverField != null) - return IrGetFieldImpl( - expression.startOffset, - expression.endOffset, - receiverField!!.symbol, - receiverField.type - ).also { - it.receiver = - IrGetValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, dispatchReceiverParameter!!.symbol) + private fun fillInvokeSuspendForInlineBodies(irFile: IrFile) { + irFile.transformChildrenVoid(object : IrElementTransformerVoid() { + override fun visitClass(declaration: IrClass): IrStatement = declaration.transformPostfix { + if (origin == JvmLoweredDeclarationOrigin.SUSPEND_LAMBDA) { + for (function in functions) { + if (function.origin == JvmLoweredDeclarationOrigin.FOR_INLINE_STATE_MACHINE_TEMPLATE) { + function.body = functions.single { it.name.asString() == "invokeSuspend" }.copyBodyTo(function) + } } - } else if (expression.symbol.owner == irFunction.dispatchReceiverParameter) { - return IrGetValueImpl(expression.startOffset, expression.endOffset, dispatchReceiverParameter!!.symbol) } - val field = fields.find { it.name == expression.symbol.owner.name } ?: return expression - return IrGetFieldImpl(expression.startOffset, expression.endOffset, field.symbol, field.type).also { - it.receiver = IrGetValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, dispatchReceiverParameter!!.symbol) - } - } - - // If the suspend lambda body contains declarations of other classes (for other lambdas), - // do not rewrite those. In particular, that could lead to rewriting of returns in nested - // lambdas to unintended non-local returns. - override fun visitClass(declaration: IrClass): IrStatement { - return declaration - } - - override fun visitReturn(expression: IrReturn): IrExpression { - val ret = super.visitReturn(expression) as IrReturn - return IrReturnImpl(ret.startOffset, ret.endOffset, ret.type, symbol, ret.value) } }) } @@ -619,12 +596,9 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : else JvmLoweredDeclarationOrigin.FOR_INLINE_STATE_MACHINE_TEMPLATE_CAPTURES_CROSSINLINE }.apply { annotations += view.annotations.map { it.deepCopyWithSymbols(this) } - copyTypeParameters(view.typeParameters) - dispatchReceiverParameter = view.dispatchReceiverParameter?.copyTo(this) - extensionReceiverParameter = view.extensionReceiverParameter?.copyTo(this) - valueParameters += view.valueParameters.map { it.copyTo(this) } - body = view.copyBodyTo(this) + copyParameterDeclarationsFrom(view) copyAttributes(view) + body = view.copyBodyTo(this) } registerNewFunction(newFunction) } @@ -728,7 +702,6 @@ private class AddContinuationLowering(private val context: JvmBackendContext) : val capturesCrossinline: Boolean ) { lateinit var constructor: IrConstructor - lateinit var invokeSuspend: IrFunction } } diff --git a/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt b/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt new file mode 100644 index 00000000000..a4030bd2f3b --- /dev/null +++ b/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt @@ -0,0 +1,19 @@ +// IGNORE_BACKEND_FIR: JVM_IR +// TARGET_BACKEND: JVM +// WITH_REFLECT +// WITH_COROUTINES +// FILE: a.kt + +import helpers.* +import kotlin.coroutines.* +import kotlin.reflect.jvm.reflect + +suspend fun f() = { OK: String -> } + +fun box(): String { + lateinit var x: (String) -> Unit + suspend { + x = f() + }.startCoroutine(EmptyContinuation) + return x.reflect()?.parameters?.singleOrNull()?.name ?: "null" +} diff --git a/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt b/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt new file mode 100644 index 00000000000..f0115bcf8d0 --- /dev/null +++ b/compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt @@ -0,0 +1,17 @@ +// IGNORE_BACKEND_FIR: JVM_IR +// TARGET_BACKEND: JVM +// WITH_REFLECT +// WITH_COROUTINES +// FILE: a.kt + +import helpers.* +import kotlin.coroutines.* +import kotlin.reflect.jvm.reflect + +fun box(): String { + lateinit var x: (String) -> Unit + suspend { + x = { OK: String -> } + }.startCoroutine(EmptyContinuation) + return x.reflect()?.parameters?.singleOrNull()?.name ?: "null" +} diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java index 6a5ed1aa5f3..991c50117fd 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java @@ -24445,6 +24445,16 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest { public void testReflectOnLambdaInStaticField() throws Exception { runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInStaticField.kt"); } + + @TestMetadata("reflectOnLambdaInSuspend.kt") + public void testReflectOnLambdaInSuspend() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt"); + } + + @TestMetadata("reflectOnLambdaInSuspendLambda.kt") + public void testReflectOnLambdaInSuspendLambda() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt"); + } } @TestMetadata("compiler/testData/codegen/box/reflection/mapping") diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java index 3adadddf36d..cb9171c27c9 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java @@ -23262,6 +23262,16 @@ public class LightAnalysisModeTestGenerated extends AbstractLightAnalysisModeTes public void testReflectOnLambdaInStaticField() throws Exception { runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInStaticField.kt"); } + + @TestMetadata("reflectOnLambdaInSuspend.kt") + public void testReflectOnLambdaInSuspend() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt"); + } + + @TestMetadata("reflectOnLambdaInSuspendLambda.kt") + public void testReflectOnLambdaInSuspendLambda() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt"); + } } @TestMetadata("compiler/testData/codegen/box/reflection/mapping") diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java index ece90e8944e..0c7dab69208 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java @@ -22954,6 +22954,16 @@ public class FirBlackBoxCodegenTestGenerated extends AbstractFirBlackBoxCodegenT public void testReflectOnLambdaInStaticField() throws Exception { runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInStaticField.kt"); } + + @TestMetadata("reflectOnLambdaInSuspend.kt") + public void testReflectOnLambdaInSuspend() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt"); + } + + @TestMetadata("reflectOnLambdaInSuspendLambda.kt") + public void testReflectOnLambdaInSuspendLambda() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt"); + } } @TestMetadata("compiler/testData/codegen/box/reflection/mapping") diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java index 550e68bc7e2..820f56e0c25 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java @@ -22954,6 +22954,16 @@ public class IrBlackBoxCodegenTestGenerated extends AbstractIrBlackBoxCodegenTes public void testReflectOnLambdaInStaticField() throws Exception { runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInStaticField.kt"); } + + @TestMetadata("reflectOnLambdaInSuspend.kt") + public void testReflectOnLambdaInSuspend() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspend.kt"); + } + + @TestMetadata("reflectOnLambdaInSuspendLambda.kt") + public void testReflectOnLambdaInSuspendLambda() throws Exception { + runTest("compiler/testData/codegen/box/reflection/lambdaClasses/reflectOnLambdaInSuspendLambda.kt"); + } } @TestMetadata("compiler/testData/codegen/box/reflection/mapping")