diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt index e5d05ac1c50..78b28b34788 100644 --- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/EnumWhenLowering.kt @@ -8,9 +8,7 @@ package org.jetbrains.kotlin.backend.common.lower import org.jetbrains.kotlin.backend.common.* import org.jetbrains.kotlin.descriptors.ClassKind import org.jetbrains.kotlin.ir.builders.* -import org.jetbrains.kotlin.ir.declarations.IrEnumEntry -import org.jetbrains.kotlin.ir.declarations.IrFile -import org.jetbrains.kotlin.ir.declarations.IrVariable +import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl @@ -24,17 +22,16 @@ import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid * Replace branches that are comparisons with compile-time known enum entries * with comparisons of ordinals. */ -class EnumWhenLowering(private val context: CommonBackendContext) : IrElementTransformerVoidWithContext(), FileLoweringPass { - +open class EnumWhenLowering(protected val context: CommonBackendContext) : IrElementTransformerVoidWithContext(), FileLoweringPass { private val subjectWithOrdinalStack = mutableListOf>>() - private val areEqual = context.irBuiltIns.eqeqSymbol + protected open fun mapConstEnumEntry(entry: IrEnumEntry): Int = + entry.parentAsClass.declarations.filterIsInstance().indexOf(this).also { + assert(it >= 0) { "enum entry ${entry.dump()} not in parent class" } + } - private fun IrEnumEntry.ordinal(): Int { - val result = parentAsClass.declarations.filterIsInstance().indexOf(this) - assert(result >= 0) { "enum entry ${symbol.owner.dump()} not in parent class" } - return result - } + protected open fun mapRuntimeEnumEntry(builder: IrBuilderWithScope, subject: IrExpression): IrExpression = + builder.irCall(subject.type.getClass()!!.symbol.getPropertyGetter("ordinal")!!).apply { dispatchReceiver = subject } override fun lower(irFile: IrFile) { visitFile(irFile) @@ -56,13 +53,11 @@ class EnumWhenLowering(private val context: CommonBackendContext) : IrElementTra // Will be initialized only when we found a branch that compares // subject with compile-time known enum entry. val subjectOrdinalProvider = lazy { - val ordinalPropertyGetter = subject.type.getClass()!!.symbol.getPropertyGetter("ordinal")!! context.createIrBuilder(currentScope!!.scope.scopeOwnerSymbol, subject.startOffset, subject.endOffset).run { - val ordinal = irCall(ordinalPropertyGetter.owner).apply { dispatchReceiver = irGet(subject) } val integer = if (subject.type.isNullable()) - irIfNull(ordinal.type, irGet(subject), irInt(-1), ordinal) + irIfNull(context.irBuiltIns.intType, irGet(subject), irInt(-1), mapRuntimeEnumEntry(this, irGet(subject))) else - ordinal + mapRuntimeEnumEntry(this, irGet(subject)) scope.createTemporaryVariable(integer).also { expression.statements.add(1, it) } @@ -95,7 +90,7 @@ class EnumWhenLowering(private val context: CommonBackendContext) : IrElementTra } val entryOrdinal = when { other is IrGetEnumValue && topmostSubject.type.classifierOrNull?.owner == other.symbol.owner.parent -> - other.symbol.owner.ordinal() + mapConstEnumEntry(other.symbol.owner) other.isNullConst() -> -1 else -> return super.visitCall(expression) diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/DeclarationOrigins.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/DeclarationOrigins.kt index e8d919f2138..95f796f6fb1 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/DeclarationOrigins.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/DeclarationOrigins.kt @@ -35,6 +35,7 @@ interface JvmLoweredDeclarationOrigin : IrDeclarationOrigin { IrDeclarationOriginImpl("SYNTHETIC_METHOD_FOR_PROPERTY_ANNOTATIONS", isSynthetic = true) object GENERATED_PROPERTY_REFERENCE : IrDeclarationOriginImpl("GENERATED_PROPERTY_REFERENCE", isSynthetic = true) object GENERATED_SAM_IMPLEMENTATION : IrDeclarationOriginImpl("GENERATED_SAM_IMPLEMENTATION", isSynthetic = true) + object ENUM_MAPPINGS_FOR_WHEN : IrDeclarationOriginImpl("ENUM_MAPPINGS_FOR_WHEN", isSynthetic = true) } interface JvmLoweredStatementOrigin : IrStatementOrigin { diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt index 7f62d7649fc..7e619f9947b 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmLower.kt @@ -61,12 +61,6 @@ private val propertiesPhase = makeIrFilePhase( stickyPostconditions = setOf((PropertiesLowering)::checkNoProperties) ) -private val enumWhenPhase = makeIrFilePhase( - ::EnumWhenLowering, - name = "EnumWhenLowering", - description = "Replace `when` subjects of enum types with their ordinals" -) - val jvmPhases = namedIrFilePhase( name = "IrLowering", description = "IR lowering", diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/MappedEnumWhenLowering.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/MappedEnumWhenLowering.kt new file mode 100644 index 00000000000..57eac2aeb7b --- /dev/null +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/MappedEnumWhenLowering.kt @@ -0,0 +1,136 @@ +/* + * Copyright 2010-2019 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.backend.jvm.lower + +import org.jetbrains.kotlin.backend.common.CommonBackendContext +import org.jetbrains.kotlin.backend.common.ir.createImplicitParameterDeclarationWithWrappedDescriptor +import org.jetbrains.kotlin.backend.common.lower.EnumWhenLowering +import org.jetbrains.kotlin.backend.common.lower.createIrBuilder +import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase +import org.jetbrains.kotlin.backend.jvm.JvmLoweredDeclarationOrigin +import org.jetbrains.kotlin.ir.IrStatement +import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET +import org.jetbrains.kotlin.ir.builders.* +import org.jetbrains.kotlin.ir.builders.declarations.addField +import org.jetbrains.kotlin.ir.builders.declarations.buildClass +import org.jetbrains.kotlin.ir.declarations.IrClass +import org.jetbrains.kotlin.ir.declarations.IrEnumEntry +import org.jetbrains.kotlin.ir.declarations.IrField +import org.jetbrains.kotlin.ir.expressions.IrExpression +import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl +import org.jetbrains.kotlin.ir.types.getClass +import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid +import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.util.OperatorNameConventions + +internal val enumWhenPhase = makeIrFilePhase( + ::MappedEnumWhenLowering, + name = "EnumWhenLowering", + description = "Replace `when` subjects of enum types with their ordinals" +) + +// A version of EnumWhenLowering that is more friendly to incremental compilation. For example, +// suppose the code initially looks like this: +// +// // 1.kt +// enum E { X } +// +// // 2.kt +// fun f(e: E) = when (e) { E.X -> 1 } +// +// EnumWhenLowering would transform 2.kt into this: +// +// fun f(e: E) = when (e.ordinal()) { 0 -> 1 } +// +// While this lowering would generate (approximately) this instead: +// +// fun f(e: E) = when (WhenMappings.$EnumSwitchMapping$0[e.ordinal()]) { 1 -> 1 } +// +// object WhenMappings { +// // Note the runtime call to ordinal(): 0 is not hardcoded. +// val $EnumSwitchMapping$0 = IntArray(E.values().size).also { it[E.X.ordinal()] = 1 } +// } +// +// The latter would not need to be recompiled if new entries were added before `X` +// at the negligible cost of an additional initializer per run + one array read per call. +// +private class MappedEnumWhenLowering(context: CommonBackendContext) : EnumWhenLowering(context) { + private val intArray = context.irBuiltIns.primitiveArrayForType.getValue(context.irBuiltIns.intType) + private val intArrayConstructor = intArray.constructors.single { it.owner.valueParameters.size == 1 } + private val intArrayGet = intArray.functions.single { it.owner.name == OperatorNameConventions.GET } + private val intArraySet = intArray.functions.single { it.owner.name == OperatorNameConventions.SET } + private val refArraySize = context.irBuiltIns.arrayClass.owner.properties.single { it.name.toString() == "size" }.getter!! + + // To avoid visibility-related issues, classes containing the mappings are direct children + // of the classes in which they are used. This field tracks which container is the innermost one. + private var state: EnumMappingState? = null + + private class EnumMappingState { + val mappings = mutableMapOf, IrField>>() + val mappingsClass by lazy { + buildClass { + name = Name.identifier("WhenMappings") + origin = JvmLoweredDeclarationOrigin.ENUM_MAPPINGS_FOR_WHEN + }.apply { + createImplicitParameterDeclarationWithWrappedDescriptor() + } + } + } + + override fun mapConstEnumEntry(entry: IrEnumEntry): Int { + val (mapping, _) = state!!.mappings.getOrPut(entry.parentAsClass) { + mutableMapOf() to state!!.mappingsClass.addField { + name = Name.identifier("\$EnumSwitchMapping\$${state!!.mappings.size}") + type = intArray.owner.defaultType + origin = JvmLoweredDeclarationOrigin.ENUM_MAPPINGS_FOR_WHEN + isFinal = true + isStatic = true + } + } + // Index 0 (default value for integers) is reserved for unknown ordinals. + return mapping.getOrPut(entry) { mapping.size + 1 } + } + + override fun mapRuntimeEnumEntry(builder: IrBuilderWithScope, subject: IrExpression): IrExpression = + builder.irCall(intArrayGet).apply { + val (_, field) = state!!.mappings[subject.type.getClass()!!] + ?: throw AssertionError("no values mapped for enum class ${subject.type}") + dispatchReceiver = builder.irGetField(null, field) + putValueArgument(0, super.mapRuntimeEnumEntry(builder, subject)) + } + + override fun visitClassNew(declaration: IrClass): IrStatement { + val oldState = state + state = EnumMappingState() + super.visitClassNew(declaration) + + for ((enum, mappingAndField) in state!!.mappings) { + val (mapping, field) = mappingAndField + val builder = context.createIrBuilder(state!!.mappingsClass.symbol) + val enumValues = enum.functions.single { it.name.toString() == "values" } + field.initializer = builder.irExprBody(builder.irBlock { + val enumSize = irCall(refArraySize).apply { dispatchReceiver = irCall(enumValues) } + val result = irTemporary(irCall(intArrayConstructor).apply { putValueArgument(0, enumSize) }) + for ((entry, index) in mapping) { + val runtimeEntry = IrGetEnumValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, enum.defaultType, entry.symbol) + +irCall(intArraySet).apply { + dispatchReceiver = irGet(result) + putValueArgument(0, super.mapRuntimeEnumEntry(builder, runtimeEntry)) // .ordinal() + putValueArgument(1, irInt(index)) + } + } + +irGet(result) + }) + } + + if (state!!.mappings.isNotEmpty()) { + declaration.declarations += state!!.mappingsClass.apply { parent = declaration } + } + state = oldState + return declaration + } +} diff --git a/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/descriptors/IrBuiltIns.kt b/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/descriptors/IrBuiltIns.kt index 0ee1f87bcca..7db180a7613 100644 --- a/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/descriptors/IrBuiltIns.kt +++ b/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/descriptors/IrBuiltIns.kt @@ -185,6 +185,7 @@ class IrBuiltIns( val primitiveFloatingPointTypes = listOf(float, double) val primitiveArrays = PrimitiveType.values().map { builtIns.getPrimitiveArrayClassDescriptor(it).toIrSymbol() } val primitiveArrayElementTypes = primitiveArrays.zip(primitiveIrTypes).toMap() + val primitiveArrayForType = primitiveArrayElementTypes.asSequence().associate { it.value to it.key } val primitiveTypeToIrType = mapOf( PrimitiveType.BOOLEAN to booleanType, diff --git a/compiler/testData/codegen/bytecodeText/inline/whenMappingOnCallSite.kt b/compiler/testData/codegen/bytecodeText/inline/whenMappingOnCallSite.kt index 43f2f6aa405..b15b2922494 100644 --- a/compiler/testData/codegen/bytecodeText/inline/whenMappingOnCallSite.kt +++ b/compiler/testData/codegen/bytecodeText/inline/whenMappingOnCallSite.kt @@ -1,4 +1,3 @@ -// IGNORE_BACKEND: JVM_IR package test enum class X { diff --git a/compiler/testData/writeFlags/class/accessFlags/mappingWhen.kt b/compiler/testData/writeFlags/class/accessFlags/mappingWhen.kt index 1fec51abcd0..654230c2e59 100644 --- a/compiler/testData/writeFlags/class/accessFlags/mappingWhen.kt +++ b/compiler/testData/writeFlags/class/accessFlags/mappingWhen.kt @@ -1,5 +1,3 @@ -// IGNORE_BACKEND: JVM_IR - enum class BigEnum { ITEM1, ITEM2