diff --git a/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.fir.txt b/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.fir.txt index 1e4580f5d2f..db6cbd744a5 100644 --- a/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.fir.txt +++ b/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.fir.txt @@ -107,10 +107,10 @@ FILE: exhaustiveness_sealedSubClass.kt } } .#(Int(0)) - lval d: R|ERROR CLASS: Unresolved name: plus| = when (R|/e|) { + lval d: R|kotlin/Int| = when (R|/e|) { ($subj$ is R|C|) -> { Int(1) } } - .#(Int(0)) + .R|kotlin/Int.plus|(Int(0)) } diff --git a/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.kt b/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.kt index 7ed050ea810..9beeb8a7276 100644 --- a/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.kt +++ b/compiler/fir/analysis-tests/testData/resolve/exhaustiveness/positive/exhaustiveness_sealedSubClass.kt @@ -48,7 +48,7 @@ fun test_2(e: A) { is D -> 2 }.plus(0) - val d = when (e) { + val d = when (e) { is C -> 1 - }.plus(0) + }.plus(0) } diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/expression/FirExhaustiveWhenChecker.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/expression/FirExhaustiveWhenChecker.kt index 3f229f9497b..ace54dbcc92 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/expression/FirExhaustiveWhenChecker.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/expression/FirExhaustiveWhenChecker.kt @@ -11,19 +11,19 @@ import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext import org.jetbrains.kotlin.fir.analysis.diagnostics.DiagnosticReporter import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors import org.jetbrains.kotlin.fir.analysis.diagnostics.reportOn +import org.jetbrains.kotlin.fir.expressions.ExhaustivenessStatus import org.jetbrains.kotlin.fir.expressions.FirWhenExpression import org.jetbrains.kotlin.fir.expressions.isExhaustive object FirExhaustiveWhenChecker : FirWhenExpressionChecker() { override fun check(expression: FirWhenExpression, context: CheckerContext, reporter: DiagnosticReporter) { - // TODO: add reporting of proper missing clauses, see class WhenMissingCase if (expression.usedAsExpression && !expression.isExhaustive) { - val factory = if (expression.source?.isIfExpression == true) { - FirErrors.INVALID_IF_AS_EXPRESSION + if (expression.source?.isIfExpression == true) { + reporter.reportOn(expression.source, FirErrors.INVALID_IF_AS_EXPRESSION, context) } else { - FirErrors.NO_ELSE_IN_WHEN + val missingCases = (expression.exhaustivenessStatus as ExhaustivenessStatus.NotExhaustive).reasons + reporter.reportOn(expression.source, FirErrors.NO_ELSE_IN_WHEN, missingCases, context) } - reporter.reportOn(expression.source, factory, context) } } diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDefaultErrorMessages.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDefaultErrorMessages.kt index 6b5de471d3b..f1c711dfbea 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDefaultErrorMessages.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDefaultErrorMessages.kt @@ -18,6 +18,7 @@ import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMB import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMBOLS import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.TO_STRING import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.VISIBILITY +import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.WHEN_MISSING_CASES import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_DELEGATED_PROPERTY import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_IN_NON_ABSTRACT_CLASS import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_WITH_BODY @@ -504,7 +505,7 @@ class FirDefaultErrorMessages : DefaultErrorMessages.Extension { //) // When expressions - map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive") + map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive, add necessary {0}", WHEN_MISSING_CASES) map.put(INVALID_IF_AS_EXPRESSION, "'if' must have both main and 'else' branches if used as an expression") // Extended checkers group diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt index 3e35205c1a9..b0627cd5e12 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt @@ -10,6 +10,7 @@ import org.jetbrains.kotlin.diagnostics.rendering.Renderer import org.jetbrains.kotlin.fir.FirElement import org.jetbrains.kotlin.fir.FirRenderer import org.jetbrains.kotlin.fir.declarations.* +import org.jetbrains.kotlin.fir.expressions.WhenMissingCase import org.jetbrains.kotlin.fir.render import org.jetbrains.kotlin.fir.renderWithType import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol @@ -88,4 +89,16 @@ object FirDiagnosticRenderers { SYMBOL.render(symbol) } } + + private const val WHEN_MISSING_LIMIT = 7 + + val WHEN_MISSING_CASES = Renderer { missingCases: List -> + if (missingCases.firstOrNull() == WhenMissingCase.Unknown) { + "'else' branch" + } else { + val list = missingCases.joinToString(", ", limit = WHEN_MISSING_LIMIT) { "'$it'" } + val branches = if (missingCases.size > 1) "branches" else "branch" + "$list $branches or 'else' branch instead" + } + } } diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt index 8dbba130ef8..94b33f8e56a 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt @@ -14,6 +14,7 @@ import org.jetbrains.kotlin.fir.FirSourceElement import org.jetbrains.kotlin.fir.declarations.FirCallableDeclaration import org.jetbrains.kotlin.fir.declarations.FirClass import org.jetbrains.kotlin.fir.declarations.FirMemberDeclaration +import org.jetbrains.kotlin.fir.expressions.WhenMissingCase import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol @@ -231,7 +232,7 @@ object FirErrors { // TODO: val UNEXPECTED_SAFE_CALL by ... // When expressions - val NO_ELSE_IN_WHEN by error0(SourceElementPositioningStrategies.WHEN_EXPRESSION) + val NO_ELSE_IN_WHEN by error1>(SourceElementPositioningStrategies.WHEN_EXPRESSION) val INVALID_IF_AS_EXPRESSION by error0(SourceElementPositioningStrategies.IF_EXPRESSION) // Extended checkers group diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirWhenExhaustivenessTransformer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirWhenExhaustivenessTransformer.kt index 5d79a86a14a..238d60a951c 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirWhenExhaustivenessTransformer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirWhenExhaustivenessTransformer.kt @@ -8,21 +8,34 @@ package org.jetbrains.kotlin.fir.resolve.transformers import org.jetbrains.kotlin.descriptors.ClassKind import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.fir.FirElement +import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.expressions.* +import org.jetbrains.kotlin.fir.expressions.LogicOperationKind.OR import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition -import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference import org.jetbrains.kotlin.fir.resolve.BodyResolveComponents -import org.jetbrains.kotlin.fir.resolve.getSymbolByLookupTag -import org.jetbrains.kotlin.fir.resolve.toSymbol -import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag +import org.jetbrains.kotlin.fir.resolve.fullyExpandedType +import org.jetbrains.kotlin.fir.resolve.symbolProvider +import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol +import org.jetbrains.kotlin.fir.symbols.StandardClassIds import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol +import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol import org.jetbrains.kotlin.fir.symbols.impl.FirVariableSymbol import org.jetbrains.kotlin.fir.types.* -import org.jetbrains.kotlin.fir.visitors.* -import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlin.fir.visitors.CompositeTransformResult +import org.jetbrains.kotlin.fir.visitors.FirTransformer +import org.jetbrains.kotlin.fir.visitors.FirVisitor +import org.jetbrains.kotlin.fir.visitors.compose class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyResolveComponents) : FirTransformer() { + companion object { + private val exhaustivenessCheckers = listOf( + WhenOnBooleanExhaustivenessChecker, + WhenOnEnumExhaustivenessChecker, + WhenOnSealedClassExhaustivenessChecker + ) + } + override fun transformElement(element: E, data: Nothing?): CompositeTransformResult { throw IllegalArgumentException("Should not be there") } @@ -32,220 +45,279 @@ class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyRe return whenExpression.compose() } + @OptIn(ExperimentalStdlibApi::class) private fun processExhaustivenessCheck(whenExpression: FirWhenExpression) { if (whenExpression.branches.any { it.condition is FirElseIfTrueCondition }) { whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive) return } - val typeRef = whenExpression.subjectVariable?.returnTypeRef - ?: whenExpression.subject?.typeRef - ?: return - - // TODO: add some report logic about flexible type (see WHEN_ENUM_CAN_BE_NULL_IN_JAVA diagnostic in old frontend) - val type = typeRef.coneType.lowerBoundIfFlexible() - val lookupTag = (type as? ConeLookupTagBasedType)?.lookupTag ?: return - val nullable = type.nullability == ConeNullability.NULLABLE - val isExhaustive = when { - ((lookupTag as? ConeClassLikeLookupTag)?.classId == bodyResolveComponents.session.builtinTypes.booleanType.id) -> { - checkBooleanExhaustiveness(whenExpression, nullable) - } - - whenExpression.branches.isEmpty() -> false - - else -> { - val klass = lookupTag.toSymbol(bodyResolveComponents.session)?.fir as? FirRegularClass ?: return - when { - klass.classKind == ClassKind.ENUM_CLASS -> checkEnumExhaustiveness(whenExpression, klass, nullable) - klass.modality == Modality.SEALED -> checkSealedClassExhaustiveness(whenExpression, klass, nullable) - else -> return - } - } - } - if (isExhaustive) { - whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive) - } else { - whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive(listOf())) - } - } - - // ------------------------ Enum exhaustiveness ------------------------ - - private fun checkEnumExhaustiveness(whenExpression: FirWhenExpression, enum: FirRegularClass, nullable: Boolean): Boolean { - val data = EnumExhaustivenessData( - enum.collectEnumEntries().map { it.symbol }.toMutableSet(), - !nullable - ) - for (branch in whenExpression.branches) { - branch.condition.accept(EnumExhaustivenessVisitor, data) - } - return data.containsNull && data.remainingEntries.isEmpty() - } - - private class EnumExhaustivenessData(val remainingEntries: MutableSet>, var containsNull: Boolean) - - private object EnumExhaustivenessVisitor : FirVisitor() { - override fun visitElement(element: FirElement, data: EnumExhaustivenessData) {} - - override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: EnumExhaustivenessData) { - if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) { - when (val argument = equalityOperatorCall.arguments[1]) { - is FirConstExpression<*> -> { - if (argument.value == null) { - data.containsNull = true - } - } - is FirQualifiedAccessExpression -> { - if (argument.typeRef.isNullableNothing) { - data.containsNull = true - return - } - val reference = argument.calleeReference as? FirResolvedNamedReference ?: return - val symbol = (reference.resolvedSymbol.fir as? FirEnumEntry)?.symbol ?: return - - data.remainingEntries.remove(symbol) - } - } - } - } - - override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: EnumExhaustivenessData) { - if (binaryLogicExpression.kind == LogicOperationKind.OR) { - binaryLogicExpression.acceptChildren(this, data) - } - } - } - - // ------------------------ Sealed class exhaustiveness ------------------------ - - private fun checkSealedClassExhaustiveness( - whenExpression: FirWhenExpression, - sealedClass: FirRegularClass, - nullable: Boolean - ): Boolean { - if (sealedClass.sealedInheritors.isNullOrEmpty()) return true - - val data = SealedExhaustivenessData(sealedClass, !nullable) - for (branch in whenExpression.branches) { - branch.condition.accept(SealedExhaustivenessVisitor, data) - } - return data.isExhaustive() - } - - private inner class SealedExhaustivenessData(regularClass: FirRegularClass, var containsNull: Boolean) { - val symbolProvider = bodyResolveComponents.symbolProvider - private val rootNode = SealedClassInheritors(regularClass.classId, regularClass.sealedInheritors.mapToSealedInheritors()) - - private fun List?.mapToSealedInheritors(): MutableSet? { - if (this.isNullOrEmpty()) return null - - return this.mapNotNull { - val inheritor = symbolProvider.getClassLikeSymbolByFqName(it)?.fir as? FirRegularClass ?: return@mapNotNull null - SealedClassInheritors(inheritor.classId, inheritor.sealedInheritors.mapToSealedInheritors()) - }.takeIf { it.isNotEmpty() }?.toMutableSet() - } - - fun removeInheritor(classId: ClassId) { - if (rootNode.classId == classId) { - rootNode.inheritors?.clear() + val subjectType = whenExpression.subjectVariable?.returnTypeRef?.coneType + ?: whenExpression.subject?.typeRef?.coneType + ?: run { + whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH) return } - rootNode.removeInheritor(classId) - } + val session = bodyResolveComponents.session + val cleanSubjectType = subjectType.fullyExpandedType(session).lowerBoundIfFlexible() - fun isExhaustive() = containsNull && rootNode.isEmpty() - } - - private data class SealedClassInheritors(val classId: ClassId, val inheritors: MutableSet? = null) { - fun removeInheritor(classId: ClassId): Boolean { - return inheritors != null && (inheritors.removeIf { it.classId == classId } || inheritors.any { it.removeInheritor(classId) }) - } - - fun isEmpty(): Boolean { - return inheritors != null && inheritors.all { it.isEmpty() } - } - } - - private object SealedExhaustivenessVisitor : FirDefaultVisitor() { - override fun visitElement(element: FirElement, data: SealedExhaustivenessData) {} - - override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: SealedExhaustivenessData) { - if (typeOperatorCall.operation == FirOperation.IS) { - typeOperatorCall.conversionTypeRef.accept(this, data) + val checkers = buildList { + exhaustivenessCheckers.filterTo(this) { it.isApplicable(cleanSubjectType, session) } + if (isNotEmpty() && cleanSubjectType.isMarkedNullable) { + add(WhenOnNullableExhaustivenessChecker) } } - override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: SealedExhaustivenessData) { - if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) { - when (val argument = equalityOperatorCall.arguments[1]) { - is FirConstExpression<*> -> { - if (argument.value == null) { - data.containsNull = true - } - } - - is FirResolvedQualifier -> { - argument.typeRef.accept(this, data) - } - - is FirQualifiedAccessExpression -> { - if (argument.typeRef.isNullableNothing) { - data.containsNull = true - } - } - } - } + if (checkers.isEmpty()) { + whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH) + return + } + val whenMissingCases = mutableListOf() + for (checker in checkers) { + checker.computeMissingCases(whenExpression, cleanSubjectType, session, whenMissingCases) + } + if (whenMissingCases.isEmpty() && whenExpression.branches.isEmpty()) { + whenMissingCases.add(WhenMissingCase.Unknown) } - override fun visitResolvedTypeRef(resolvedTypeRef: FirResolvedTypeRef, data: SealedExhaustivenessData) { - val lookupTag = (resolvedTypeRef.type as? ConeLookupTagBasedType)?.lookupTag ?: return - val symbol = data.symbolProvider.getSymbolByLookupTag(lookupTag) as? FirClassSymbol ?: return - data.removeInheritor(symbol.classId) + val status = if (whenMissingCases.isEmpty()) { + ExhaustivenessStatus.Exhaustive + } else { + ExhaustivenessStatus.NotExhaustive(whenMissingCases) + } + whenExpression.replaceExhaustivenessStatus(status) + } +} + +private sealed class WhenExhaustivenessChecker { + abstract fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean + abstract fun computeMissingCases( + whenExpression: FirWhenExpression, + subjectType: ConeKotlinType, + session: FirSession, + destination: MutableCollection + ) + + protected abstract class AbstractConditionChecker : FirVisitor() { + override fun visitElement(element: FirElement, data: D) {} + + override fun visitWhenExpression(whenExpression: FirWhenExpression, data: D) { + whenExpression.branches.forEach { it.accept(this, data) } } - override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: SealedExhaustivenessData) { - if (binaryLogicExpression.kind == LogicOperationKind.OR) { + override fun visitWhenBranch(whenBranch: FirWhenBranch, data: D) { + whenBranch.condition.accept(this, data) + } + + override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: D) { + if (binaryLogicExpression.kind == OR) { binaryLogicExpression.acceptChildren(this, data) } } } +} - // ------------------------ Boolean exhaustiveness ------------------------ - - private fun checkBooleanExhaustiveness(whenExpression: FirWhenExpression, nullable: Boolean): Boolean { - val flags = BooleanExhaustivenessFlags(!nullable) - for (branch in whenExpression.branches) { - branch.condition.accept(BooleanExhaustivenessVisitor, flags) - } - return flags.containsTrue && flags.containsFalse && flags.containsNull +private object WhenOnNullableExhaustivenessChecker : WhenExhaustivenessChecker() { + override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean { + return subjectType.isNullable } - private class BooleanExhaustivenessFlags(var containsNull: Boolean) { + override fun computeMissingCases( + whenExpression: FirWhenExpression, + subjectType: ConeKotlinType, + session: FirSession, + destination: MutableCollection + ) { + val flags = Flags() + whenExpression.accept(ConditionChecker, flags) + if (!flags.containsNull) { + destination.add(WhenMissingCase.NullIsMissing) + } + } + + private class Flags { + var containsNull = false + } + + private object ConditionChecker : AbstractConditionChecker() { + override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) { + val argument = equalityOperatorCall.arguments[1] + if (argument.typeRef.isNullableNothing) { + data.containsNull = true + } + } + + override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) { + if (typeOperatorCall.conversionTypeRef.coneType.isNullable) { + data.containsNull = true + } + } + } +} + +private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker() { + override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean { + return subjectType.classId == StandardClassIds.Boolean + } + + private class Flags { var containsTrue = false var containsFalse = false } - private object BooleanExhaustivenessVisitor : FirVisitor() { - override fun visitElement(element: FirElement, data: BooleanExhaustivenessFlags) {} + override fun computeMissingCases( + whenExpression: FirWhenExpression, + subjectType: ConeKotlinType, + session: FirSession, + destination: MutableCollection + ) { + val flags = Flags() + whenExpression.accept(ConditionChecker, flags) + if (!flags.containsFalse) { + destination.add(WhenMissingCase.BooleanIsMissing.False) + } + if (!flags.containsTrue) { + destination.add(WhenMissingCase.BooleanIsMissing.True) + } + } - override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: BooleanExhaustivenessFlags) { + private object ConditionChecker : AbstractConditionChecker() { + override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) { if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) { val argument = equalityOperatorCall.arguments[1] if (argument is FirConstExpression<*>) { when (argument.value) { true -> data.containsTrue = true false -> data.containsFalse = true - null -> data.containsNull = true } } } } + } +} - override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: BooleanExhaustivenessFlags) { - if (binaryLogicExpression.kind == LogicOperationKind.OR) { - binaryLogicExpression.acceptChildren(this, data) - } +private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() { + override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean { + val symbol = subjectType.toSymbol(session) as? FirRegularClassSymbol ?: return false + return symbol.fir.classKind == ClassKind.ENUM_CLASS + } + + override fun computeMissingCases( + whenExpression: FirWhenExpression, + subjectType: ConeKotlinType, + session: FirSession, + destination: MutableCollection + ) { + val enumClass = (subjectType.toSymbol(session) as FirRegularClassSymbol).fir + val allEntries = enumClass.declarations.mapNotNullTo(mutableSetOf()) { it as? FirEnumEntry } + val checkedEntries = mutableSetOf() + whenExpression.accept(ConditionChecker, checkedEntries) + val notCheckedEntries = allEntries - checkedEntries + notCheckedEntries.mapTo(destination) { WhenMissingCase.EnumCheckIsMissing(it.symbol.callableId) } + } + + private object ConditionChecker : AbstractConditionChecker>() { + override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: MutableSet) { + if (!equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) return + val argument = equalityOperatorCall.arguments[1] + val symbol = argument.toResolvedCallableReference()?.resolvedSymbol as? FirVariableSymbol<*> ?: return + val checkedEnumEntry = symbol.fir as? FirEnumEntry ?: return + data.add(checkedEnumEntry) + } + } +} + +private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecker() { + override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean { + return (subjectType.toSymbol(session)?.fir as? FirRegularClass)?.modality == Modality.SEALED + } + + override fun computeMissingCases( + whenExpression: FirWhenExpression, + subjectType: ConeKotlinType, + session: FirSession, + destination: MutableCollection + ) { + val allSubclasses = subjectType.toSymbol(session)?.collectAllSubclasses(session) ?: return + val checkedSubclasses = mutableSetOf>() + whenExpression.accept(ConditionChecker, Flags(allSubclasses, checkedSubclasses, session)) + (allSubclasses - checkedSubclasses).mapNotNullTo(destination) { + when (it) { + is FirClassSymbol<*> -> WhenMissingCase.IsTypeCheckIsMissing(it.classId, it.fir.classKind.isSingleton) + is FirVariableSymbol<*> -> WhenMissingCase.EnumCheckIsMissing(it.callableId) + else -> null + } + } + } + + private class Flags( + val allSubclasses: Set>, + val checkedSubclasses: MutableSet>, + val session: FirSession + ) + + private object ConditionChecker : AbstractConditionChecker() { + override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) { + val isNegated = when (equalityOperatorCall.operation) { + FirOperation.EQ, FirOperation.IDENTITY -> false + FirOperation.NOT_EQ, FirOperation.NOT_IDENTITY -> true + else -> return + } + + val symbol = when (val argument = equalityOperatorCall.arguments[1]) { + is FirResolvedQualifier -> { + val firClass = (argument.symbol as? FirRegularClassSymbol)?.fir + if (firClass?.classKind == ClassKind.OBJECT) { + firClass.symbol + } else { + firClass?.companionObject?.symbol + } + } + else -> { + argument.toResolvedCallableSymbol()?.takeIf { it.fir is FirEnumEntry } + } + } ?: return + processBranch(symbol, isNegated, data) + } + + override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) { + val isNegated = when (typeOperatorCall.operation) { + FirOperation.IS -> false + FirOperation.NOT_IS -> true + else -> return + } + val symbol = typeOperatorCall.conversionTypeRef.coneType.fullyExpandedType(data.session).toSymbol(data.session) ?: return + processBranch(symbol, isNegated, data) + } + + private fun processBranch(symbolToCheck: AbstractFirBasedSymbol<*>, isNegated: Boolean, flags: Flags) { + val subclassesOfType = symbolToCheck.collectAllSubclasses(flags.session) + if (subclassesOfType.none { it in flags.allSubclasses }) { + return + } + val checkedSubclasses = if (isNegated) flags.allSubclasses - subclassesOfType else subclassesOfType + flags.checkedSubclasses.addAll(checkedSubclasses) + } + } + + + private fun AbstractFirBasedSymbol<*>.collectAllSubclasses(session: FirSession): Set> { + return mutableSetOf>().apply { collectAllSubclassesTo(this, session) } + } + + private fun AbstractFirBasedSymbol<*>.collectAllSubclassesTo(destination: MutableSet>, session: FirSession) { + if (this !is FirRegularClassSymbol) { + destination.add(this) + return + } + when { + fir.modality == Modality.SEALED -> fir.sealedInheritors?.forEach { + val symbol = session.symbolProvider.getClassLikeSymbolByFqName(it) as? FirRegularClassSymbol + symbol?.collectAllSubclassesTo(destination, session) + } + fir.classKind == ClassKind.ENUM_CLASS -> fir.collectEnumEntries().mapTo(destination) { it.symbol } + else -> destination.add(this) } } } diff --git a/compiler/fir/tree/src/org/jetbrains/kotlin/fir/expressions/ExhaustivenessStatus.kt b/compiler/fir/tree/src/org/jetbrains/kotlin/fir/expressions/ExhaustivenessStatus.kt index 50966403df1..ddac6cf5b4f 100644 --- a/compiler/fir/tree/src/org/jetbrains/kotlin/fir/expressions/ExhaustivenessStatus.kt +++ b/compiler/fir/tree/src/org/jetbrains/kotlin/fir/expressions/ExhaustivenessStatus.kt @@ -5,19 +5,61 @@ package org.jetbrains.kotlin.fir.expressions +import org.jetbrains.kotlin.fir.symbols.CallableId import org.jetbrains.kotlin.name.ClassId sealed class ExhaustivenessStatus { object Exhaustive : ExhaustivenessStatus() - class NotExhaustive(val reasons: List) : ExhaustivenessStatus() + class NotExhaustive(val reasons: List) : ExhaustivenessStatus() { + companion object { + val NO_ELSE_BRANCH = NotExhaustive(listOf(WhenMissingCase.Unknown)) + } + } } -sealed class WhenMissingCase { - object Unknown : WhenMissingCase() - object NullIsMissing : WhenMissingCase() - class BooleanIsMissing(val value: Boolean) : WhenMissingCase() - class IsTypeCheckIsMissing(val classId: ClassId) : WhenMissingCase() - class EnumCheckIsMissing(val classId: ClassId) : WhenMissingCase() +sealed class WhenMissingCase() { + abstract val branchConditionText: String + + object Unknown : WhenMissingCase() { + override fun toString(): String = "unknown" + + override val branchConditionText: String = "else" + } + + object NullIsMissing : WhenMissingCase() { + override val branchConditionText: String = "null" + } + + sealed class BooleanIsMissing(val value: Boolean) : WhenMissingCase() { + object True : BooleanIsMissing(true) + object False : BooleanIsMissing(false) + + override val branchConditionText: String = value.toString() + } + + class IsTypeCheckIsMissing(val classId: ClassId, val isSingleton: Boolean) : WhenMissingCase() { + override val branchConditionText: String = run { + val fqName = classId.asSingleFqName().toString() + if (isSingleton) fqName else "is $fqName" + } + + override fun toString(): String { + val name = classId.shortClassName.identifier + return if (isSingleton) name else "is $name" + } + } + + class EnumCheckIsMissing(val callableId: CallableId) : WhenMissingCase() { + override val branchConditionText: String = callableId.asFqNameForDebugInfo().toString() + + override fun toString(): String { + return callableId.callableName.identifier + } + } + + override fun toString(): String { + return branchConditionText + } } val FirWhenExpression.isExhaustive: Boolean diff --git a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnRoot.fir.kt b/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnRoot.fir.kt index c393564b456..04b049c2dc3 100644 --- a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnRoot.fir.kt +++ b/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnRoot.fir.kt @@ -18,6 +18,6 @@ fun test2(x: Stmt): String = } fun test3(x: Expr): String = - when (x) { + when (x) { is Stmt -> "stmt" } diff --git a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.fir.kt b/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.fir.kt deleted file mode 100644 index fab86d219d0..00000000000 --- a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.fir.kt +++ /dev/null @@ -1,34 +0,0 @@ -sealed class Base { - sealed class A : Base() { - object A1 : A() - sealed class A2 : A() - } - sealed class B : Base() { - sealed class B1 : B() - object B2 : B() - } - - fun foo() = when (this) { - is A -> 1 - is B.B1 -> 2 - B.B2 -> 3 - // No else required - } - - fun bar() = when (this) { - is A -> 1 - is B.B1 -> 2 - } - - fun baz() = when (this) { - is A -> 1 - B.B2 -> 3 - // No else required (no possible B1 instances) - } - - fun negated() = when (this) { - !is A -> 1 - A.A1 -> 2 - is A.A2 -> 3 - } -} diff --git a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.kt b/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.kt index 292767ade22..598ce99c605 100644 --- a/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.kt +++ b/compiler/testData/diagnostics/tests/sealed/ExhaustiveOnTree.kt @@ -1,3 +1,4 @@ +// FIR_IDENTICAL sealed class Base { sealed class A : Base() { object A1 : A()