diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/boxing/PopBackwardPropagationTransformer.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/boxing/PopBackwardPropagationTransformer.kt index 3c63eab07a1..00360890a1e 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/boxing/PopBackwardPropagationTransformer.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/boxing/PopBackwardPropagationTransformer.kt @@ -17,7 +17,6 @@ package org.jetbrains.kotlin.codegen.optimization.boxing import org.jetbrains.kotlin.codegen.optimization.OptimizationMethodVisitor -import org.jetbrains.kotlin.codegen.optimization.common.debugText import org.jetbrains.kotlin.codegen.optimization.common.isLoadOperation import org.jetbrains.kotlin.codegen.optimization.fixStack.peekWords import org.jetbrains.kotlin.codegen.optimization.fixStack.top @@ -52,17 +51,38 @@ class PopBackwardPropagationTransformer : MethodTransformer() { private val dontTouchInsnIndices = BitSet(insns.size) - private val transformations = hashMapOf() - private val frames = analyzeMethodBody() - fun transform() { + val frames = Analyzer(HazardsTrackingInterpreter()).analyze("fake", methodNode) for ((i, insn) in insns.withIndex()) { - if (insn.opcode == Opcodes.POP && frames[i] != null) { - val inputTop = getInputTop(insn) - val sources = inputTop.insns - if (sources.none { isDontTouch(it) } && sources.any { isTransformablePopOperand(it) }) { + val frame = frames[i] ?: continue + when (insn.opcode) { + Opcodes.POP -> + frame.top()?.let { input -> + // If this POP won't be removed, other POPs that touch the same values have to stay as well. + if (input.insns.any { it.shouldKeep() } || input.longerWhenFusedWithPop()) { + input.insns.markAsDontTouch() + } + } + Opcodes.POP2 -> frame.peekWords(2)?.forEach { it.insns.markAsDontTouch() } + Opcodes.DUP_X1 -> frame.peekWords(1, 1)?.forEach { it.insns.markAsDontTouch() } + Opcodes.DUP2_X1 -> frame.peekWords(2, 1)?.forEach { it.insns.markAsDontTouch() } + Opcodes.DUP_X2 -> frame.peekWords(1, 2)?.forEach { it.insns.markAsDontTouch() } + Opcodes.DUP2_X2 -> frame.peekWords(2, 2)?.forEach { it.insns.markAsDontTouch() } + } + } + + val transformations = hashMapOf() + for ((i, insn) in insns.withIndex()) { + val frame = frames[i] ?: continue + if (insn.opcode == Opcodes.POP) { + val input = frame.top() ?: continue + if (input.insns.none { it.shouldKeep() }) { transformations[insn] = REPLACE_WITH_NOP - sources.forEach { propagatePopBackwards(it, inputTop.size) } + input.insns.forEach { + if (it !in transformations) { + transformations[it] = it.combineWithPop(frames, input.size) + } + } } } } @@ -71,43 +91,6 @@ class PopBackwardPropagationTransformer : MethodTransformer() { } } - private fun analyzeMethodBody(): Array?> { - val frames = Analyzer(HazardsTrackingInterpreter()).analyze("fake", methodNode) - val insns = methodNode.instructions.toArray() - for (i in frames.indices) { - val frame = frames[i] ?: continue - val insn = insns[i] - - when (insn.opcode) { - Opcodes.POP2 -> { - val top2 = frame.peekWords(2) ?: throwIncorrectBytecode(insn, frame) - top2.forEach { it.insns.markAsDontTouch() } - } - Opcodes.DUP_X1 -> { - val top2 = frame.peekWords(1, 1) ?: throwIncorrectBytecode(insn, frame) - top2.forEach { it.insns.markAsDontTouch() } - } - Opcodes.DUP2_X1 -> { - val top3 = frame.peekWords(2, 1) ?: throwIncorrectBytecode(insn, frame) - top3.forEach { it.insns.markAsDontTouch() } - } - Opcodes.DUP_X2 -> { - val top3 = frame.peekWords(1, 2) ?: throwIncorrectBytecode(insn, frame) - top3.forEach { it.insns.markAsDontTouch() } - } - Opcodes.DUP2_X2 -> { - val top4 = frame.peekWords(2, 2) ?: throwIncorrectBytecode(insn, frame) - top4.forEach { it.insns.markAsDontTouch() } - } - } - } - return frames - } - - private fun throwIncorrectBytecode(insn: AbstractInsnNode?, frame: Frame): Nothing { - throw AssertionError("Incorrect bytecode at ${methodNode.instructions.indexOf(insn)}: ${insn.debugText} $frame") - } - private inner class HazardsTrackingInterpreter : SourceInterpreter(Opcodes.API_VERSION) { override fun naryOperation(insn: AbstractInsnNode, values: MutableList): SourceValue { for (value in values) { @@ -122,9 +105,7 @@ class PopBackwardPropagationTransformer : MethodTransformer() { } override fun unaryOperation(insn: AbstractInsnNode, value: SourceValue): SourceValue { - if (!insn.isPrimitiveTypeConversion()) { - value.insns.markAsDontTouch() - } + value.insns.markAsDontTouch() return super.unaryOperation(insn, value) } @@ -153,59 +134,38 @@ class PopBackwardPropagationTransformer : MethodTransformer() { } } - private fun propagatePopBackwards(insn: AbstractInsnNode, poppedValueSize: Int) { - if (transformations.containsKey(insn)) return - + private fun SourceValue.longerWhenFusedWithPop() = insns.fold(0) { x, insn -> when { - insn.isPrimitiveBoxing() -> - transformations[insn] = replaceWithPopTransformation(getInputTop(insn).size) + insn.isPurePush() -> x - 1 + insn.isPrimitiveBoxing() || insn.isPrimitiveTypeConversion() -> x + else -> x + 1 + } + } > 0 - insn.isPurePush() -> - transformations[insn] = REPLACE_WITH_NOP - - insn.isPrimitiveTypeConversion() -> { - val inputTop = getInputTop(insn) - val sources = inputTop.insns - if (sources.none { isDontTouch(it) }) { - transformations[insn] = REPLACE_WITH_NOP - sources.forEach { propagatePopBackwards(it, inputTop.size) } - } else { - transformations[insn] = replaceWithPopTransformation(inputTop.size) + private fun AbstractInsnNode.combineWithPop(frames: Array?>, resultSize: Int): Transformation = + when { + isPurePush() -> REPLACE_WITH_NOP + isPrimitiveBoxing() || isPrimitiveTypeConversion() -> { + val index = insnList.indexOf(this) + val frame = frames[index] ?: throw AssertionError("dead instruction #$index used by non-dead instruction") + val input = frame.top() ?: throw AssertionError("coercion instruction at #$index has no input") + when (input.size) { + 1 -> REPLACE_WITH_POP1 + 2 -> REPLACE_WITH_POP2 + else -> throw AssertionError("Unexpected pop value size: ${input.size}") } } - else -> - transformations[insn] = insertPopAfterTransformation(poppedValueSize) - } - } - - private fun replaceWithPopTransformation(size: Int): Transformation = - when (size) { - 1 -> REPLACE_WITH_POP1 - 2 -> REPLACE_WITH_POP2 - else -> throw AssertionError("Unexpected pop value size: $size") + when (resultSize) { + 1 -> INSERT_POP1_AFTER + 2 -> INSERT_POP2_AFTER + else -> throw AssertionError("Unexpected pop value size: $resultSize") + } } - private fun insertPopAfterTransformation(size: Int): Transformation = - when (size) { - 1 -> INSERT_POP1_AFTER - 2 -> INSERT_POP2_AFTER - else -> throw AssertionError("Unexpected pop value size: $size") - } - - private fun getInputTop(insn: AbstractInsnNode): SourceValue { - val i = insnList.indexOf(insn) - val frame = frames[i] ?: throw AssertionError("Unexpected dead instruction #$i") - return frame.top() ?: throw AssertionError("Instruction #$i has empty stack on input") - } - - private fun isTransformablePopOperand(insn: AbstractInsnNode) = - insn.isPrimitiveBoxing() || insn.isPurePush() - - private fun isDontTouch(insn: AbstractInsnNode) = - dontTouchInsnIndices[insnList.indexOf(insn)] + private fun AbstractInsnNode.shouldKeep() = + dontTouchInsnIndices[insnList.indexOf(this)] } - } fun AbstractInsnNode.isPurePush() = diff --git a/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBlackBoxCodegenTestGenerated.java b/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBlackBoxCodegenTestGenerated.java index ff0782803c6..c628cc2b910 100644 --- a/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBlackBoxCodegenTestGenerated.java +++ b/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBlackBoxCodegenTestGenerated.java @@ -27271,6 +27271,12 @@ public class FirBlackBoxCodegenTestGenerated extends AbstractFirBlackBoxCodegenT public void testKt20844() throws Exception { runTest("compiler/testData/codegen/box/optimizations/kt20844.kt"); } + + @Test + @TestMetadata("kt46921.kt") + public void testKt46921() throws Exception { + runTest("compiler/testData/codegen/box/optimizations/kt46921.kt"); + } } @Nested diff --git a/compiler/testData/codegen/box/optimizations/kt46921.kt b/compiler/testData/codegen/box/optimizations/kt46921.kt new file mode 100644 index 00000000000..48a3f5cfdaf --- /dev/null +++ b/compiler/testData/codegen/box/optimizations/kt46921.kt @@ -0,0 +1,12 @@ +// TARGET_BACKEND: JVM +// FILE: I.java +public interface I { + public T create(); +} + +// FILE: box.kt +// A specific bytecode pattern here may confuse POP propagation. +inline fun I.bar(default: T, crossinline baz: V.(T) -> T) = + u@{ it: Any? -> create().baz(it as? T ?: return@u default) } + +fun box() = I { "O" }.bar("fail") { this + it }("K") diff --git a/compiler/testData/codegen/bytecodeText/boxedNotNumberTypeOnUnboxing.kt b/compiler/testData/codegen/bytecodeText/boxedNotNumberTypeOnUnboxing.kt index 93dacfb6f3c..729fb509c51 100644 --- a/compiler/testData/codegen/bytecodeText/boxedNotNumberTypeOnUnboxing.kt +++ b/compiler/testData/codegen/bytecodeText/boxedNotNumberTypeOnUnboxing.kt @@ -1,34 +1,34 @@ fun test(p: Int?) { if (p != null) { - p.toByte() //intValue & I2B - p.toShort() //intValue & I2S - p.toInt() //intValue - p.toLong() //intValue & I2L - p.toFloat() //intValue & I2F - p.toDouble() //intValue & I2D + val a = p.toByte() //intValue & I2B + val b = p.toShort() //intValue & I2S + val c = p.toInt() //intValue + val d = p.toLong() //intValue & I2L + val e = p.toFloat() //intValue & I2F + val f = p.toDouble() //intValue & I2D } } fun test(p: Byte?) { if (p != null) { - p.toByte() //byteValue - p.toShort() //byteValue & I2S - p.toInt() //byteValue - p.toLong() //byteValue & I2L - p.toFloat() //byteValue & I2F - p.toDouble() //byteValue & I2D + val a = p.toByte() //byteValue + val b = p.toShort() //byteValue & I2S + val c = p.toInt() //byteValue + val d = p.toLong() //byteValue & I2L + val e = p.toFloat() //byteValue & I2F + val f = p.toDouble() //byteValue & I2D } } fun test(p: Char?) { if (p != null) { - p.toByte() //charValue & I2B - p.toShort() //charValue & I2S - p.toInt() //charValue - p.toLong() //charValue & I2L - p.toFloat() //charValue & I2F - p.toDouble() //charValue & I2D + val a = p.toByte() //charValue & I2B + val b = p.toShort() //charValue & I2S + val c = p.toInt() //charValue + val d = p.toLong() //charValue & I2L + val e = p.toFloat() //charValue & I2F + val f = p.toDouble() //charValue & I2D } } diff --git a/compiler/testData/codegen/bytecodeText/boxingOptimization/casts.kt b/compiler/testData/codegen/bytecodeText/boxingOptimization/casts.kt index 35ebf26ebf3..53c99f3b110 100644 --- a/compiler/testData/codegen/bytecodeText/boxingOptimization/casts.kt +++ b/compiler/testData/codegen/bytecodeText/boxingOptimization/casts.kt @@ -4,12 +4,12 @@ inline fun foo(x : R?, block : (R?) -> T) : T { } fun bar() { - foo(1) { x -> x!!.toLong() } - foo(1) { x -> x!!.toShort() } - foo(1L) { x -> x!!.toByte() } - foo(1L) { x -> x!!.toShort() } - foo('a') { x -> x!!.toDouble() } - foo(1.0) { x -> x!!.toInt() } + val a = foo(1) { x -> x!!.toLong() } + val b = foo(1) { x -> x!!.toShort() } + val c = foo(1L) { x -> x!!.toByte() } + val d = foo(1L) { x -> x!!.toShort() } + val e = foo('a') { x -> x!!.toDouble() } + val f = foo(1.0) { x -> x!!.toInt() } } // 0 valueOf diff --git a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BlackBoxCodegenTestGenerated.java b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BlackBoxCodegenTestGenerated.java index e65e5008e14..3dbc32d8aed 100644 --- a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BlackBoxCodegenTestGenerated.java +++ b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BlackBoxCodegenTestGenerated.java @@ -27241,6 +27241,12 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest { public void testKt20844() throws Exception { runTest("compiler/testData/codegen/box/optimizations/kt20844.kt"); } + + @Test + @TestMetadata("kt46921.kt") + public void testKt46921() throws Exception { + runTest("compiler/testData/codegen/box/optimizations/kt46921.kt"); + } } @Nested diff --git a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBlackBoxCodegenTestGenerated.java b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBlackBoxCodegenTestGenerated.java index 52e7e661251..00d8fa73838 100644 --- a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBlackBoxCodegenTestGenerated.java +++ b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBlackBoxCodegenTestGenerated.java @@ -27271,6 +27271,12 @@ public class IrBlackBoxCodegenTestGenerated extends AbstractIrBlackBoxCodegenTes public void testKt20844() throws Exception { runTest("compiler/testData/codegen/box/optimizations/kt20844.kt"); } + + @Test + @TestMetadata("kt46921.kt") + public void testKt46921() throws Exception { + runTest("compiler/testData/codegen/box/optimizations/kt46921.kt"); + } } @Nested diff --git a/compiler/tests-gen/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java b/compiler/tests-gen/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java index 3a6a627c381..668759668b2 100644 --- a/compiler/tests-gen/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java +++ b/compiler/tests-gen/org/jetbrains/kotlin/codegen/LightAnalysisModeTestGenerated.java @@ -23110,6 +23110,11 @@ public class LightAnalysisModeTestGenerated extends AbstractLightAnalysisModeTes public void testKt20844() throws Exception { runTest("compiler/testData/codegen/box/optimizations/kt20844.kt"); } + + @TestMetadata("kt46921.kt") + public void testKt46921() throws Exception { + runTest("compiler/testData/codegen/box/optimizations/kt46921.kt"); + } } @TestMetadata("compiler/testData/codegen/box/package")