diff --git a/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ClassCodegen.kt b/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ClassCodegen.kt index 57bb02f0bdc..36339960f2c 100644 --- a/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ClassCodegen.kt +++ b/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ClassCodegen.kt @@ -57,6 +57,7 @@ import org.jetbrains.kotlin.utils.addToStdlib.safeAs import org.jetbrains.org.objectweb.asm.* import org.jetbrains.org.objectweb.asm.commons.Method import java.io.File +import java.lang.RuntimeException class ClassCodegen private constructor( val irClass: IrClass, @@ -141,7 +142,8 @@ class ClassCodegen private constructor( for (method in irClass.declarations.filterIsInstance()) { if (method.name.asString() != "" && method.origin != JvmLoweredDeclarationOrigin.INLINE_LAMBDA && - method.origin != IrDeclarationOrigin.ADAPTER_FOR_FUN_INTERFACE_CONSTRUCTOR + method.origin != IrDeclarationOrigin.ADAPTER_FOR_FUN_INTERFACE_CONSTRUCTOR && + !(method.origin == IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE && method.body == null) ) { generateMethod(method, smap) } diff --git a/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ExpressionCodegen.kt b/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ExpressionCodegen.kt index 42755a4cdb1..6dc018a7456 100644 --- a/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ExpressionCodegen.kt +++ b/compiler/ir/backend.jvm/codegen/src/org/jetbrains/kotlin/backend/jvm/codegen/ExpressionCodegen.kt @@ -481,7 +481,9 @@ class ExpressionCodegen( fun handleValueParameter(i: Int, irParameter: IrValueParameter) { val arg = expression.getValueArgument(i) val parameterType = callable.valueParameterTypes[i] - require(arg != null) { "Null argument in ExpressionCodegen for parameter ${irParameter.render()}" } + require(arg != null) { + "No argument for parameter ${irParameter.render()}:\n${expression.dump()}" + } callGenerator.genValueAndPut(irParameter, arg, parameterType, this, data) } diff --git a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/FunctionReferenceLowering.kt b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/FunctionReferenceLowering.kt index cb9a9abaf38..b89e3a47a6e 100644 --- a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/FunctionReferenceLowering.kt +++ b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/FunctionReferenceLowering.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlin.backend.common.FileLoweringPass import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext import org.jetbrains.kotlin.backend.common.ir.* import org.jetbrains.kotlin.backend.common.lower.SamEqualsHashCodeMethodsGenerator +import org.jetbrains.kotlin.backend.common.lower.VariableRemapper import org.jetbrains.kotlin.backend.common.lower.parents import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase import org.jetbrains.kotlin.backend.jvm.JvmBackendContext @@ -573,16 +574,16 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) } private val adaptedReferenceOriginalTarget: IrFunction? = adapteeCall?.symbol?.owner - private val isAdaptedFunInterfaceConstructorReference = + private val isFunInterfaceConstructorReference = callee.origin == IrDeclarationOrigin.ADAPTER_FOR_FUN_INTERFACE_CONSTRUCTOR private val constructedFunInterfaceSymbol: IrClassSymbol? = - if (isAdaptedFunInterfaceConstructorReference) + if (isFunInterfaceConstructorReference) callee.returnType.classOrNull ?: throw AssertionError("Fun interface type expected: ${callee.returnType.render()}") else null private val isAdaptedReference = - isAdaptedFunInterfaceConstructorReference || adaptedReferenceOriginalTarget != null + isFunInterfaceConstructorReference || adaptedReferenceOriginalTarget != null private val samInterface = samSuperType?.getClass() private val isKotlinFunInterface = samInterface != null && !samInterface.isFromJava() @@ -594,7 +595,7 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) samSuperType ?: when { isLambda -> context.ir.symbols.lambdaClass - isAdaptedFunInterfaceConstructorReference -> context.ir.symbols.funInterfaceConstructorReferenceClass + isFunInterfaceConstructorReference -> context.ir.symbols.funInterfaceConstructorReferenceClass useOptimizedSuperClass -> when { isAdaptedReference -> context.ir.symbols.adaptedFunctionReference else -> context.ir.symbols.functionReferenceImpl @@ -755,7 +756,7 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) // arity, [receiver] val constructor = when { - isAdaptedFunInterfaceConstructorReference -> + isFunInterfaceConstructorReference -> context.ir.symbols.funInterfaceConstructorReferenceClass.owner.constructors.single() samSuperType != null -> context.irBuiltIns.anyClass.owner.constructors.single() @@ -785,7 +786,7 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) call: IrFunctionAccessExpression, generateBoundReceiver: IrBuilder.() -> IrExpression ) { - if (isAdaptedFunInterfaceConstructorReference) { + if (isFunInterfaceConstructorReference) { val funInterfaceKClassRef = kClassReference(constructedFunInterfaceSymbol!!.owner.defaultType) val funInterfaceJavaClassRef = kClassToJavaClass(funInterfaceKClassRef) call.putValueArgument(0, funInterfaceJavaClassRef) @@ -850,12 +851,15 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) IrDeclarationOrigin.INSTANCE_RECEIVER, functionReferenceClass.symbol.defaultType ) - if (isLambda) - createLambdaInvokeMethod() - else if (isAdaptedFunInterfaceConstructorReference) - createFunInterfaceConstructorInvokeMethod() - else - createFunctionReferenceInvokeMethod(receiverVar) + + when { + isLambda -> + createLambdaInvokeMethod() + isFunInterfaceConstructorReference -> + createFunInterfaceConstructorInvokeMethod() + else -> + createFunctionReferenceInvokeMethod(receiverVar) + } } // Inline the body of an anonymous function into the generated lambda subclass. @@ -866,7 +870,6 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) } valueParameters += valueParameterMap.values body = callee.moveBodyTo(this, valueParameterMap) - } private fun IrSimpleFunction.createFunInterfaceConstructorInvokeMethod() { @@ -921,10 +924,100 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext) }?.let { putArgument(callee, parameter, it) } } } - irExprBody(call) + irExprBody( + inlineAdapterCallIfPossible(call, this@createFunctionReferenceInvokeMethod) + ) } } + private fun inlineAdapterCallIfPossible( + expression: IrFunctionAccessExpression, + invokeMethod: IrSimpleFunction + ): IrExpression { + val irCall = expression as? IrCall + ?: return expression + val callee = irCall.symbol.owner + if (callee.origin != IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE) + return expression + + // TODO fix testSuspendUnitConversion + if (callee.isSuspend) return expression + + // Callable reference adapter is a simple function that delegates to callable reference target, + // adapting its signature for required functional type. + // Usually it simply forwards arguments to target function. + // It also passes 'receiver' field for bound references, with downcast to the actual receiver type. + // In any case, adapter itself is synthetic and is not necessarily debuggable, so we can reuse variables freely. + // Inlining adapter into 'invoke' saves us two methods (adapter & synthetic accessor). + val adapterBody = callee.body as? IrBlockBody + if (adapterBody == null || adapterBody.statements.size != 1) + throw AssertionError("Unexpected adapter body: ${callee.dump()}") + val resultStatement = adapterBody.statements[0] + val resultExpression: IrExpression = + when { + resultStatement is IrReturn -> + resultStatement.value + resultStatement is IrTypeOperatorCall && resultStatement.operator == IrTypeOperator.IMPLICIT_COERCION_TO_UNIT -> + resultStatement + resultStatement is IrCall -> + resultStatement + resultStatement is IrConstructorCall -> + resultStatement + else -> + throw AssertionError("Unexpected adapter body: ${callee.dump()}") + } + + val startOffset = irCall.startOffset + val endOffset = irCall.endOffset + + val callArguments = LinkedHashMap() + val inlinedAdapterBlock = IrBlockImpl(startOffset, endOffset, irCall.type, origin = null) + var tmpVarIndex = 0 + + fun wrapIntoTemporaryVariableIfNecessary(expression: IrExpression): IrValueDeclaration { + if (expression is IrGetValue) + return expression.symbol.owner + if (expression !is IrTypeOperatorCall || expression.argument !is IrGetField) + throw AssertionError("Unexpected adapter argument:\n${expression.dump()}") + val temporaryVar = IrVariableImpl( + startOffset, endOffset, IrDeclarationOrigin.IR_TEMPORARY_VARIABLE, + IrVariableSymbolImpl(), + Name.identifier("tmp_${tmpVarIndex++}"), + expression.type, + isVar = false, isConst = false, isLateinit = false + ) + temporaryVar.parent = invokeMethod + temporaryVar.initializer = expression + inlinedAdapterBlock.statements.add(temporaryVar) + return temporaryVar + } + + callee.dispatchReceiverParameter?.let { + callArguments[it] = wrapIntoTemporaryVariableIfNecessary( + irCall.dispatchReceiver + ?: throw AssertionError("No dispatch receiver in adapter call: ${irCall.dump()}") + ) + } + callee.extensionReceiverParameter?.let { + callArguments[it] = wrapIntoTemporaryVariableIfNecessary( + irCall.extensionReceiver + ?: throw AssertionError("No extension receiver in adapter call: ${irCall.dump()}") + ) + } + for (valueParameter in callee.valueParameters) { + callArguments[valueParameter] = wrapIntoTemporaryVariableIfNecessary( + irCall.getValueArgument(valueParameter.index) + ?: throw AssertionError("No value argument #${valueParameter.index} in adapter call: ${irCall.dump()}") + ) + } + + val inlinedAdapterResult = resultExpression.transform(VariableRemapper(callArguments), null) + inlinedAdapterBlock.statements.add(inlinedAdapterResult) + + callee.body = null + return inlinedAdapterBlock + } + private fun buildOverride(superFunction: IrSimpleFunction, newReturnType: IrType = superFunction.returnType): IrSimpleFunction = functionReferenceClass.addFunction { setSourceRange(irFunctionReference) diff --git a/compiler/testData/codegen/bytecodeListing/callableReference/adaptedReference_ir.txt b/compiler/testData/codegen/bytecodeListing/callableReference/adaptedReference_ir.txt index 4c025cfaf48..c203b865e6e 100644 --- a/compiler/testData/codegen/bytecodeListing/callableReference/adaptedReference_ir.txt +++ b/compiler/testData/codegen/bytecodeListing/callableReference/adaptedReference_ir.txt @@ -26,13 +26,11 @@ public final class A { inner (anonymous) class A$testDefaultArguments$1 inner (anonymous) class A$testDefaultArguments$2 public method (): void - public synthetic final static method access$testDefaultArguments$defaultArgs(p0: A): java.lang.String public synthetic final static method access$testDefaultArguments$defaultArgs-0(p0: A, p1: kotlin.coroutines.Continuation): java.lang.Object synthetic static method defaultArgs$default(p0: A, p1: int, p2: java.lang.String, p3: int, p4: java.lang.Object): java.lang.String private final method defaultArgs(p0: int, p1: java.lang.String): java.lang.String private final method myApply(p0: kotlin.jvm.functions.Function0): void private final method myApplySuspend(p0: kotlin.jvm.functions.Function1): void - private synthetic final static method testDefaultArguments$defaultArgs(p0: A): java.lang.String private synthetic final static method testDefaultArguments$defaultArgs-0(p0: A, p1: kotlin.coroutines.Continuation): java.lang.Object public final method testDefaultArguments(): void }