[IR] Merged K/N inliner with the common one

This commit is contained in:
Igor Chevdar
2019-08-15 17:49:04 +03:00
parent 32153c26a8
commit ad8bcda99e
6 changed files with 183 additions and 78 deletions
@@ -6,6 +6,7 @@
package org.jetbrains.kotlin.backend.common.ir
import org.jetbrains.kotlin.backend.common.CommonBackendContext
import org.jetbrains.kotlin.builtins.KOTLIN_REFLECT_FQ_NAME
import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.builtins.PrimitiveType
import org.jetbrains.kotlin.builtins.UnsignedType
@@ -15,6 +16,7 @@ import org.jetbrains.kotlin.descriptors.findClassAcrossModuleDependencies
import org.jetbrains.kotlin.incremental.components.NoLookupLocation
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.declarations.IrPackageFragment
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
@@ -246,5 +248,12 @@ abstract class Symbols<out T : CommonBackendContext>(val context: T, private val
receiverClass?.fqNameWhenAvailable?.toUnsafe() == KotlinBuiltIns.FQ_NAMES.kProperty0
}
}
fun isTypeOfIntrinsic(symbol: IrFunctionSymbol): Boolean =
symbol is IrSimpleFunctionSymbol && symbol.owner.let { function ->
function.name.asString() == "typeOf" &&
function.valueParameters.isEmpty() &&
(function.parent as? IrPackageFragment)?.fqName == KOTLIN_REFLECT_FQ_NAME
}
}
}
@@ -18,6 +18,11 @@ import org.jetbrains.kotlin.ir.util.statements
// Return the underlying function for a lambda argument without bound or default parameters or varargs.
fun IrExpression.asSimpleLambda(): IrSimpleFunction? {
if (this is IrFunctionExpression) {
if (function.valueParameters.any { it.isVararg || it.defaultValue != null })
return null
return function
}
// A lambda is represented as a block with a function declaration and a reference to it.
if (this !is IrBlock || statements.size != 2)
return null
@@ -12,17 +12,22 @@ import org.jetbrains.kotlin.backend.common.ir.createTemporaryVariableWithWrapped
import org.jetbrains.kotlin.backend.common.lower.CoroutineIntrinsicLambdaOrigin
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.config.languageVersionSettings
import org.jetbrains.kotlin.descriptors.ValueDescriptor
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.irReturn
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
import org.jetbrains.kotlin.ir.symbols.impl.IrReturnableBlockSymbolImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.isNullable
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.IrElementVisitor
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.util.OperatorNameConventions
@@ -41,12 +46,13 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
return expression
if (Symbols.isLateinitIsInitializedPropertyGetter(callee.symbol))
return expression
if (Symbols.isTypeOfIntrinsic(callee.symbol))
return expression
val actualCallee = getFunctionDeclaration(callee.symbol)
actualCallee.transformChildrenVoid(this) // Process recursive inline.
val parent = allScopes.map { it.irElement }.filterIsInstance<IrDeclarationParent>().lastOrNull()
val inliner = Inliner(expression, actualCallee, currentScope!!, parent, context)
return inliner.inline()
}
@@ -69,13 +75,15 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
}
}
private val IrFunction.needsInlining get() = (this.isInline && !this.isExternal)
private val IrFunction.needsInlining get() = this.isInline && !this.isExternal
private inner class Inliner(val callSite: IrFunctionAccessExpression,
val callee: IrFunction,
val currentScope: ScopeWithIr,
val parent: IrDeclarationParent?,
val context: CommonBackendContext) {
private inner class Inliner(
val callSite: IrFunctionAccessExpression,
val callee: IrFunction,
val currentScope: ScopeWithIr,
val parent: IrDeclarationParent?,
val context: CommonBackendContext
) {
val copyIrElement = run {
val typeParameters =
@@ -91,7 +99,7 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
val substituteMap = mutableMapOf<IrValueParameter, IrExpression>()
fun inline() = inlineFunction(callSite, callee)
fun inline() = inlineFunction(callSite, callee, true)
/**
* TODO: JVM inliner crashed on attempt inline this function from transform.kt with:
@@ -104,27 +112,36 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
}
}
private fun inlineFunction(callSite: IrFunctionAccessExpression, callee: IrFunction): IrReturnableBlock {
val copiedCallee = copyIrElement.copy(callee) as IrFunction
private fun inlineFunction(
callSite: IrFunctionAccessExpression,
callee: IrFunction,
performRecursiveInline: Boolean
): IrReturnableBlock {
val copiedCallee = if (performRecursiveInline)
visitElement(copyIrElement.copy(callee)) as IrFunction
else copyIrElement.copy(callee) as IrFunction
val evaluationStatements = evaluateArguments(callSite, copiedCallee)
val statements = (copiedCallee.body as IrBlockBody).statements
val irReturnableBlockSymbol = IrReturnableBlockSymbolImpl(copiedCallee.descriptor.original)
val startOffset = callee.startOffset
val endOffset = callee.endOffset
val irBuilder = context.createIrBuilder(irReturnableBlockSymbol, startOffset, endOffset)
/* creates irBuilder appending to the end of the given returnable block: thus why we initialize
* irBuilder with (..., endOffset, endOffset).
*/
val irBuilder = context.createIrBuilder(irReturnableBlockSymbol, endOffset, endOffset)
val transformer = ParameterSubstitutor()
statements.transform { it.transform(transformer, data = null) }
statements.addAll(0, evaluationStatements)
val isCoroutineIntrinsicCall = callSite.descriptor.isBuiltInSuspendCoroutineUninterceptedOrReturn(
context.configuration.languageVersionSettings)
context.configuration.languageVersionSettings
)
return IrReturnableBlockImpl(
startOffset = startOffset,
endOffset = endOffset,
startOffset = callSite.startOffset,
endOffset = callSite.endOffset,
type = callSite.type,
symbol = irReturnableBlockSymbol,
origin = if (isCoroutineIntrinsicCall) CoroutineIntrinsicLambdaOrigin else null,
@@ -140,6 +157,7 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
return expression
}
})
patchDeclarationParents(parent) // TODO: Why it is not enough to just run SetDeclarationsParentVisitor?
}
}
@@ -152,7 +170,10 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
val argument = substituteMap[newExpression.symbol.owner] ?: return newExpression
argument.transformChildrenVoid(this) // Default argument can contain subjects for substitution.
return copyIrElement.copy(argument) as IrExpression
return if (argument is IrGetValueWithoutLocation)
argument.withLocation(newExpression.startOffset, newExpression.endOffset)
else (copyIrElement.copy(argument) as IrExpression)
}
//-----------------------------------------------------------------//
@@ -167,6 +188,8 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
return super.visitCall(expression)
if (functionArgument is IrFunctionReference) {
functionArgument.transformChildrenVoid(this)
val function = functionArgument.symbol.owner
val functionParameters = function.explicitParameters
val boundFunctionParameters = functionArgument.getArgumentsWithIr()
@@ -178,37 +201,54 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
val valueParameters = expression.getArgumentsWithIr().drop(1) // Skip dispatch receiver.
val immediateCall = with(expression) {
if (function is IrConstructor) {
IrConstructorCallImpl.fromSymbolOwner(startOffset, endOffset, type, function.symbol)
} else {
IrCallImpl(startOffset, endOffset, type, functionArgument.symbol)
}
if (function is IrConstructor)
IrConstructorCallImpl.fromSymbolOwner(startOffset, endOffset, function.returnType, function.symbol)
else
IrCallImpl(startOffset, endOffset, function.returnType, functionArgument.symbol)
}.apply {
functionParameters.forEach {
val argument =
if (it !in unboundArgsSet)
boundFunctionParametersMap[it]!!
else
valueParameters.getOrNull(unboundIndex++)?.second
if (unboundArgsSet.contains(it)) {
assert(unboundIndex < valueParameters.size) {
"Attempt to use unbound parameter outside of the callee's value parameters"
}
valueParameters[unboundIndex++].second
} else {
val arg = boundFunctionParametersMap[it]!!
if (arg is IrGetValueWithoutLocation)
arg.withLocation(expression.startOffset, expression.endOffset)
else arg
}
when (it) {
function.dispatchReceiverParameter -> this.dispatchReceiver = argument
function.extensionReceiverParameter -> this.extensionReceiver = argument
else -> putValueArgument(it.index, argument)
function.dispatchReceiverParameter ->
this.dispatchReceiver = argument.implicitCastIfNeededTo(function.dispatchReceiverParameter!!.type)
function.extensionReceiverParameter ->
this.extensionReceiver = argument.implicitCastIfNeededTo(function.extensionReceiverParameter!!.type)
else -> putValueArgument(it.index, argument.implicitCastIfNeededTo(function.valueParameters[it.index].type))
}
}
assert(unboundIndex >= valueParameters.size) { "Not all arguments of <invoke> are used" }
assert(unboundIndex == valueParameters.size) { "Not all arguments of the callee are used" }
for (index in 0 until functionArgument.typeArgumentsCount)
putTypeArgument(index, functionArgument.getTypeArgument(index))
}
}.implicitCastIfNeededTo(expression.type)
return this@FunctionInlining.visitExpression(super.visitExpression(immediateCall))
}
if (functionArgument !is IrBlock)
if (functionArgument !is IrFunctionExpression)
return super.visitCall(expression)
val functionDeclaration = functionArgument.statements[0] as IrFunction
val newExpression = inlineFunction(expression, functionDeclaration) // Inline the lambda. Lambda parameters will be substituted with lambda arguments.
return newExpression.transform(this, null) // Substitute lambda arguments with target function arguments.
// Inline the lambda. Lambda parameters will be substituted with lambda arguments.
val newExpression = inlineFunction(
expression,
functionArgument.function,
false
)
// Substitute lambda arguments with target function arguments.
return newExpression.transform(
this,
null
)
}
//-----------------------------------------------------------------//
@@ -216,7 +256,13 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
override fun visitElement(element: IrElement) = element.accept(this, null)
}
private fun isLambdaCall(irCall: IrFunctionAccessExpression): Boolean {
private fun IrExpression.implicitCastIfNeededTo(type: IrType) =
if (type == this.type)
this
else
IrTypeOperatorCallImpl(startOffset, endOffset, type, IrTypeOperator.IMPLICIT_CAST, type, this)
private fun isLambdaCall(irCall: IrCall): Boolean {
val callee = irCall.symbol.owner
val dispatchReceiver = callee.dispatchReceiverParameter ?: return false
assert(!dispatchReceiver.type.isKFunction())
@@ -231,25 +277,15 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
private fun IrValueParameter.isInlineParameter() =
!isNoinline && !type.isNullable() && type.isFunctionOrKFunction()
private inner class ParameterToArgument(val parameter: IrValueParameter,
val argumentExpression: IrExpression) {
private inner class ParameterToArgument(
val parameter: IrValueParameter,
val argumentExpression: IrExpression
) {
val isInlinableLambdaArgument: Boolean
get() {
if (!parameter.isInlineParameter()) return false
if (argumentExpression is IrFunctionReference) return true
// Do pattern-matching on IR.
if (argumentExpression !is IrBlock) return false
if (argumentExpression.origin != IrStatementOrigin.LAMBDA &&
argumentExpression.origin != IrStatementOrigin.ANONYMOUS_FUNCTION) return false
val statements = argumentExpression.statements
val irFunction = statements[0]
val irCallableReference = statements[1]
if (irFunction !is IrFunction) return false
if (irCallableReference !is IrCallableReference) return false
return true
}
get() = parameter.isInlineParameter() &&
(argumentExpression is IrFunctionReference
|| argumentExpression is IrFunctionExpression)
val isImmutableVariableLoad: Boolean
get() = argumentExpression.let {
@@ -257,14 +293,12 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
}
}
//-------------------------------------------------------------------------//
// callee might be a copied version of callsite.symbol.owner
private fun buildParameterToArgument(callSite: IrFunctionAccessExpression, callee: IrFunction): List<ParameterToArgument> {
val parameterToArgument = mutableListOf<ParameterToArgument>()
if (callSite.dispatchReceiver != null && // Only if there are non null dispatch receivers both
callee.dispatchReceiverParameter != null) // on call site and in function declaration.
if (callSite.dispatchReceiver != null && callee.dispatchReceiverParameter != null)
parameterToArgument += ParameterToArgument(
parameter = callee.dispatchReceiverParameter!!,
argumentExpression = callSite.dispatchReceiver!!
@@ -328,18 +362,53 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
}
}
}
return parameterToArgument + parametersWithDefaultToArgument // All arguments except default are evaluated at callsite,
// All arguments except default are evaluated at callsite,
// but default arguments are evaluated inside callee.
return parameterToArgument + parametersWithDefaultToArgument
}
//-------------------------------------------------------------------------//
private fun evaluateArguments(callSite: IrFunctionAccessExpression, callee: IrFunction): List<IrStatement> {
val parameterToArgumentOld = buildParameterToArgument(callSite, callee)
private fun evaluateArguments(functionReference: IrFunctionReference): List<IrStatement> {
val arguments = functionReference.getArgumentsWithIr().map { ParameterToArgument(it.first, it.second) }
val evaluationStatements = mutableListOf<IrStatement>()
val substitutor = ParameterSubstitutor()
parameterToArgumentOld.forEach {
val referenced = functionReference.symbol.owner
arguments.forEach {
val newArgument = if (it.isImmutableVariableLoad) {
it.argumentExpression.transform( // Arguments may reference the previous ones - substitute them.
substitutor,
data = null
)
} else {
val newVariable =
currentScope.scope.createTemporaryVariableWithWrappedDescriptor(
irExpression = it.argumentExpression.transform( // Arguments may reference the previous ones - substitute them.
substitutor,
data = null
),
nameHint = callee.symbol.owner.name.toString(),
isMutable = false
)
evaluationStatements.add(newVariable)
IrGetValueWithoutLocation(newVariable.symbol)
}
when (it.parameter) {
referenced.dispatchReceiverParameter -> functionReference.dispatchReceiver = newArgument
referenced.extensionReceiverParameter -> functionReference.extensionReceiver = newArgument
else -> functionReference.putValueArgument(it.parameter.index, newArgument)
}
}
return evaluationStatements
}
private fun evaluateArguments(callSite: IrFunctionAccessExpression, callee: IrFunction): List<IrStatement> {
val arguments = buildParameterToArgument(callSite, callee)
val evaluationStatements = mutableListOf<IrStatement>()
val substitutor = ParameterSubstitutor()
arguments.forEach {
/*
* We need to create temporary variable for each argument except inlinable lambda arguments.
* For simplicity and to produce simpler IR we don't create temporaries for every immutable variable,
@@ -347,29 +416,52 @@ class FunctionInlining(val context: CommonBackendContext) : IrElementTransformer
*/
if (it.isInlinableLambdaArgument) {
substituteMap[it.parameter] = it.argumentExpression
(it.argumentExpression as? IrFunctionReference)?.let { evaluationStatements += evaluateArguments(it) }
return@forEach
}
if (it.isImmutableVariableLoad) {
substituteMap[it.parameter] = it.argumentExpression.transform(substitutor, data = null) // Arguments may reference the previous ones - substitute them.
substituteMap[it.parameter] =
it.argumentExpression.transform( // Arguments may reference the previous ones - substitute them.
substitutor,
data = null
)
return@forEach
}
val newVariable = currentScope.scope.createTemporaryVariableWithWrappedDescriptor( // Create new variable and init it with the parameter expression.
irExpression = it.argumentExpression.transform(substitutor, data = null), // Arguments may reference the previous ones - substitute them.
nameHint = callee.symbol.owner.name.toString(),
isMutable = false)
val newVariable =
currentScope.scope.createTemporaryVariableWithWrappedDescriptor(
irExpression = it.argumentExpression.transform( // Arguments may reference the previous ones - substitute them.
substitutor,
data = null
),
nameHint = callee.symbol.owner.name.toString(),
isMutable = false
)
evaluationStatements.add(newVariable)
val getVal = IrGetValueImpl(
startOffset = currentScope.irElement.startOffset,
endOffset = currentScope.irElement.endOffset,
type = newVariable.type,
symbol = newVariable.symbol
)
substituteMap[it.parameter] = getVal
substituteMap[it.parameter] = IrGetValueWithoutLocation(newVariable.symbol)
}
return evaluationStatements
}
}
private class IrGetValueWithoutLocation(
symbol: IrValueSymbol,
override val origin: IrStatementOrigin? = null
) : IrTerminalDeclarationReferenceBase<IrValueSymbol, ValueDescriptor>(
UNDEFINED_OFFSET, UNDEFINED_OFFSET,
symbol.owner.type,
symbol, symbol.descriptor
), IrGetValue {
override fun <R, D> accept(visitor: IrElementVisitor<R, D>, data: D) =
visitor.visitGetValue(this, data)
override fun copy(): IrGetValue {
TODO("not implemented")
}
fun withLocation(startOffset: Int, endOffset: Int) =
IrGetValueImpl(startOffset, endOffset, type, symbol, origin)
}
}
@@ -403,9 +403,9 @@ val jsPhases = namedIrModulePhase(
testGenerationPhase then
expectDeclarationsRemovingPhase then
stripTypeAliasDeclarationsPhase then
provisionalFunctionExpressionPhase then
arrayConstructorPhase then
functionInliningPhase then
provisionalFunctionExpressionPhase then
lateinitLoweringPhase then
tailrecLoweringPhase then
enumClassConstructorLoweringPhase then
@@ -1,5 +1,5 @@
// !LANGUAGE: +NewInference +FunctionReferenceWithDefaultValueAsOtherType
// IGNORE_BACKEND: JS, JVM_IR
// IGNORE_BACKEND: JS, JVM_IR, JS_IR
fun foo(vararg l: Long, s: String = "OK"): String =
if (l.size == 0) s else "Fail"
@@ -1,4 +1,3 @@
// IGNORE_BACKEND: JS_IR
// FILE: 1.kt
package test