diff --git a/compiler/testData/codegen/box/when/enumOptimization/bigEnum.kt b/compiler/testData/codegen/box/when/enumOptimization/bigEnum.kt index fb76910b59d..629b6d1a32a 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/bigEnum.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/bigEnum.kt @@ -1,4 +1,8 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=bar1 count=6 +// CHECK_IF_COUNT: function=bar1 count=0 +// CHECK_CASES_COUNT: function=bar2 count=6 +// CHECK_IF_COUNT: function=bar2 count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/duplicatingItems.kt b/compiler/testData/codegen/box/when/enumOptimization/duplicatingItems.kt index 2760d6daa88..0df6b309bf6 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/duplicatingItems.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/duplicatingItems.kt @@ -1,4 +1,6 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=bar count=3 +// CHECK_IF_COUNT: function=bar count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/enumInsideClassObject.kt b/compiler/testData/codegen/box/when/enumOptimization/enumInsideClassObject.kt index 18131fed078..1ce3c8075a0 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/enumInsideClassObject.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/enumInsideClassObject.kt @@ -1,4 +1,6 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=foo count=3 +// CHECK_IF_COUNT: function=foo count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/expression.kt b/compiler/testData/codegen/box/when/enumOptimization/expression.kt index e1774794dea..418f1860d09 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/expression.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/expression.kt @@ -1,4 +1,8 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=bar1 count=3 +// CHECK_IF_COUNT: function=bar1 count=0 +// CHECK_CASES_COUNT: function=bar2 count=4 +// CHECK_IF_COUNT: function=bar2 count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/functionLiteralInTopLevel.kt b/compiler/testData/codegen/box/when/enumOptimization/functionLiteralInTopLevel.kt index 9b97152b2f3..2266ce69a30 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/functionLiteralInTopLevel.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/functionLiteralInTopLevel.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=box$lambda count=0 +// CHECK_IF_COUNT: function=box$lambda count=1 + enum class Season { WINTER, SPRING, diff --git a/compiler/testData/codegen/box/when/enumOptimization/kt14597.kt b/compiler/testData/codegen/box/when/enumOptimization/kt14597.kt index bd278f2a72e..50202f79404 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/kt14597.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/kt14597.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=box count=6 +// CHECK_IF_COUNT: function=box count=1 + enum class En { A, B, С } fun box(): String { diff --git a/compiler/testData/codegen/box/when/enumOptimization/kt14597_full.kt b/compiler/testData/codegen/box/when/enumOptimization/kt14597_full.kt index 4b45ed81a12..f69d587ee7f 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/kt14597_full.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/kt14597_full.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=box count=18 +// CHECK_IF_COUNT: function=box count=3 + enum class En { A, B, С } fun box(): String { diff --git a/compiler/testData/codegen/box/when/enumOptimization/kt14802.kt b/compiler/testData/codegen/box/when/enumOptimization/kt14802.kt index bdfc991c967..b2fd2b389c6 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/kt14802.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/kt14802.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=crash count=2 +// CHECK_IF_COUNT: function=crash count=1 + class EncapsulatedEnum>(val value: T) enum class MyEnum(val value: String) { diff --git a/compiler/testData/codegen/box/when/enumOptimization/kt15806.kt b/compiler/testData/codegen/box/when/enumOptimization/kt15806.kt index ad603bf7a5f..54d4d6c5cd9 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/kt15806.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/kt15806.kt @@ -1,3 +1,5 @@ +// CHECK_CASES_COUNT: function=doTheThing count=2 +// CHECK_IF_COUNT: function=doTheThing count=2 private fun Any?.doTheThing(): String { when (this) { diff --git a/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt b/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt index 3a29f080cb3..35b69efcfdd 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/manyWhensWithinClass.kt @@ -1,4 +1,8 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=bar1_u51tkt$ count=3 +// CHECK_IF_COUNT: function=bar1_u51tkt$ count=0 +// CHECK_CASES_COUNT: function=A$bar2$lambda count=3 +// CHECK_IF_COUNT: function=A$bar2$lambda count=0 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/nonConstantEnum.kt b/compiler/testData/codegen/box/when/enumOptimization/nonConstantEnum.kt index a1440b4712b..f613b0c7f4f 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/nonConstantEnum.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/nonConstantEnum.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=box count=0 +// CHECK_IF_COUNT: function=box count=1 + enum class Season { WINTER, SPRING, diff --git a/compiler/testData/codegen/box/when/enumOptimization/nullability.kt b/compiler/testData/codegen/box/when/enumOptimization/nullability.kt index d8734f77526..35beca78fac 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/nullability.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/nullability.kt @@ -1,4 +1,8 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=foo1 count=0 +// CHECK_IF_COUNT: function=foo1 count=2 +// CHECK_CASES_COUNT: function=foo2 count=0 +// CHECK_IF_COUNT: function=foo2 count=2 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/nullableEnum.kt b/compiler/testData/codegen/box/when/enumOptimization/nullableEnum.kt index 457f6281e1e..20d5f740782 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/nullableEnum.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/nullableEnum.kt @@ -1,3 +1,6 @@ +// CHECK_CASES_COUNT: function=test count=0 +// CHECK_IF_COUNT: function=test count=3 + enum class E { A, B diff --git a/compiler/testData/codegen/box/when/enumOptimization/subjectAny.kt b/compiler/testData/codegen/box/when/enumOptimization/subjectAny.kt index 0f21003c5f6..1a18c698481 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/subjectAny.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/subjectAny.kt @@ -1,4 +1,6 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=foo count=0 +// CHECK_IF_COUNT: function=foo count=3 import kotlin.test.assertEquals diff --git a/compiler/testData/codegen/box/when/enumOptimization/withoutElse.kt b/compiler/testData/codegen/box/when/enumOptimization/withoutElse.kt index abaea695949..a1f39c321ec 100644 --- a/compiler/testData/codegen/box/when/enumOptimization/withoutElse.kt +++ b/compiler/testData/codegen/box/when/enumOptimization/withoutElse.kt @@ -1,4 +1,8 @@ // WITH_RUNTIME +// CHECK_CASES_COUNT: function=bar1 count=3 +// CHECK_IF_COUNT: function=bar1 count=0 +// CHECK_CASES_COUNT: function=bar2 count=4 +// CHECK_IF_COUNT: function=bar2 count=0 import kotlin.test.assertEquals diff --git a/js/js.translator/src/org/jetbrains/kotlin/js/translate/expression/WhenTranslator.kt b/js/js.translator/src/org/jetbrains/kotlin/js/translate/expression/WhenTranslator.kt index 87eb5721e44..352fab12658 100644 --- a/js/js.translator/src/org/jetbrains/kotlin/js/translate/expression/WhenTranslator.kt +++ b/js/js.translator/src/org/jetbrains/kotlin/js/translate/expression/WhenTranslator.kt @@ -18,6 +18,8 @@ package org.jetbrains.kotlin.js.translate.expression import org.jetbrains.kotlin.backend.common.CodegenUtil import org.jetbrains.kotlin.builtins.KotlinBuiltIns +import org.jetbrains.kotlin.descriptors.ClassDescriptor +import org.jetbrains.kotlin.descriptors.ClassKind import org.jetbrains.kotlin.js.backend.ast.* import org.jetbrains.kotlin.js.translate.context.Namer import org.jetbrains.kotlin.js.translate.context.TranslationContext @@ -29,16 +31,26 @@ import org.jetbrains.kotlin.js.translate.utils.JsAstUtils.not import org.jetbrains.kotlin.js.translate.utils.mutator.CoercionMutator import org.jetbrains.kotlin.js.translate.utils.mutator.LastExpressionMutator import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.getTextWithLocation +import org.jetbrains.kotlin.resolve.DescriptorUtils +import org.jetbrains.kotlin.resolve.bindingContextUtil.getDataFlowInfoBefore +import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValueFactory +import org.jetbrains.kotlin.resolve.constants.CompileTimeConstant +import org.jetbrains.kotlin.resolve.constants.EnumValue import org.jetbrains.kotlin.resolve.constants.evaluate.ConstantExpressionEvaluator +import org.jetbrains.kotlin.resolve.descriptorUtil.getSuperClassOrAny import org.jetbrains.kotlin.types.KotlinType +private typealias EntryWithConstants = Pair, KtWhenEntry> + class WhenTranslator private constructor(private val whenExpression: KtWhenExpression, context: TranslationContext) : AbstractTranslator(context) { private val expressionToMatch: JsExpression? private val type: KotlinType? private val uniqueConstants = mutableSetOf() + private val uniqueEnumNames = mutableSetOf() private val isExhaustive: Boolean get() { @@ -108,44 +120,26 @@ private constructor(private val whenExpression: KtWhenExpression, context: Trans } private fun translateAsSwitch(fromIndex: Int): Pair? { - val expectedType = type ?: return null + val ktSubject = whenExpression.subjectExpression ?: return null + val subjectType = bindingContext().getType(ktSubject) ?: return null + + val dataFlow = DataFlowValueFactory.createDataFlowValue( + ktSubject, subjectType, bindingContext(), context().declarationDescriptor ?: context().currentModule) + val expectedTypes = bindingContext().getDataFlowInfoBefore(ktSubject).getStableTypes(dataFlow) + setOf(subjectType) val subject = expressionToMatch ?: return null + var subjectSupplier = { subject } - val entries = whenExpression.entries - val entriesForSwitch = mutableListOf, KtWhenEntry>>() - var i = fromIndex - while (i < entries.size) { - val entry = entries[i] - if (entry.isElse) break - - var hasImproperConstants = false - val constantValues = entry.conditions.mapNotNull { condition -> - val expression = (condition as? KtWhenConditionWithExpression)?.expression - expression?.let { ConstantExpressionEvaluator.getConstant(it, bindingContext())?.getValue(expectedType) } ?: run { - hasImproperConstants = true - null - } + val enumClass = expectedTypes.asSequence().mapNotNull { it.getEnumClass() }.firstOrNull() + val (entriesForSwitch, nextIndex) = if (enumClass != null) { + subjectSupplier = { + val enumBaseClass = enumClass.getSuperClassOrAny() + val nameProperty = DescriptorUtils.getPropertyByName(enumBaseClass.unsubstitutedMemberScope, Name.identifier("name")) + JsNameRef(context().getNameForDescriptor(nameProperty), subject) } - if (hasImproperConstants) break - - val constants = constantValues.filter { uniqueConstants.add(it) }.mapNotNull { value -> - when (value) { - is String -> JsStringLiteral(value) - is Int -> JsIntLiteral(value) - is Short -> JsIntLiteral(value.toInt()) - is Byte -> JsIntLiteral(value.toInt()) - else -> { - hasImproperConstants = true - null - } - } - } - if (hasImproperConstants) break - - if (constants.isNotEmpty()) { - entriesForSwitch += Pair(constants, entry) - } - i++ + collectEnumEntries(fromIndex, whenExpression.entries, enumClass.defaultType) + } + else { + collectPrimitiveConstantEntries(fromIndex, whenExpression.entries, expectedTypes) } return if (entriesForSwitch.asSequence().map { it.first.size }.sum() > 1) { @@ -164,13 +158,94 @@ private constructor(private val whenExpression: KtWhenExpression, context: Trans lastEntry.statements += JsBreak().apply { source = entry } members } - Pair(JsSwitch(subject, switchEntries).apply { source = expression }, i) + Pair(JsSwitch(subjectSupplier(), switchEntries).apply { source = expression }, nextIndex) } else { null } } + private fun collectPrimitiveConstantEntries( + fromIndex: Int, + entries: List, + expectedTypes: Set + ): Pair, Int> { + return collectConstantEntries( + fromIndex, entries, + { constant -> expectedTypes.asSequence().mapNotNull { constant.getValue(it) }.firstOrNull() }, + { uniqueConstants.add(it) }, + { + when (it) { + is String -> JsStringLiteral(it) + is Int -> JsIntLiteral(it) + is Short -> JsIntLiteral(it.toInt()) + is Byte -> JsIntLiteral(it.toInt()) + is Char -> JsIntLiteral(it.toInt()) + else -> null + } + } + ) + } + + private fun collectEnumEntries( + fromIndex: Int, + entries: List, + expectedType: KotlinType + ): Pair, Int> { + return collectConstantEntries( + fromIndex, entries, + { (it.toConstantValue(expectedType) as? EnumValue)?.value?.name?.identifier }, + { uniqueEnumNames.add(it) }, + { JsStringLiteral(it) } + ) + } + + private fun collectConstantEntries( + fromIndex: Int, + entries: List, + extractor: (CompileTimeConstant<*>) -> T?, + filter: (T) -> Boolean, + wrapper: (T) -> JsExpression? + ): Pair, Int> { + val entriesForSwitch = mutableListOf() + var i = fromIndex + while (i < entries.size) { + val entry = entries[i] + if (entry.isElse) break + + var hasImproperConstants = false + val constantValues = entry.conditions.mapNotNull { condition -> + val expression = (condition as? KtWhenConditionWithExpression)?.expression + expression?.let { ConstantExpressionEvaluator.getConstant(it, bindingContext()) }?.let(extractor) ?: run { + hasImproperConstants = true + null + } + } + if (hasImproperConstants) break + + val constants = constantValues.filter(filter).mapNotNull { + wrapper(it) ?: run { + hasImproperConstants = true + null + } + } + if (hasImproperConstants) break + + if (constants.isNotEmpty()) { + entriesForSwitch += Pair(constants, entry) + } + i++ + } + + return Pair(entriesForSwitch, i) + } + + private fun KotlinType.getEnumClass(): ClassDescriptor? { + if (isMarkedNullable) return null + val classDescriptor = (constructor.declarationDescriptor as? ClassDescriptor) + return if (classDescriptor?.kind == ClassKind.ENUM_CLASS) classDescriptor else null + } + private fun translateEntryExpression( entry: KtWhenEntry, context: TranslationContext, diff --git a/js/js.translator/testData/box/expression/when/exhaustiveCheckException.kt b/js/js.translator/testData/box/expression/when/exhaustiveCheckException.kt index 2c5529349ab..c480bc8e64c 100644 --- a/js/js.translator/testData/box/expression/when/exhaustiveCheckException.kt +++ b/js/js.translator/testData/box/expression/when/exhaustiveCheckException.kt @@ -29,9 +29,9 @@ enum class E { X, Y } -private inline fun createWrongC(): C = js("void 0").unsafeCast() +private inline fun createWrongC(): C = js("{ name: 'Z' }").unsafeCast() -private inline fun createWrongE(): E = js("void 0").unsafeCast() +private inline fun createWrongE(): E = js("{ name: 'Z' }").unsafeCast() fun box(): String { checkThrown(createWrongC()) {