diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt index ddd2e52c54b..5717939012e 100644 --- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt +++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt @@ -34,38 +34,66 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo override fun lower(irFile: IrFile) { val ctorToFactoryMap = mutableMapOf() - irFile.transform(CallableReferenceClassTransformer(ctorToFactoryMap), null) + val ctorToFreeFunctionMap = mutableMapOf() + irFile.transform(CallableReferenceClassTransformer(ctorToFactoryMap, ctorToFreeFunctionMap), null) irFile.transformChildrenVoid(object : IrElementTransformerVoid() { override fun visitConstructorCall(expression: IrConstructorCall): IrExpression { expression.transformChildrenVoid() if (expression.origin != JsStatementOrigins.CALLABLE_REFERENCE_CREATE) return expression - return ctorToFactoryMap[expression.symbol]?.let { factory -> - val newCall = expression.run { - IrCallImpl(startOffset, endOffset, type, factory, typeArgumentsCount, valueArgumentsCount, origin) - } - newCall.dispatchReceiver = expression.dispatchReceiver - newCall.extensionReceiver = expression.extensionReceiver + ctorToFreeFunctionMap[expression.symbol]?.let { liftedLambda -> + return replaceLambdaConstructorCallWithReferenceToLiftedLambda(expression, liftedLambda) + } - for (i in 0 until expression.typeArgumentsCount) { - newCall.putTypeArgument(i, expression.getTypeArgument(i)) - } + ctorToFactoryMap[expression.symbol]?.let { factory -> + return replaceLambdaConstructorCallWithFactoryCall(expression, factory) + } - for (i in 0 until expression.valueArgumentsCount) { - newCall.putValueArgument(i, expression.getValueArgument(i)) - } - - newCall - } ?: expression + return expression } }) } + private fun replaceLambdaConstructorCallWithFactoryCall( + expression: IrConstructorCall, + factory: IrSimpleFunctionSymbol + ): IrCall { + val newCall = expression.run { + IrCallImpl(startOffset, endOffset, type, factory, typeArgumentsCount, valueArgumentsCount, origin) + } + + newCall.dispatchReceiver = expression.dispatchReceiver + newCall.extensionReceiver = expression.extensionReceiver + + for (i in 0 until expression.typeArgumentsCount) { + newCall.putTypeArgument(i, expression.getTypeArgument(i)) + } + + for (i in 0 until expression.valueArgumentsCount) { + newCall.putValueArgument(i, expression.getValueArgument(i)) + } + + return newCall + } + + private fun replaceLambdaConstructorCallWithReferenceToLiftedLambda( + expression: IrConstructorCall, + liftedLambda: IrSimpleFunctionSymbol + ): IrRawFunctionReference = IrRawFunctionReferenceImpl( + expression.startOffset, + expression.endOffset, + expression.type, + liftedLambda, + ) + override fun lower(irBody: IrBody, container: IrDeclaration) { compilationException("Unreachable", irBody) } - private inner class CallableReferenceClassTransformer(private val ctorToFactoryMap: MutableMap) : IrElementTransformerVoid() { + private inner class CallableReferenceClassTransformer( + private val ctorToFactoryMap: MutableMap, + private val ctorToFreeFunctionMap: MutableMap + ) : IrElementTransformerVoid() { override fun visitFile(declaration: IrFile): IrFile { declaration.transformChildrenVoid() declaration.transformDeclarationsFlat { it.transformCallableReference() } @@ -100,7 +128,17 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo } private fun replaceWithFactory(lambdaClass: IrClass): List { - return buildFactoryFunction(lambdaClass, ctorToFactoryMap).onEach { it.parent = lambdaClass.parent } + val lambdaInfo = LambdaInfo(lambdaClass) + + // Optimization: + // If the lambda has no context, we lift it, i.e. instead of generating a factory function that creates lambda objects, + // we generate a named free function. The usage of the lambda is then replaced with a reference to the free function. + // This allows us to avoid allocating a new object each time the lambda is created. + return if (lambdaClass.origin == CallableReferenceLowering.Companion.LAMBDA_IMPL && !lambdaInfo.isSuspendLambda && lambdaClass.fields.none()) { + liftLambda(ctorToFreeFunctionMap, lambdaInfo) + } else { + buildFactoryFunction(ctorToFactoryMap, lambdaInfo) + }.onEach { it.parent = lambdaClass.parent } } } @@ -213,47 +251,42 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo return returnStmt.value } - private fun buildFactoryBody( - factoryFunction: IrSimpleFunction, - lambdaClass: IrClass, - newDeclarations: MutableList - ): IrBlockBody { + private class LambdaInfo(val lambdaClass: IrClass) { val invokeFun = lambdaClass.invokeFun!! val superInvokeFun = invokeFun.overriddenSymbols.first { it.owner.isSuspend == invokeFun.isSuspend }.owner - val lambdaName = Name.identifier("${lambdaClass.name.asString()}\$lambda") + val isSuspendLambda = invokeFun.overriddenSymbols.any { it.owner.isSuspend } - val superClass = superInvokeFun.parentAsClass - val anyNType = context.irBuiltIns.anyNType - val lambdaDeclaration = context.irFactory.buildFun { - startOffset = invokeFun.startOffset - endOffset = invokeFun.endOffset - // Since box/unbox is done on declaration side in case of suspend function use the specified type - returnType = if (invokeFun.isSuspend) invokeFun.returnType else anyNType - visibility = DescriptorVisibilities.LOCAL - name = lambdaName - isSuspend = invokeFun.isSuspend - } + fun createOldToNewInvokeParametersMapping(lambdaDeclaration: IrSimpleFunction) = + invokeFun.valueParameters.associateBy({ it.symbol }, { lambdaDeclaration.valueParameters[it.index].symbol }) - lambdaDeclaration.parent = factoryFunction + fun lambdaInnerClasses() = + lambdaClass.declarations.filter { it is IrClass || (it is IrSimpleFunction && it.dispatchReceiverParameter == null) } + } - lambdaDeclaration.valueParameters = superInvokeFun.valueParameters.mapIndexed { id, vp -> - vp.copyTo(lambdaDeclaration, type = anyNType, name = invokeFun.valueParameters[id].name) - } + private fun buildFactoryBody( + factoryFunction: IrSimpleFunction, + newDeclarations: MutableList, + lambdaInfo: LambdaInfo + ): IrBlockBody { + val superClass = lambdaInfo.superInvokeFun.parentAsClass + val lambdaName = Name.identifier("${lambdaInfo.lambdaClass.name.asString()}\$lambda") + + val lambdaDeclaration = + createLambdaDeclaration(lambdaInfo.invokeFun, lambdaName, factoryFunction, lambdaInfo.superInvokeFun) val statements = ArrayList(4) - val isSuspendLambda = invokeFun.overriddenSymbols.any { it.owner.isSuspend } - val constructor = lambdaClass.declarations.firstNotNullOf { it as? IrConstructor } + val constructor = lambdaInfo.lambdaClass.declarations.firstNotNullOf { it as? IrConstructor } - if (isSuspendLambda) { + if (lambdaInfo.isSuspendLambda) { // Due to suspend lambda is a class itself it's not easy to inline it correctly and moreover I see no reason to do so - val lambdaType = lambdaClass.defaultType + val lambdaType = lambdaInfo.lambdaClass.defaultType val instanceVal = JsIrBuilder.buildVar(lambdaType, factoryFunction, "i").apply { val newCtorCall = IrConstructorCallImpl( - lambdaClass.startOffset, - lambdaClass.endOffset, + lambdaInfo.lambdaClass.startOffset, + lambdaInfo.lambdaClass.endOffset, lambdaType, constructor.symbol, - lambdaClass.typeParameters.size, + lambdaInfo.lambdaClass.typeParameters.size, constructor.typeParameters.size, constructor.valueParameters.size ) @@ -267,31 +300,25 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo statements.add(instanceVal) - lambdaDeclaration.body = buildLambdaBody(instanceVal, lambdaDeclaration, invokeFun) + lambdaDeclaration.body = buildLambdaBody(instanceVal, lambdaDeclaration, lambdaInfo.invokeFun) - newDeclarations.add(lambdaClass) + newDeclarations.add(lambdaInfo.lambdaClass) } else { val fieldToParameterMapping = capturedFieldsToParametersMap(constructor, factoryFunction) - val oldToNewInvokeParametersMapping = mutableMapOf() - invokeFun.valueParameters.forEach { - oldToNewInvokeParametersMapping[it.symbol] = lambdaDeclaration.valueParameters[it.index].symbol - } + val oldToNewInvokeParametersMapping = lambdaInfo.createOldToNewInvokeParametersMapping(lambdaDeclaration) lambdaDeclaration.body = - inlineLambdaBody(lambdaDeclaration, invokeFun, oldToNewInvokeParametersMapping, fieldToParameterMapping) + inlineLambdaBody(lambdaDeclaration, lambdaInfo.invokeFun, oldToNewInvokeParametersMapping, fieldToParameterMapping) - // lambdas could contain another lambdas and local classes in so let do not lose them - val lambdaInnerClasses = - lambdaClass.declarations.filter { it is IrClass || (it is IrSimpleFunction && it.dispatchReceiverParameter == null) } - - newDeclarations.addAll(lambdaInnerClasses) + // lambdas can contain another lambdas and local classes in so let's not lose them + newDeclarations.addAll(lambdaInfo.lambdaInnerClasses()) } - val lambdaType = lambdaClass.superTypes.single { it.classifierOrNull === superClass.symbol } - val functionExpression = lambdaClass.run { + val lambdaType = lambdaInfo.lambdaClass.superTypes.single { it.classifierOrNull === superClass.symbol } + val functionExpression = lambdaInfo.lambdaClass.run { IrFunctionExpressionImpl(startOffset, endOffset, lambdaType, lambdaDeclaration, JsStatementOrigins.CALLABLE_REFERENCE_CREATE) } - val nameGetter = context.mapping.reflectedNameAccessor[lambdaClass] + val nameGetter = context.mapping.reflectedNameAccessor[lambdaInfo.lambdaClass] if (nameGetter != null || lambdaDeclaration.isSuspend) { val tmpVar = JsIrBuilder.buildVar(functionExpression.type, factoryFunction, "l", initializer = functionExpression) @@ -326,23 +353,48 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo statements.add(JsIrBuilder.buildReturn(factoryFunction.symbol, functionExpression, context.irBuiltIns.nothingType)) } - return context.irFactory.createBlockBody(lambdaClass.startOffset, lambdaClass.endOffset, statements) + return context.irFactory.createBlockBody(lambdaInfo.lambdaClass.startOffset, lambdaInfo.lambdaClass.endOffset, statements) + } + + private fun createLambdaDeclaration( + invokeFun: IrSimpleFunction, + lambdaName: Name, + parent: IrDeclarationParent, + superInvokeFun: IrSimpleFunction + ): IrSimpleFunction { + val anyNType = context.irBuiltIns.anyNType + val lambdaDeclaration = context.irFactory.buildFun { + startOffset = invokeFun.startOffset + endOffset = invokeFun.endOffset + // Since box/unbox is done on declaration side in case of suspend function use the specified type + returnType = if (invokeFun.isSuspend) invokeFun.returnType else anyNType + visibility = DescriptorVisibilities.LOCAL + name = lambdaName + isSuspend = invokeFun.isSuspend + } + + lambdaDeclaration.parent = parent + + lambdaDeclaration.valueParameters = superInvokeFun.valueParameters.mapIndexed { id, vp -> + vp.copyTo(lambdaDeclaration, type = anyNType, name = invokeFun.valueParameters[id].name) + } + return lambdaDeclaration } private fun buildFactoryFunction( - lambdaClass: IrClass, - ctorToFactoryMap: MutableMap + ctorToFactoryMap: MutableMap, + lambdaInfo: LambdaInfo ): List { val newDeclarations = mutableListOf() - val constructor = lambdaClass.declarations.single { it is IrConstructor } as IrConstructor + val constructor = lambdaInfo.lambdaClass.constructors.single() - val factoryDeclaration = context.irFactory.stageController.restrictTo(lambdaClass) { + val factoryDeclaration = context.irFactory.stageController.restrictTo(lambdaInfo.lambdaClass) { context.irFactory.buildFun { - startOffset = lambdaClass.startOffset - endOffset = lambdaClass.endOffset - visibility = lambdaClass.visibility - returnType = lambdaClass.defaultType - name = lambdaClass.name + startOffset = lambdaInfo.lambdaClass.startOffset + endOffset = lambdaInfo.lambdaClass.endOffset + visibility = lambdaInfo.lambdaClass.visibility + returnType = lambdaInfo.lambdaClass.defaultType + name = lambdaInfo.lambdaClass.name origin = JsStatementOrigins.FACTORY_ORIGIN } } @@ -355,7 +407,7 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo } } - factoryDeclaration.body = buildFactoryBody(factoryDeclaration, lambdaClass, newDeclarations) + factoryDeclaration.body = buildFactoryBody(factoryDeclaration, newDeclarations, lambdaInfo) newDeclarations.add(factoryDeclaration) ctorToFactoryMap[constructor.symbol] = factoryDeclaration.symbol @@ -363,6 +415,38 @@ class InteropCallableReferenceLowering(val context: JsIrBackendContext) : BodyLo return newDeclarations } + /** + * Replaces a contextless lambda class with a free function. + */ + private fun liftLambda( + ctorToFreeFunctionMap: MutableMap, + lambdaInfo: LambdaInfo + ): List { + val constructor = lambdaInfo.lambdaClass.constructors.single() + val newDeclarations = mutableListOf() + val freeFunctionDeclaration = createLambdaDeclaration( + lambdaInfo.invokeFun, + lambdaInfo.lambdaClass.name, + lambdaInfo.lambdaClass.parent, + lambdaInfo.superInvokeFun + ) + + freeFunctionDeclaration.body = inlineLambdaBody( + freeFunctionDeclaration, + lambdaInfo.invokeFun, + lambdaInfo.createOldToNewInvokeParametersMapping(freeFunctionDeclaration), + emptyMap() + ) + + newDeclarations.add(freeFunctionDeclaration) + + // lambdas can contain another lambdas and local classes in so let's not lose them + newDeclarations.addAll(lambdaInfo.lambdaInnerClasses()) + + ctorToFreeFunctionMap[constructor.symbol] = freeFunctionDeclaration.symbol + + return newDeclarations + } private fun setDynamicProperty(r: IrValueSymbol, property: String, value: IrExpression): IrStatement { return IrDynamicOperatorExpressionImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, context.irBuiltIns.unitType, IrDynamicOperator.EQ).apply { diff --git a/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt b/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt index c95b3bffa65..ab5be3c8390 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt @@ -2,7 +2,7 @@ // CHECK_CASES_COUNT: function=bar1_u51tkt$ count=3 TARGET_BACKENDS=JS // CHECK_IF_COUNT: function=bar1_u51tkt$ count=0 TARGET_BACKENDS=JS // CHECK_CASES_COUNT: function=A$bar2$lambda count=3 TARGET_BACKENDS=JS -// CHECK_CASES_COUNT: function=A$bar2$lambda count=0 IGNORED_BACKENDS=JS +// CHECK_CASES_COUNT: function=A$bar2$lambda count=4 IGNORED_BACKENDS=JS // CHECK_IF_COUNT: function=A$bar2$lambda count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/defaultLambdaInNoInline.kt b/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/defaultLambdaInNoInline.kt index f7077d79057..5de2082c361 100644 --- a/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/defaultLambdaInNoInline.kt +++ b/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/defaultLambdaInNoInline.kt @@ -9,7 +9,7 @@ inline fun inlineFun(crossinline inlineLambda: () -> String = { "OK" }, noinline // FILE: 2.kt // CHECK_CALLED_IN_SCOPE: function=inlineFun$lambda_0 scope=box TARGET_BACKENDS=JS -// CHECK_CALLED_IN_SCOPE: function=box$lambda scope=box IGNORED_BACKENDS=JS +// HAS_NO_CAPTURED_VARS: function=box except=box$lambda IGNORED_BACKENDS=JS import test.* fun box(): String { diff --git a/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/noInline.kt b/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/noInline.kt index fbd60d336e2..087d767ac4b 100644 --- a/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/noInline.kt +++ b/compiler/testData/codegen/boxInline/defaultValues/lambdaInlining/noInline.kt @@ -10,7 +10,7 @@ fun call(lambda: () -> String ) = lambda() // FILE: 2.kt // CHECK_CALLED_IN_SCOPE: function=inlineFun$lambda scope=box TARGET_BACKENDS=JS -// CHECK_CALLED_IN_SCOPE: function=box$lambda scope=box IGNORED_BACKENDS=JS +// HAS_NO_CAPTURED_VARS: function=box except=box$lambda;call IGNORED_BACKENDS=JS // CHECK_CALLED_IN_SCOPE: function=call scope=box import test.* diff --git a/js/js.translator/testData/box/inline/lambdaInLambda.kt b/js/js.translator/testData/box/inline/lambdaInLambda.kt index 770edaea1ee..62559ce992d 100644 --- a/js/js.translator/testData/box/inline/lambdaInLambda.kt +++ b/js/js.translator/testData/box/inline/lambdaInLambda.kt @@ -1,7 +1,9 @@ // EXPECTED_REACHABLE_NODES: 1284 package foo -// CHECK_CALLED_IN_SCOPE: scope=multiplyBy2 function=multiplyBy2$lambda +// CHECK_FUNCTION_EXISTS: multiplyBy2$lambda +// CHECK_CALLED_IN_SCOPE: scope=multiplyBy2 function=multiplyBy2$lambda TARGET_BACKENDS=JS +// HAS_NO_CAPTURED_VARS: function=multiplyBy2 except=multiplyBy2$lambda // CHECK_NOT_CALLED_IN_SCOPE: scope=multiplyBy2 function=multiplyBy2$lambda_0 // CHECK_NOT_CALLED_IN_SCOPE: scope=multiplyBy2 function=run @@ -19,4 +21,4 @@ fun box(): String { assertEquals(8, multiplyBy2(4)) return "OK" -} \ No newline at end of file +} diff --git a/js/js.translator/testData/box/inline/noInlineLambda.kt b/js/js.translator/testData/box/inline/noInlineLambda.kt index cb7cf6a1f9f..540c8531e0b 100644 --- a/js/js.translator/testData/box/inline/noInlineLambda.kt +++ b/js/js.translator/testData/box/inline/noInlineLambda.kt @@ -1,7 +1,9 @@ // EXPECTED_REACHABLE_NODES: 1284 package foo -// CHECK_CALLED_IN_SCOPE: scope=multiplyBy2 function=multiplyBy2$lambda +// CHECK_FUNCTION_EXISTS: multiplyBy2$lambda +// CHECK_CALLED_IN_SCOPE: scope=multiplyBy2 function=multiplyBy2$lambda TARGET_BACKENDS=JS +// HAS_NO_CAPTURED_VARS: function=multiplyBy2 except=multiplyBy2$lambda // CHECK_NOT_CALLED_IN_SCOPE: scope=multiplyBy2 function=run internal inline fun run(noinline func: (T) -> T, arg: T): T { @@ -18,4 +20,4 @@ fun box(): String { assertEquals(8, multiplyBy2(4)) return "OK" -} \ No newline at end of file +}