From 7c63d50d1c7bc1c99be36620ebeb6c56510d4399 Mon Sep 17 00:00:00 2001 From: pyos Date: Mon, 6 Sep 2021 12:15:20 +0200 Subject: [PATCH] IR: create more temporary `val`s when optimizing tailrec calls This is needed so that SharedVariablesLowering doesn't get confused, and SharedVariablesLowering should run after TailrecLowering to properly optimize tailrec calls in inline lambdas. --- .../common/TailRecursionCallsCollector.kt | 9 +-- .../backend/common/lower/TailrecLowering.kt | 63 +++++++------------ .../jetbrains/kotlin/backend/jvm/JvmLower.kt | 4 +- .../codegen/box/defaultArguments/kt36853.kt | 2 - .../defaultArguments/kt36853_nestedObject.kt | 2 - .../codegen/box/defaultArguments/kt36853a.kt | 2 - .../codegen/box/defaultArguments/kt46189.kt | 3 - 7 files changed, 26 insertions(+), 59 deletions(-) diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/TailRecursionCallsCollector.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/TailRecursionCallsCollector.kt index 29d98ac788c..26c97f32455 100644 --- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/TailRecursionCallsCollector.kt +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/TailRecursionCallsCollector.kt @@ -24,7 +24,6 @@ import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.types.classOrNull import org.jetbrains.kotlin.ir.types.IdSignatureValues import org.jetbrains.kotlin.ir.types.isUnit -import org.jetbrains.kotlin.ir.util.parentClassOrNull import org.jetbrains.kotlin.ir.util.usesDefaultArguments import org.jetbrains.kotlin.ir.visitors.IrElementVisitor @@ -99,13 +98,7 @@ fun collectTailRecursionCalls(irFunction: IrFunction): Set { } private fun IrExpression.isUnitRead(): Boolean = - when (this) { - is IrGetObjectValue -> symbol - // On the JVM, if SingletonReferencesLowering has already finished, a `Unit` reference - // is now an IrGetField to the INSTANCE field. - is IrGetField -> symbol.owner.parentClassOrNull?.symbol - else -> null - }?.signature == IdSignatureValues.unit + this is IrGetObjectValue && symbol.signature == IdSignatureValues.unit override fun visitWhen(expression: IrWhen, data: ElementKind) { expression.branches.forEach { diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/TailrecLowering.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/TailrecLowering.kt index d3f0740e42e..87d4bd84e07 100644 --- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/TailrecLowering.kt +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/TailrecLowering.kt @@ -24,12 +24,13 @@ import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.builders.* import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.expressions.* -import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl +import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl import org.jetbrains.kotlin.ir.symbols.IrValueParameterSymbol import org.jetbrains.kotlin.ir.transformStatement +import org.jetbrains.kotlin.ir.types.makeNullable import org.jetbrains.kotlin.ir.util.explicitParameters import org.jetbrains.kotlin.ir.util.getArgumentsWithIr -import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid +import org.jetbrains.kotlin.ir.util.patchDeclarationParents import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid @@ -124,68 +125,52 @@ private class BodyTransformer( val parameterToVariable: Map, val tailRecursionCalls: Set, val properComputationOrderOfTailrecDefaultParameters: Boolean -) : IrElementTransformerVoid() { +) : VariableRemapper(parameterToNew) { val parameters = irFunction.explicitParameters - override fun visitGetValue(expression: IrGetValue): IrExpression { - expression.transformChildrenVoid(this) - val value = parameterToNew[expression.symbol.owner] ?: return expression - return builder.at(expression).irGet(value) - } - override fun visitCall(expression: IrCall): IrExpression { expression.transformChildrenVoid(this) if (expression !in tailRecursionCalls) { return expression } - return builder.at(expression).genTailCall(expression) } private fun IrBuilderWithScope.genTailCall(expression: IrCall) = this.irBlock(expression) { // Get all specified arguments: - val parameterToArgument = expression.getArgumentsWithIr().map { (parameter, argument) -> - parameter to argument + val parameterToArgument = expression.getArgumentsWithIr().associateTo(mutableMapOf()) { (parameter, argument) -> + // Note that we create `val`s for those parameters so that if some default value contains an object + // that captures another parameter, it won't capture it as a mutable ref. + parameter to irTemporary(argument) } - // For each specified argument set the corresponding variable to it in the correct order: - parameterToArgument.forEach { (parameter, argument) -> - at(argument) - // Note that argument can use values of parameters, so it is important that - // references to parameters are mapped using `parameterToNew`, not `parameterToVariable`. - +irSet(parameterToVariable[parameter]!!.symbol, argument) - } - - val specifiedParameters = parameterToArgument.map { (parameter, _) -> parameter }.toSet() - // For each unspecified argument set the corresponding variable to default: parameters - .filter { it !in specifiedParameters } + .filter { it !in parameterToArgument } .let { if (properComputationOrderOfTailrecDefaultParameters) it else it.asReversed() } - .forEach { parameter -> - + .associateWithTo(parameterToArgument) { parameter -> val originalDefaultValue = parameter.defaultValue?.expression ?: throw Error("no argument specified for $parameter") - // Copy default value, mapping parameters to variables containing freshly computed arguments: val defaultValue = originalDefaultValue - .deepCopyWithVariables() - .transform(object : IrElementTransformerVoid() { - + .deepCopyWithVariables().patchDeclarationParents(parent) + .transform(object : VariableRemapper(parameterToArgument) { override fun visitGetValue(expression: IrGetValue): IrExpression { - expression.transformChildrenVoid(this) - - val variable = parameterToVariable[expression.symbol.owner] ?: return expression - return IrGetValueImpl( - expression.startOffset, expression.endOffset, variable.type, - variable.symbol, expression.origin - ) + // If this parameter references a different parameter declared later, produce null: + if (expression.symbol.owner.let { it is IrValueParameter && it.parent == irFunction && it !in parameterToArgument }) + return IrConstImpl.defaultValueForType(startOffset, endOffset, expression.type.makeNullable()) + return super.visitGetValue(expression) } - }, data = null) - - +irSet(parameterToVariable[parameter]!!.symbol, defaultValue) + }, null) + irTemporary(defaultValue) } + // Copy the new `val`s into the `var`s declared outside the loop: + parameterToArgument.forEach { (parameter, argument) -> + at(argument) + +irSet(parameterToVariable[parameter]!!.symbol, irGet(argument)) + } + // Jump to the entry: +irContinue(loop) } diff --git a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt index ae2503d01ac..147edcda600 100644 --- a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt +++ b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt @@ -269,7 +269,6 @@ private val tailrecPhase = makeIrFilePhase( ::JvmTailrecLowering, name = "Tailrec", description = "Handle tailrec calls", - prerequisite = setOf(localDeclarationsPhase) ) private val kotlinNothingValueExceptionPhase = makeIrFilePhase( @@ -354,6 +353,7 @@ private val jvmFilePhases = listOf( forLoopsPhase, collectionStubMethodLowering, jvmInlineClassPhase, + tailrecPhase, makePatchParentsPhase(1), enumWhenPhase, @@ -364,8 +364,6 @@ private val jvmFilePhases = listOf( returnableBlocksPhase, sharedVariablesPhase, localDeclarationsPhase, - - tailrecPhase, makePatchParentsPhase(2), jvmLocalClassExtractionPhase, diff --git a/compiler/testData/codegen/box/defaultArguments/kt36853.kt b/compiler/testData/codegen/box/defaultArguments/kt36853.kt index 464cf736792..47cde6752e2 100644 --- a/compiler/testData/codegen/box/defaultArguments/kt36853.kt +++ b/compiler/testData/codegen/box/defaultArguments/kt36853.kt @@ -1,5 +1,3 @@ -// IGNORE_BACKEND: JS_IR, WASM - interface IFoo { fun foo(): String } diff --git a/compiler/testData/codegen/box/defaultArguments/kt36853_nestedObject.kt b/compiler/testData/codegen/box/defaultArguments/kt36853_nestedObject.kt index bfb2e57260a..5f8c37cc90b 100644 --- a/compiler/testData/codegen/box/defaultArguments/kt36853_nestedObject.kt +++ b/compiler/testData/codegen/box/defaultArguments/kt36853_nestedObject.kt @@ -1,5 +1,3 @@ -// IGNORE_BACKEND: JS_IR, WASM - interface IFoo { fun foo(): String } diff --git a/compiler/testData/codegen/box/defaultArguments/kt36853a.kt b/compiler/testData/codegen/box/defaultArguments/kt36853a.kt index ccf8259fac3..5751487976c 100644 --- a/compiler/testData/codegen/box/defaultArguments/kt36853a.kt +++ b/compiler/testData/codegen/box/defaultArguments/kt36853a.kt @@ -1,5 +1,3 @@ -// IGNORE_BACKEND: JS_IR - tailrec fun tailrecDefault(fake: Int, fn: () -> String = { "OK" }): String { return if (fake == 0) tailrecDefault(1) diff --git a/compiler/testData/codegen/box/defaultArguments/kt46189.kt b/compiler/testData/codegen/box/defaultArguments/kt46189.kt index 001d0d20018..10089944a97 100644 --- a/compiler/testData/codegen/box/defaultArguments/kt46189.kt +++ b/compiler/testData/codegen/box/defaultArguments/kt46189.kt @@ -1,6 +1,3 @@ -// IGNORE_BACKEND: WASM -// IGNORE_BACKEND: JS_IR - class C fun box(): String =