diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt new file mode 100644 index 00000000000..e5d05ac1c50 --- /dev/null +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt @@ -0,0 +1,109 @@ +/* + * Copyright 2010-2019 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.backend.common.lower + +import org.jetbrains.kotlin.backend.common.* +import org.jetbrains.kotlin.descriptors.ClassKind +import org.jetbrains.kotlin.ir.builders.* +import org.jetbrains.kotlin.ir.declarations.IrEnumEntry +import org.jetbrains.kotlin.ir.declarations.IrFile +import org.jetbrains.kotlin.ir.declarations.IrVariable +import org.jetbrains.kotlin.ir.expressions.* +import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl +import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl +import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl +import org.jetbrains.kotlin.ir.types.classifierOrNull +import org.jetbrains.kotlin.ir.types.getClass +import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid + +/** Look for when-constructs where subject is enum entry. + * Replace branches that are comparisons with compile-time known enum entries + * with comparisons of ordinals. + */ +class EnumWhenLowering(private val context: CommonBackendContext) : IrElementTransformerVoidWithContext(), FileLoweringPass { + + private val subjectWithOrdinalStack = mutableListOf>>() + + private val areEqual = context.irBuiltIns.eqeqSymbol + + private fun IrEnumEntry.ordinal(): Int { + val result = parentAsClass.declarations.filterIsInstance().indexOf(this) + assert(result >= 0) { "enum entry ${symbol.owner.dump()} not in parent class" } + return result + } + + override fun lower(irFile: IrFile) { + visitFile(irFile) + } + + override fun visitBlock(expression: IrBlock): IrExpression { + // NB: See BranchingExpressionGenerator to get insight about `when` block translation to IR. + if (expression.origin != IrStatementOrigin.WHEN) { + return super.visitBlock(expression) + } + // when-block with subject should have two children: temporary variable and when itself. + if (expression.statements.size != 2) { + return super.visitBlock(expression) + } + val subject = expression.statements[0] + if (subject !is IrVariable || subject.type.getClass()?.kind != ClassKind.ENUM_CLASS) { + return super.visitBlock(expression) + } + // Will be initialized only when we found a branch that compares + // subject with compile-time known enum entry. + val subjectOrdinalProvider = lazy { + val ordinalPropertyGetter = subject.type.getClass()!!.symbol.getPropertyGetter("ordinal")!! + context.createIrBuilder(currentScope!!.scope.scopeOwnerSymbol, subject.startOffset, subject.endOffset).run { + val ordinal = irCall(ordinalPropertyGetter.owner).apply { dispatchReceiver = irGet(subject) } + val integer = if (subject.type.isNullable()) + irIfNull(ordinal.type, irGet(subject), irInt(-1), ordinal) + else + ordinal + scope.createTemporaryVariable(integer).also { + expression.statements.add(1, it) + } + } + } + subjectWithOrdinalStack.push(Pair(subject, subjectOrdinalProvider)) + try { + // Process nested `when` and comparisons. + expression.statements[1].transformChildrenVoid(this) + } finally { + subjectWithOrdinalStack.pop() + } + return expression + } + + override fun visitCall(expression: IrCall): IrExpression { + // We are looking for branch that is a comparison of the subject and another enum entry. + if (expression.symbol != context.irBuiltIns.eqeqSymbol) { + return super.visitCall(expression) + } + val lhs = expression.getValueArgument(0)!! + val rhs = expression.getValueArgument(1)!! + + val (topmostSubject, topmostOrdinalProvider) = subjectWithOrdinalStack.peek() + ?: return super.visitCall(expression) + val other = when { + lhs is IrGetValue && lhs.symbol.owner == topmostSubject -> rhs + rhs is IrGetValue && rhs.symbol.owner == topmostSubject -> lhs + else -> return super.visitCall(expression) + } + val entryOrdinal = when { + other is IrGetEnumValue && topmostSubject.type.classifierOrNull?.owner == other.symbol.owner.parent -> + other.symbol.owner.ordinal() + other.isNullConst() -> + -1 + else -> return super.visitCall(expression) + } + val subjectOrdinal = topmostOrdinalProvider.value + return IrCallImpl(expression.startOffset, expression.endOffset, expression.type, expression.symbol).apply { + putValueArgument(0, IrGetValueImpl(lhs.startOffset, lhs.endOffset, subjectOrdinal.type, subjectOrdinal.symbol)) + putValueArgument(1, IrConstImpl.int(rhs.startOffset, rhs.endOffset, context.irBuiltIns.intType, entryOrdinal)) + } + } +} diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt index a5cf494bd97..7f62d7649fc 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt @@ -61,6 +61,12 @@ private val propertiesPhase = makeIrFilePhase( stickyPostconditions = setOf((PropertiesLowering)::checkNoProperties) ) +private val enumWhenPhase = makeIrFilePhase( + ::EnumWhenLowering, + name = "EnumWhenLowering", + description = "Replace `when` subjects of enum types with their ordinals" +) + val jvmPhases = namedIrFilePhase( name = "IrLowering", description = "IR lowering", @@ -88,6 +94,7 @@ val jvmPhases = namedIrFilePhase( makePatchParentsPhase(1) then + enumWhenPhase then singletonReferencesPhase then jvmLocalDeclarationsPhase then singleAbstractMethodPhase then diff --git a/compiler/testData/codegen/bytecodeText/when/exhaustiveWhenUnit.kt b/compiler/testData/codegen/bytecodeText/when/exhaustiveWhenUnit.kt index 5f84864c3fc..17f33e1fa5f 100644 --- a/compiler/testData/codegen/bytecodeText/when/exhaustiveWhenUnit.kt +++ b/compiler/testData/codegen/bytecodeText/when/exhaustiveWhenUnit.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR enum class AccessMode { READ, WRITE, EXECUTE } fun whenExpr(access: AccessMode) { diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/bigEnum.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/bigEnum.kt index b09a3db6177..230eae422bb 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/bigEnum.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/bigEnum.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import kotlin.test.assertEquals enum class BigEnum { diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/duplicatingItems.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/duplicatingItems.kt index 975f58bf719..87e9d97c05a 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/duplicatingItems.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/duplicatingItems.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import kotlin.test.assertEquals enum class Season { diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/expression.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/expression.kt index 9c07da6a5f2..a1c4bc676c5 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/expression.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/expression.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import kotlin.test.assertEquals enum class Season { diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/functionLiteralInTopLevel.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/functionLiteralInTopLevel.kt index 041c66634e0..9bcebaf912e 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/functionLiteralInTopLevel.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/functionLiteralInTopLevel.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import kotlin.test.assertEquals enum class Season { @@ -14,6 +13,7 @@ fun box() : String { return foo(Season.SPRING) { x -> when (x) { Season.SPRING -> "OK" + Season.SUMMER -> "fail" // redundant branch to force use of TABLESWITCH instead of IF_ICMPNE else -> "fail" } } diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/importedEnumEntry.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/importedEnumEntry.kt index f11ef808bb1..17176b04d73 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/importedEnumEntry.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/importedEnumEntry.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import Color.RED enum class Color { RED, GREEN, BLUE } diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/manyWhensWithinClass.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/manyWhensWithinClass.kt index 2ac8a5d6af2..34dad145a32 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/manyWhensWithinClass.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/manyWhensWithinClass.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR package abc.foo enum class Season { diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/nullability.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/nullability.kt index fb97dcb8d2c..4a4a6cc6340 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/nullability.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/nullability.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR enum class Season { WINTER, SPRING, diff --git a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/withoutElse.kt b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/withoutElse.kt index f1c6aad502a..09e8377c0b4 100644 --- a/compiler/testData/codegen/bytecodeText/whenEnumOptimization/withoutElse.kt +++ b/compiler/testData/codegen/bytecodeText/whenEnumOptimization/withoutElse.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR import kotlin.test.assertEquals enum class Season {