From 5cdf053c8edd1ab2b34c047f9abe4dba478fbb5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steven=20Sch=C3=A4fer?= Date: Thu, 25 Jun 2020 16:04:11 +0200 Subject: [PATCH] Coroutines: Fix RedundantLocalsEliminationMethodTransformer - Take control flow into account when collecting usage information - Don't remove stores to local variables --- ...ndantLocalsEliminationMethodTransformer.kt | 259 ++++++++---------- .../ir/FirBlackBoxCodegenTestGenerated.java | 5 + .../unitTypeReturn/inlineUnitFunction.kt | 7 + .../codegen/BlackBoxCodegenTestGenerated.java | 5 + .../LightAnalysisModeTestGenerated.java | 5 + .../ir/IrBlackBoxCodegenTestGenerated.java | 5 + .../IrJsCodegenBoxES6TestGenerated.java | 5 + .../IrJsCodegenBoxTestGenerated.java | 5 + .../semantics/JsCodegenBoxTestGenerated.java | 5 + 9 files changed, 155 insertions(+), 146 deletions(-) create mode 100644 compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/RedundantLocalsEliminationMethodTransformer.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/RedundantLocalsEliminationMethodTransformer.kt index dbac2f9132e..ac7ae83c121 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/RedundantLocalsEliminationMethodTransformer.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/RedundantLocalsEliminationMethodTransformer.kt @@ -1,11 +1,10 @@ /* - * Copyright 2010-2019 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ package org.jetbrains.kotlin.codegen.coroutines -import org.jetbrains.kotlin.codegen.inline.nodeText import org.jetbrains.kotlin.codegen.optimization.boxing.isUnitInstance import org.jetbrains.kotlin.codegen.optimization.common.MethodAnalyzer import org.jetbrains.kotlin.codegen.optimization.common.asSequence @@ -13,168 +12,136 @@ import org.jetbrains.kotlin.codegen.optimization.common.removeAll import org.jetbrains.kotlin.codegen.optimization.fixStack.top import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer import org.jetbrains.kotlin.resolve.jvm.AsmTypes +import org.jetbrains.kotlin.utils.addToStdlib.safeAs import org.jetbrains.org.objectweb.asm.Opcodes -import org.jetbrains.org.objectweb.asm.Type -import org.jetbrains.org.objectweb.asm.tree.* +import org.jetbrains.org.objectweb.asm.tree.AbstractInsnNode +import org.jetbrains.org.objectweb.asm.tree.LabelNode +import org.jetbrains.org.objectweb.asm.tree.MethodNode +import org.jetbrains.org.objectweb.asm.tree.VarInsnNode import org.jetbrains.org.objectweb.asm.tree.analysis.BasicInterpreter import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue +import org.jetbrains.org.objectweb.asm.tree.analysis.Frame +import java.util.* -private class PossibleSpilledValue(val source: AbstractInsnNode, type: Type?) : BasicValue(type) { - val usages = mutableSetOf() - - override fun toString(): String = when { - source.opcode == Opcodes.ALOAD -> "" + (source as VarInsnNode).`var` - source.isUnitInstance() -> "U" - else -> error("unreachable") - } - - override fun equals(other: Any?): Boolean = - other is PossibleSpilledValue && source == other.source - - override fun hashCode(): Int = super.hashCode() xor source.hashCode() -} - -private object NonSpillableValue : BasicValue(AsmTypes.OBJECT_TYPE) { - override fun equals(other: Any?): Boolean = other is NonSpillableValue - - override fun toString(): String = "N" -} - -private object ConstructedValue : BasicValue(AsmTypes.OBJECT_TYPE) { - override fun equals(other: Any?): Boolean = other is ConstructedValue - - override fun toString(): String = "C" -} - -fun BasicValue?.nonspillable(): BasicValue? = if (this?.type?.sort == Type.OBJECT) NonSpillableValue else this - -private class RedundantSpillingInterpreter : BasicInterpreter(Opcodes.API_VERSION) { - val possibleSpilledValues = mutableSetOf() - - override fun newOperation(insn: AbstractInsnNode): BasicValue? { - if (insn.opcode == Opcodes.NEW) return ConstructedValue - val basicValue = super.newOperation(insn) - return if (insn.isUnitInstance()) - // Unit instances come from inlining suspend functions returning Unit. - // They can be spilled before they are eventually popped. - // Track them. - PossibleSpilledValue(insn, basicValue.type).also { possibleSpilledValues += it } - else basicValue.nonspillable() - } - - override fun copyOperation(insn: AbstractInsnNode, value: BasicValue?): BasicValue? = - when (value) { - is ConstructedValue -> value - is PossibleSpilledValue -> { - value.usages += insn - if (insn.opcode == Opcodes.ALOAD || insn.opcode == Opcodes.ASTORE) value - else value.nonspillable() - } - else -> value?.nonspillable() - } - - override fun naryOperation(insn: AbstractInsnNode, values: MutableList): BasicValue? { - for (value in values.filterIsInstance()) { - value.usages += insn - } - return super.naryOperation(insn, values)?.nonspillable() - } - - override fun merge(v: BasicValue?, w: BasicValue?): BasicValue? = - if (v is PossibleSpilledValue && w is PossibleSpilledValue && v.source == w.source) v - else v?.nonspillable() -} - -// Inliner emits a lot of locals during inlining. -// Remove all of them since these locals are -// 1) going to be spilled into continuation object -// 2) breaking tail-call elimination +/** + * This pass removes unused Unit values. These typically occur as a result of inlining and could end up spilling + * into the continuation object or break tail-call elimination. + * + * Concretely, we remove "GETSTATIC kotlin/Unit.INSTANCE" instructions if they are unused, or all uses are either + * POP instructions, or ASTORE instructions to locals which are never read and are not named local variables. + * + * This pass does not touch [suspensionPoints], as later passes rely on the bytecode patterns around suspension points. + */ internal class RedundantLocalsEliminationMethodTransformer(private val suspensionPoints: List) : MethodTransformer() { override fun transform(internalClassName: String, methodNode: MethodNode) { - val interpreter = RedundantSpillingInterpreter() - val frames = MethodAnalyzer(internalClassName, methodNode, interpreter).analyze() + val interpreter = UnitSourceInterpreter(methodNode.localVariables?.mapTo(mutableSetOf()) { it.index } ?: setOf()) + val frames = interpreter.run(internalClassName, methodNode) + // Mark all unused instructions for deletion (except for labels which may be used in debug information) val toDelete = mutableSetOf() - for (spilledValue in interpreter.possibleSpilledValues.filter { it.usages.isNotEmpty() }) { - @Suppress("UNCHECKED_CAST") - val aloads = spilledValue.usages.filter { it.opcode == Opcodes.ALOAD } as List - - if (aloads.isEmpty()) continue - - val slot = aloads.first().`var` - - if (aloads.any { it.`var` != slot }) continue - for (aload in aloads) { - methodNode.instructions.set(aload, spilledValue.source.clone()) - } - - toDelete.addAll(spilledValue.usages.filter { it.opcode == Opcodes.ASTORE }) - toDelete.add(spilledValue.source) + methodNode.instructions.asSequence().zip(frames.asSequence()).mapNotNullTo(toDelete) { (insn, frame) -> + insn.takeIf { frame == null && insn !is LabelNode } } - for (pop in methodNode.instructions.asSequence().filter { it.opcode == Opcodes.POP }) { - val value = (frames[methodNode.instructions.indexOf(pop)]?.top() as? PossibleSpilledValue) ?: continue - if (value.usages.isEmpty() && value.source !in suspensionPoints) { - toDelete.add(pop) - toDelete.add(value.source) - } - } - - // Remove unreachable instructions to simplify further analyses - for (index in frames.indices) { - if (frames[index] == null) { - val insn = methodNode.instructions[index] - if (insn !is LabelNode) { - toDelete.add(insn) - } + // Mark all spillable "GETSTATIC kotlin/Unit.INSTANCE" instructions for deletion + for ((unit, uses) in interpreter.unitUsageInformation) { + if (unit !in interpreter.unspillableUnitValues && unit !in suspensionPoints) { + toDelete += unit + toDelete += uses } } methodNode.instructions.removeAll(toDelete) } - - private fun AbstractInsnNode.clone() = when (this) { - is FieldInsnNode -> FieldInsnNode(opcode, owner, name, desc) - is VarInsnNode -> VarInsnNode(opcode, `var`) - is InsnNode -> InsnNode(opcode) - is TypeInsnNode -> TypeInsnNode(opcode, desc) - else -> error("clone of $this is not implemented yet") - } } -// Handy debugging routing -@Suppress("unused") -fun MethodNode.nodeTextWithFrames(frames: Array<*>): String { - var insns = nodeText.split("\n") - val first = insns.indexOfLast { it.trim().startsWith("@") } + 1 - var last = insns.indexOfFirst { it.trim().startsWith("LOCALVARIABLE") } - if (last < 0) last = insns.size - val prefix = insns.subList(0, first).joinToString(separator = "\n") - val postfix = insns.subList(last, insns.size).joinToString(separator = "\n") - insns = insns.subList(first, last) - if (insns.any { it.contains("TABLESWITCH") }) { - var insideTableSwitch = false - var buffer = "" - val res = arrayListOf() - for (insn in insns) { - if (insn.contains("TABLESWITCH")) { - insideTableSwitch = true - } - if (insideTableSwitch) { - buffer += insn - if (insn.contains("default")) { - insideTableSwitch = false - res += buffer - buffer = "" - continue - } - } else { - res += insn +// A version of SourceValue which inherits from BasicValue and is only used for Unit values. +private class UnitValue(val insns: Set) : BasicValue(AsmTypes.OBJECT_TYPE) { + constructor(insn: AbstractInsnNode) : this(Collections.singleton(insn)) + + override fun equals(other: Any?): Boolean = other is UnitValue && insns == other.insns + override fun hashCode() = Objects.hash(insns) + override fun toString() = "U" +} + +// A specialized SourceInterpreter which only keeps track of the use sites for Unit values which are exclusively used as +// arguments to POP and unused ASTORE instructions. +private class UnitSourceInterpreter(private val localVariables: Set) : BasicInterpreter(Opcodes.API_VERSION) { + // All unit values with visible use-sites. + val unspillableUnitValues = mutableSetOf() + + // Map from unit values to ASTORE/POP use-sites. + val unitUsageInformation = mutableMapOf>() + + private fun markUnspillable(value: BasicValue?) = + value?.safeAs()?.let { unspillableUnitValues += it.insns } + + private fun collectUnitUsage(use: AbstractInsnNode, value: UnitValue) { + for (def in value.insns) { + if (def !in unspillableUnitValues) { + unitUsageInformation.getOrPut(def) { mutableSetOf() } += use } } - insns = res } - return prefix + "\n" + insns.withIndex().joinToString(separator = "\n") { (index, insn) -> - if (index >= frames.size) "N/A\t$insn" else "${frames[index]}\t$insn" - } + "\n" + postfix + + fun run(internalClassName: String, methodNode: MethodNode): Array?> { + val frames = MethodAnalyzer(internalClassName, methodNode, this).analyze() + // The ASM analyzer does not visit POP instructions, so we do so here. + for ((insn, frame) in methodNode.instructions.asSequence().zip(frames.asSequence())) { + if (frame != null && insn.opcode == Opcodes.POP) { + val value = frame.top() + value.safeAs()?.let { collectUnitUsage(insn, it) } + } + } + return frames + } + + override fun newOperation(insn: AbstractInsnNode?): BasicValue = + if (insn?.isUnitInstance() == true) UnitValue(insn) else super.newOperation(insn) + + override fun copyOperation(insn: AbstractInsnNode, value: BasicValue?): BasicValue? { + if (value is UnitValue) { + if (insn is VarInsnNode && insn.opcode == Opcodes.ASTORE && insn.`var` !in localVariables) { + collectUnitUsage(insn, value) + // We track the stored value in case it is subsequently read. + return value + } + unspillableUnitValues += value.insns + } + return super.copyOperation(insn, value) + } + + override fun unaryOperation(insn: AbstractInsnNode, value: BasicValue?): BasicValue? { + markUnspillable(value) + return super.unaryOperation(insn, value) + } + + override fun binaryOperation(insn: AbstractInsnNode, value1: BasicValue?, value2: BasicValue?): BasicValue? { + markUnspillable(value1) + markUnspillable(value2) + return super.binaryOperation(insn, value1, value2) + } + + override fun ternaryOperation(insn: AbstractInsnNode, value1: BasicValue?, value2: BasicValue?, value3: BasicValue?): BasicValue? { + markUnspillable(value1) + markUnspillable(value2) + markUnspillable(value3) + return super.ternaryOperation(insn, value1, value2, value3) + } + + override fun naryOperation(insn: AbstractInsnNode, values: List?): BasicValue? { + values?.forEach(this::markUnspillable) + return super.naryOperation(insn, values) + } + + override fun merge(value1: BasicValue?, value2: BasicValue?): BasicValue? = + if (value1 is UnitValue && value2 is UnitValue) { + UnitValue(value1.insns.union(value2.insns)) + } else { + // Mark unit values as unspillable if we merge them with non-unit values here. + // This is conservative since the value could turn out to be unused. + markUnspillable(value1) + markUnspillable(value2) + super.merge(value1, value2) + } } diff --git a/compiler/fir/fir2ir/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java b/compiler/fir/fir2ir/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java index 05b6272a354..20948f6e472 100644 --- a/compiler/fir/fir2ir/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java +++ b/compiler/fir/fir2ir/tests/org/jetbrains/kotlin/codegen/ir/FirBlackBoxCodegenTestGenerated.java @@ -8306,6 +8306,11 @@ public class FirBlackBoxCodegenTestGenerated extends AbstractFirBlackBoxCodegenT runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt b/compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt new file mode 100644 index 00000000000..9c5f974cda9 --- /dev/null +++ b/compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt @@ -0,0 +1,7 @@ +suspend fun f(x: Any?) { + x?.let { Unit } ?: Unit +} + +fun box(): String { + return "OK" +} diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java index a973f033404..ce639473203 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java @@ -9501,6 +9501,11 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest { runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java index de930cff700..2a705ec6d24 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java @@ -9501,6 +9501,11 @@ public class LightAnalysisModeTestGenerated extends AbstractLightAnalysisModeTes runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java index 2ae8369c362..9751131d1e0 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/ir/IrBlackBoxCodegenTestGenerated.java @@ -8306,6 +8306,11 @@ public class IrBlackBoxCodegenTestGenerated extends AbstractIrBlackBoxCodegenTes runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/js/js.tests/test/org/jetbrains/kotlin/js/test/es6/semantics/IrJsCodegenBoxES6TestGenerated.java b/js/js.tests/test/org/jetbrains/kotlin/js/test/es6/semantics/IrJsCodegenBoxES6TestGenerated.java index c7e95737e00..475c7dad761 100644 --- a/js/js.tests/test/org/jetbrains/kotlin/js/test/es6/semantics/IrJsCodegenBoxES6TestGenerated.java +++ b/js/js.tests/test/org/jetbrains/kotlin/js/test/es6/semantics/IrJsCodegenBoxES6TestGenerated.java @@ -7031,6 +7031,11 @@ public class IrJsCodegenBoxES6TestGenerated extends AbstractIrJsCodegenBoxES6Tes runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/js/js.tests/test/org/jetbrains/kotlin/js/test/ir/semantics/IrJsCodegenBoxTestGenerated.java b/js/js.tests/test/org/jetbrains/kotlin/js/test/ir/semantics/IrJsCodegenBoxTestGenerated.java index 8a90db2d98a..515ca2b97a7 100644 --- a/js/js.tests/test/org/jetbrains/kotlin/js/test/ir/semantics/IrJsCodegenBoxTestGenerated.java +++ b/js/js.tests/test/org/jetbrains/kotlin/js/test/ir/semantics/IrJsCodegenBoxTestGenerated.java @@ -7041,6 +7041,11 @@ public class IrJsCodegenBoxTestGenerated extends AbstractIrJsCodegenBoxTest { runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt"); diff --git a/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/JsCodegenBoxTestGenerated.java b/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/JsCodegenBoxTestGenerated.java index bd79b3f240d..bb045d6721e 100644 --- a/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/JsCodegenBoxTestGenerated.java +++ b/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/JsCodegenBoxTestGenerated.java @@ -7041,6 +7041,11 @@ public class JsCodegenBoxTestGenerated extends AbstractJsCodegenBoxTest { runTestWithPackageReplacement("compiler/testData/codegen/box/coroutines/unitTypeReturn/coroutineReturn.kt", "kotlin.coroutines"); } + @TestMetadata("inlineUnitFunction.kt") + public void testInlineUnitFunction() throws Exception { + runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/inlineUnitFunction.kt"); + } + @TestMetadata("interfaceDelegation.kt") public void testInterfaceDelegation() throws Exception { runTest("compiler/testData/codegen/box/coroutines/unitTypeReturn/interfaceDelegation.kt");