[FIR] Reimplement when exhaustiveness checker to fir it's logic with FE 1.0

This commit is contained in:
Dmitriy Novozhilov
2021-02-06 12:34:32 +03:00
parent 2a1c9283a4
commit 18bde2c542
11 changed files with 329 additions and 233 deletions
@@ -107,10 +107,10 @@ FILE: exhaustiveness_sealedSubClass.kt
}
}
.<Unresolved name: plus>#(Int(0))
lval d: R|ERROR CLASS: Unresolved name: plus| = when (R|<local>/e|) {
lval d: R|kotlin/Int| = when (R|<local>/e|) {
($subj$ is R|C|) -> {
Int(1)
}
}
.<Unresolved name: plus>#(Int(0))
.R|kotlin/Int.plus|(Int(0))
}
@@ -48,7 +48,7 @@ fun test_2(e: A) {
is D -> 2
}.<!UNRESOLVED_REFERENCE{LT}!><!UNRESOLVED_REFERENCE{PSI}!>plus<!>(0)<!>
val d = <!NO_ELSE_IN_WHEN!>when<!> (e) {
val d = when (e) {
is C -> 1
}.<!UNRESOLVED_REFERENCE{LT}!><!UNRESOLVED_REFERENCE{PSI}!>plus<!>(0)<!>
}.plus(0)
}
@@ -11,19 +11,19 @@ import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import org.jetbrains.kotlin.fir.analysis.diagnostics.DiagnosticReporter
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors
import org.jetbrains.kotlin.fir.analysis.diagnostics.reportOn
import org.jetbrains.kotlin.fir.expressions.ExhaustivenessStatus
import org.jetbrains.kotlin.fir.expressions.FirWhenExpression
import org.jetbrains.kotlin.fir.expressions.isExhaustive
object FirExhaustiveWhenChecker : FirWhenExpressionChecker() {
override fun check(expression: FirWhenExpression, context: CheckerContext, reporter: DiagnosticReporter) {
// TODO: add reporting of proper missing clauses, see class WhenMissingCase
if (expression.usedAsExpression && !expression.isExhaustive) {
val factory = if (expression.source?.isIfExpression == true) {
FirErrors.INVALID_IF_AS_EXPRESSION
if (expression.source?.isIfExpression == true) {
reporter.reportOn(expression.source, FirErrors.INVALID_IF_AS_EXPRESSION, context)
} else {
FirErrors.NO_ELSE_IN_WHEN
val missingCases = (expression.exhaustivenessStatus as ExhaustivenessStatus.NotExhaustive).reasons
reporter.reportOn(expression.source, FirErrors.NO_ELSE_IN_WHEN, missingCases, context)
}
reporter.reportOn(expression.source, factory, context)
}
}
@@ -18,6 +18,7 @@ import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMB
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMBOLS
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.TO_STRING
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.VISIBILITY
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.WHEN_MISSING_CASES
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_DELEGATED_PROPERTY
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_IN_NON_ABSTRACT_CLASS
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_WITH_BODY
@@ -504,7 +505,7 @@ class FirDefaultErrorMessages : DefaultErrorMessages.Extension {
//)
// When expressions
map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive")
map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive, add necessary {0}", WHEN_MISSING_CASES)
map.put(INVALID_IF_AS_EXPRESSION, "'if' must have both main and 'else' branches if used as an expression")
// Extended checkers group
@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.diagnostics.rendering.Renderer
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirRenderer
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.expressions.WhenMissingCase
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.renderWithType
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
@@ -88,4 +89,16 @@ object FirDiagnosticRenderers {
SYMBOL.render(symbol)
}
}
private const val WHEN_MISSING_LIMIT = 7
val WHEN_MISSING_CASES = Renderer { missingCases: List<WhenMissingCase> ->
if (missingCases.firstOrNull() == WhenMissingCase.Unknown) {
"'else' branch"
} else {
val list = missingCases.joinToString(", ", limit = WHEN_MISSING_LIMIT) { "'$it'" }
val branches = if (missingCases.size > 1) "branches" else "branch"
"$list $branches or 'else' branch instead"
}
}
}
@@ -14,6 +14,7 @@ import org.jetbrains.kotlin.fir.FirSourceElement
import org.jetbrains.kotlin.fir.declarations.FirCallableDeclaration
import org.jetbrains.kotlin.fir.declarations.FirClass
import org.jetbrains.kotlin.fir.declarations.FirMemberDeclaration
import org.jetbrains.kotlin.fir.expressions.WhenMissingCase
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
@@ -231,7 +232,7 @@ object FirErrors {
// TODO: val UNEXPECTED_SAFE_CALL by ...
// When expressions
val NO_ELSE_IN_WHEN by error0<FirSourceElement, KtWhenExpression>(SourceElementPositioningStrategies.WHEN_EXPRESSION)
val NO_ELSE_IN_WHEN by error1<FirSourceElement, KtWhenExpression, List<WhenMissingCase>>(SourceElementPositioningStrategies.WHEN_EXPRESSION)
val INVALID_IF_AS_EXPRESSION by error0<FirSourceElement, KtIfExpression>(SourceElementPositioningStrategies.IF_EXPRESSION)
// Extended checkers group
@@ -8,21 +8,34 @@ package org.jetbrains.kotlin.fir.resolve.transformers
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.LogicOperationKind.OR
import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.resolve.BodyResolveComponents
import org.jetbrains.kotlin.fir.resolve.getSymbolByLookupTag
import org.jetbrains.kotlin.fir.resolve.toSymbol
import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.resolve.symbolProvider
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.StandardClassIds
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirVariableSymbol
import org.jetbrains.kotlin.fir.types.*
import org.jetbrains.kotlin.fir.visitors.*
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.fir.visitors.CompositeTransformResult
import org.jetbrains.kotlin.fir.visitors.FirTransformer
import org.jetbrains.kotlin.fir.visitors.FirVisitor
import org.jetbrains.kotlin.fir.visitors.compose
class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyResolveComponents) : FirTransformer<Nothing?>() {
companion object {
private val exhaustivenessCheckers = listOf(
WhenOnBooleanExhaustivenessChecker,
WhenOnEnumExhaustivenessChecker,
WhenOnSealedClassExhaustivenessChecker
)
}
override fun <E : FirElement> transformElement(element: E, data: Nothing?): CompositeTransformResult<E> {
throw IllegalArgumentException("Should not be there")
}
@@ -32,220 +45,279 @@ class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyRe
return whenExpression.compose()
}
@OptIn(ExperimentalStdlibApi::class)
private fun processExhaustivenessCheck(whenExpression: FirWhenExpression) {
if (whenExpression.branches.any { it.condition is FirElseIfTrueCondition }) {
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive)
return
}
val typeRef = whenExpression.subjectVariable?.returnTypeRef
?: whenExpression.subject?.typeRef
?: return
// TODO: add some report logic about flexible type (see WHEN_ENUM_CAN_BE_NULL_IN_JAVA diagnostic in old frontend)
val type = typeRef.coneType.lowerBoundIfFlexible()
val lookupTag = (type as? ConeLookupTagBasedType)?.lookupTag ?: return
val nullable = type.nullability == ConeNullability.NULLABLE
val isExhaustive = when {
((lookupTag as? ConeClassLikeLookupTag)?.classId == bodyResolveComponents.session.builtinTypes.booleanType.id) -> {
checkBooleanExhaustiveness(whenExpression, nullable)
}
whenExpression.branches.isEmpty() -> false
else -> {
val klass = lookupTag.toSymbol(bodyResolveComponents.session)?.fir as? FirRegularClass ?: return
when {
klass.classKind == ClassKind.ENUM_CLASS -> checkEnumExhaustiveness(whenExpression, klass, nullable)
klass.modality == Modality.SEALED -> checkSealedClassExhaustiveness(whenExpression, klass, nullable)
else -> return
}
}
}
if (isExhaustive) {
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive)
} else {
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive(listOf()))
}
}
// ------------------------ Enum exhaustiveness ------------------------
private fun checkEnumExhaustiveness(whenExpression: FirWhenExpression, enum: FirRegularClass, nullable: Boolean): Boolean {
val data = EnumExhaustivenessData(
enum.collectEnumEntries().map { it.symbol }.toMutableSet(),
!nullable
)
for (branch in whenExpression.branches) {
branch.condition.accept(EnumExhaustivenessVisitor, data)
}
return data.containsNull && data.remainingEntries.isEmpty()
}
private class EnumExhaustivenessData(val remainingEntries: MutableSet<FirVariableSymbol<FirEnumEntry>>, var containsNull: Boolean)
private object EnumExhaustivenessVisitor : FirVisitor<Unit, EnumExhaustivenessData>() {
override fun visitElement(element: FirElement, data: EnumExhaustivenessData) {}
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: EnumExhaustivenessData) {
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
when (val argument = equalityOperatorCall.arguments[1]) {
is FirConstExpression<*> -> {
if (argument.value == null) {
data.containsNull = true
}
}
is FirQualifiedAccessExpression -> {
if (argument.typeRef.isNullableNothing) {
data.containsNull = true
return
}
val reference = argument.calleeReference as? FirResolvedNamedReference ?: return
val symbol = (reference.resolvedSymbol.fir as? FirEnumEntry)?.symbol ?: return
data.remainingEntries.remove(symbol)
}
}
}
}
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: EnumExhaustivenessData) {
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
binaryLogicExpression.acceptChildren(this, data)
}
}
}
// ------------------------ Sealed class exhaustiveness ------------------------
private fun checkSealedClassExhaustiveness(
whenExpression: FirWhenExpression,
sealedClass: FirRegularClass,
nullable: Boolean
): Boolean {
if (sealedClass.sealedInheritors.isNullOrEmpty()) return true
val data = SealedExhaustivenessData(sealedClass, !nullable)
for (branch in whenExpression.branches) {
branch.condition.accept(SealedExhaustivenessVisitor, data)
}
return data.isExhaustive()
}
private inner class SealedExhaustivenessData(regularClass: FirRegularClass, var containsNull: Boolean) {
val symbolProvider = bodyResolveComponents.symbolProvider
private val rootNode = SealedClassInheritors(regularClass.classId, regularClass.sealedInheritors.mapToSealedInheritors())
private fun List<ClassId>?.mapToSealedInheritors(): MutableSet<SealedClassInheritors>? {
if (this.isNullOrEmpty()) return null
return this.mapNotNull {
val inheritor = symbolProvider.getClassLikeSymbolByFqName(it)?.fir as? FirRegularClass ?: return@mapNotNull null
SealedClassInheritors(inheritor.classId, inheritor.sealedInheritors.mapToSealedInheritors())
}.takeIf { it.isNotEmpty() }?.toMutableSet()
}
fun removeInheritor(classId: ClassId) {
if (rootNode.classId == classId) {
rootNode.inheritors?.clear()
val subjectType = whenExpression.subjectVariable?.returnTypeRef?.coneType
?: whenExpression.subject?.typeRef?.coneType
?: run {
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH)
return
}
rootNode.removeInheritor(classId)
}
val session = bodyResolveComponents.session
val cleanSubjectType = subjectType.fullyExpandedType(session).lowerBoundIfFlexible()
fun isExhaustive() = containsNull && rootNode.isEmpty()
}
private data class SealedClassInheritors(val classId: ClassId, val inheritors: MutableSet<SealedClassInheritors>? = null) {
fun removeInheritor(classId: ClassId): Boolean {
return inheritors != null && (inheritors.removeIf { it.classId == classId } || inheritors.any { it.removeInheritor(classId) })
}
fun isEmpty(): Boolean {
return inheritors != null && inheritors.all { it.isEmpty() }
}
}
private object SealedExhaustivenessVisitor : FirDefaultVisitor<Unit, SealedExhaustivenessData>() {
override fun visitElement(element: FirElement, data: SealedExhaustivenessData) {}
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: SealedExhaustivenessData) {
if (typeOperatorCall.operation == FirOperation.IS) {
typeOperatorCall.conversionTypeRef.accept(this, data)
val checkers = buildList {
exhaustivenessCheckers.filterTo(this) { it.isApplicable(cleanSubjectType, session) }
if (isNotEmpty() && cleanSubjectType.isMarkedNullable) {
add(WhenOnNullableExhaustivenessChecker)
}
}
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: SealedExhaustivenessData) {
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
when (val argument = equalityOperatorCall.arguments[1]) {
is FirConstExpression<*> -> {
if (argument.value == null) {
data.containsNull = true
}
}
is FirResolvedQualifier -> {
argument.typeRef.accept(this, data)
}
is FirQualifiedAccessExpression -> {
if (argument.typeRef.isNullableNothing) {
data.containsNull = true
}
}
}
}
if (checkers.isEmpty()) {
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH)
return
}
val whenMissingCases = mutableListOf<WhenMissingCase>()
for (checker in checkers) {
checker.computeMissingCases(whenExpression, cleanSubjectType, session, whenMissingCases)
}
if (whenMissingCases.isEmpty() && whenExpression.branches.isEmpty()) {
whenMissingCases.add(WhenMissingCase.Unknown)
}
override fun visitResolvedTypeRef(resolvedTypeRef: FirResolvedTypeRef, data: SealedExhaustivenessData) {
val lookupTag = (resolvedTypeRef.type as? ConeLookupTagBasedType)?.lookupTag ?: return
val symbol = data.symbolProvider.getSymbolByLookupTag(lookupTag) as? FirClassSymbol ?: return
data.removeInheritor(symbol.classId)
val status = if (whenMissingCases.isEmpty()) {
ExhaustivenessStatus.Exhaustive
} else {
ExhaustivenessStatus.NotExhaustive(whenMissingCases)
}
whenExpression.replaceExhaustivenessStatus(status)
}
}
private sealed class WhenExhaustivenessChecker {
abstract fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean
abstract fun computeMissingCases(
whenExpression: FirWhenExpression,
subjectType: ConeKotlinType,
session: FirSession,
destination: MutableCollection<WhenMissingCase>
)
protected abstract class AbstractConditionChecker<in D : Any> : FirVisitor<Unit, D>() {
override fun visitElement(element: FirElement, data: D) {}
override fun visitWhenExpression(whenExpression: FirWhenExpression, data: D) {
whenExpression.branches.forEach { it.accept(this, data) }
}
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: SealedExhaustivenessData) {
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
override fun visitWhenBranch(whenBranch: FirWhenBranch, data: D) {
whenBranch.condition.accept(this, data)
}
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: D) {
if (binaryLogicExpression.kind == OR) {
binaryLogicExpression.acceptChildren(this, data)
}
}
}
}
// ------------------------ Boolean exhaustiveness ------------------------
private fun checkBooleanExhaustiveness(whenExpression: FirWhenExpression, nullable: Boolean): Boolean {
val flags = BooleanExhaustivenessFlags(!nullable)
for (branch in whenExpression.branches) {
branch.condition.accept(BooleanExhaustivenessVisitor, flags)
}
return flags.containsTrue && flags.containsFalse && flags.containsNull
private object WhenOnNullableExhaustivenessChecker : WhenExhaustivenessChecker() {
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
return subjectType.isNullable
}
private class BooleanExhaustivenessFlags(var containsNull: Boolean) {
override fun computeMissingCases(
whenExpression: FirWhenExpression,
subjectType: ConeKotlinType,
session: FirSession,
destination: MutableCollection<WhenMissingCase>
) {
val flags = Flags()
whenExpression.accept(ConditionChecker, flags)
if (!flags.containsNull) {
destination.add(WhenMissingCase.NullIsMissing)
}
}
private class Flags {
var containsNull = false
}
private object ConditionChecker : AbstractConditionChecker<Flags>() {
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
val argument = equalityOperatorCall.arguments[1]
if (argument.typeRef.isNullableNothing) {
data.containsNull = true
}
}
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) {
if (typeOperatorCall.conversionTypeRef.coneType.isNullable) {
data.containsNull = true
}
}
}
}
private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker() {
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
return subjectType.classId == StandardClassIds.Boolean
}
private class Flags {
var containsTrue = false
var containsFalse = false
}
private object BooleanExhaustivenessVisitor : FirVisitor<Unit, BooleanExhaustivenessFlags>() {
override fun visitElement(element: FirElement, data: BooleanExhaustivenessFlags) {}
override fun computeMissingCases(
whenExpression: FirWhenExpression,
subjectType: ConeKotlinType,
session: FirSession,
destination: MutableCollection<WhenMissingCase>
) {
val flags = Flags()
whenExpression.accept(ConditionChecker, flags)
if (!flags.containsFalse) {
destination.add(WhenMissingCase.BooleanIsMissing.False)
}
if (!flags.containsTrue) {
destination.add(WhenMissingCase.BooleanIsMissing.True)
}
}
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: BooleanExhaustivenessFlags) {
private object ConditionChecker : AbstractConditionChecker<Flags>() {
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
val argument = equalityOperatorCall.arguments[1]
if (argument is FirConstExpression<*>) {
when (argument.value) {
true -> data.containsTrue = true
false -> data.containsFalse = true
null -> data.containsNull = true
}
}
}
}
}
}
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: BooleanExhaustivenessFlags) {
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
binaryLogicExpression.acceptChildren(this, data)
}
private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
val symbol = subjectType.toSymbol(session) as? FirRegularClassSymbol ?: return false
return symbol.fir.classKind == ClassKind.ENUM_CLASS
}
override fun computeMissingCases(
whenExpression: FirWhenExpression,
subjectType: ConeKotlinType,
session: FirSession,
destination: MutableCollection<WhenMissingCase>
) {
val enumClass = (subjectType.toSymbol(session) as FirRegularClassSymbol).fir
val allEntries = enumClass.declarations.mapNotNullTo(mutableSetOf()) { it as? FirEnumEntry }
val checkedEntries = mutableSetOf<FirEnumEntry>()
whenExpression.accept(ConditionChecker, checkedEntries)
val notCheckedEntries = allEntries - checkedEntries
notCheckedEntries.mapTo(destination) { WhenMissingCase.EnumCheckIsMissing(it.symbol.callableId) }
}
private object ConditionChecker : AbstractConditionChecker<MutableSet<FirEnumEntry>>() {
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: MutableSet<FirEnumEntry>) {
if (!equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) return
val argument = equalityOperatorCall.arguments[1]
val symbol = argument.toResolvedCallableReference()?.resolvedSymbol as? FirVariableSymbol<*> ?: return
val checkedEnumEntry = symbol.fir as? FirEnumEntry ?: return
data.add(checkedEnumEntry)
}
}
}
private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecker() {
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
return (subjectType.toSymbol(session)?.fir as? FirRegularClass)?.modality == Modality.SEALED
}
override fun computeMissingCases(
whenExpression: FirWhenExpression,
subjectType: ConeKotlinType,
session: FirSession,
destination: MutableCollection<WhenMissingCase>
) {
val allSubclasses = subjectType.toSymbol(session)?.collectAllSubclasses(session) ?: return
val checkedSubclasses = mutableSetOf<AbstractFirBasedSymbol<*>>()
whenExpression.accept(ConditionChecker, Flags(allSubclasses, checkedSubclasses, session))
(allSubclasses - checkedSubclasses).mapNotNullTo(destination) {
when (it) {
is FirClassSymbol<*> -> WhenMissingCase.IsTypeCheckIsMissing(it.classId, it.fir.classKind.isSingleton)
is FirVariableSymbol<*> -> WhenMissingCase.EnumCheckIsMissing(it.callableId)
else -> null
}
}
}
private class Flags(
val allSubclasses: Set<AbstractFirBasedSymbol<*>>,
val checkedSubclasses: MutableSet<AbstractFirBasedSymbol<*>>,
val session: FirSession
)
private object ConditionChecker : AbstractConditionChecker<Flags>() {
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
val isNegated = when (equalityOperatorCall.operation) {
FirOperation.EQ, FirOperation.IDENTITY -> false
FirOperation.NOT_EQ, FirOperation.NOT_IDENTITY -> true
else -> return
}
val symbol = when (val argument = equalityOperatorCall.arguments[1]) {
is FirResolvedQualifier -> {
val firClass = (argument.symbol as? FirRegularClassSymbol)?.fir
if (firClass?.classKind == ClassKind.OBJECT) {
firClass.symbol
} else {
firClass?.companionObject?.symbol
}
}
else -> {
argument.toResolvedCallableSymbol()?.takeIf { it.fir is FirEnumEntry }
}
} ?: return
processBranch(symbol, isNegated, data)
}
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) {
val isNegated = when (typeOperatorCall.operation) {
FirOperation.IS -> false
FirOperation.NOT_IS -> true
else -> return
}
val symbol = typeOperatorCall.conversionTypeRef.coneType.fullyExpandedType(data.session).toSymbol(data.session) ?: return
processBranch(symbol, isNegated, data)
}
private fun processBranch(symbolToCheck: AbstractFirBasedSymbol<*>, isNegated: Boolean, flags: Flags) {
val subclassesOfType = symbolToCheck.collectAllSubclasses(flags.session)
if (subclassesOfType.none { it in flags.allSubclasses }) {
return
}
val checkedSubclasses = if (isNegated) flags.allSubclasses - subclassesOfType else subclassesOfType
flags.checkedSubclasses.addAll(checkedSubclasses)
}
}
private fun AbstractFirBasedSymbol<*>.collectAllSubclasses(session: FirSession): Set<AbstractFirBasedSymbol<*>> {
return mutableSetOf<AbstractFirBasedSymbol<*>>().apply { collectAllSubclassesTo(this, session) }
}
private fun AbstractFirBasedSymbol<*>.collectAllSubclassesTo(destination: MutableSet<AbstractFirBasedSymbol<*>>, session: FirSession) {
if (this !is FirRegularClassSymbol) {
destination.add(this)
return
}
when {
fir.modality == Modality.SEALED -> fir.sealedInheritors?.forEach {
val symbol = session.symbolProvider.getClassLikeSymbolByFqName(it) as? FirRegularClassSymbol
symbol?.collectAllSubclassesTo(destination, session)
}
fir.classKind == ClassKind.ENUM_CLASS -> fir.collectEnumEntries().mapTo(destination) { it.symbol }
else -> destination.add(this)
}
}
}
@@ -5,19 +5,61 @@
package org.jetbrains.kotlin.fir.expressions
import org.jetbrains.kotlin.fir.symbols.CallableId
import org.jetbrains.kotlin.name.ClassId
sealed class ExhaustivenessStatus {
object Exhaustive : ExhaustivenessStatus()
class NotExhaustive(val reasons: List<WhenMissingCase>) : ExhaustivenessStatus()
class NotExhaustive(val reasons: List<WhenMissingCase>) : ExhaustivenessStatus() {
companion object {
val NO_ELSE_BRANCH = NotExhaustive(listOf(WhenMissingCase.Unknown))
}
}
}
sealed class WhenMissingCase {
object Unknown : WhenMissingCase()
object NullIsMissing : WhenMissingCase()
class BooleanIsMissing(val value: Boolean) : WhenMissingCase()
class IsTypeCheckIsMissing(val classId: ClassId) : WhenMissingCase()
class EnumCheckIsMissing(val classId: ClassId) : WhenMissingCase()
sealed class WhenMissingCase() {
abstract val branchConditionText: String
object Unknown : WhenMissingCase() {
override fun toString(): String = "unknown"
override val branchConditionText: String = "else"
}
object NullIsMissing : WhenMissingCase() {
override val branchConditionText: String = "null"
}
sealed class BooleanIsMissing(val value: Boolean) : WhenMissingCase() {
object True : BooleanIsMissing(true)
object False : BooleanIsMissing(false)
override val branchConditionText: String = value.toString()
}
class IsTypeCheckIsMissing(val classId: ClassId, val isSingleton: Boolean) : WhenMissingCase() {
override val branchConditionText: String = run {
val fqName = classId.asSingleFqName().toString()
if (isSingleton) fqName else "is $fqName"
}
override fun toString(): String {
val name = classId.shortClassName.identifier
return if (isSingleton) name else "is $name"
}
}
class EnumCheckIsMissing(val callableId: CallableId) : WhenMissingCase() {
override val branchConditionText: String = callableId.asFqNameForDebugInfo().toString()
override fun toString(): String {
return callableId.callableName.identifier
}
}
override fun toString(): String {
return branchConditionText
}
}
val FirWhenExpression.isExhaustive: Boolean
@@ -18,6 +18,6 @@ fun test2(x: Stmt): String =
}
fun test3(x: Expr): String =
<!NO_ELSE_IN_WHEN!>when<!> (x) {
when (x) {
is Stmt -> "stmt"
}
@@ -1,34 +0,0 @@
sealed class Base {
sealed class A : Base() {
object A1 : A()
sealed class A2 : A()
}
sealed class B : Base() {
sealed class B1 : B()
object B2 : B()
}
fun foo() = when (this) {
is A -> 1
is B.B1 -> 2
B.B2 -> 3
// No else required
}
fun bar() = <!NO_ELSE_IN_WHEN!>when<!> (this) {
is A -> 1
is B.B1 -> 2
}
fun baz() = when (this) {
is A -> 1
B.B2 -> 3
// No else required (no possible B1 instances)
}
fun negated() = when (this) {
!is A -> 1
A.A1 -> 2
is A.A2 -> 3
}
}
@@ -1,3 +1,4 @@
// FIR_IDENTICAL
sealed class Base {
sealed class A : Base() {
object A1 : A()