[FIR] Reimplement when exhaustiveness checker to fir it's logic with FE 1.0
This commit is contained in:
+2
-2
@@ -107,10 +107,10 @@ FILE: exhaustiveness_sealedSubClass.kt
|
||||
}
|
||||
}
|
||||
.<Unresolved name: plus>#(Int(0))
|
||||
lval d: R|ERROR CLASS: Unresolved name: plus| = when (R|<local>/e|) {
|
||||
lval d: R|kotlin/Int| = when (R|<local>/e|) {
|
||||
($subj$ is R|C|) -> {
|
||||
Int(1)
|
||||
}
|
||||
}
|
||||
.<Unresolved name: plus>#(Int(0))
|
||||
.R|kotlin/Int.plus|(Int(0))
|
||||
}
|
||||
|
||||
+2
-2
@@ -48,7 +48,7 @@ fun test_2(e: A) {
|
||||
is D -> 2
|
||||
}.<!UNRESOLVED_REFERENCE{LT}!><!UNRESOLVED_REFERENCE{PSI}!>plus<!>(0)<!>
|
||||
|
||||
val d = <!NO_ELSE_IN_WHEN!>when<!> (e) {
|
||||
val d = when (e) {
|
||||
is C -> 1
|
||||
}.<!UNRESOLVED_REFERENCE{LT}!><!UNRESOLVED_REFERENCE{PSI}!>plus<!>(0)<!>
|
||||
}.plus(0)
|
||||
}
|
||||
|
||||
+5
-5
@@ -11,19 +11,19 @@ import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.DiagnosticReporter
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.reportOn
|
||||
import org.jetbrains.kotlin.fir.expressions.ExhaustivenessStatus
|
||||
import org.jetbrains.kotlin.fir.expressions.FirWhenExpression
|
||||
import org.jetbrains.kotlin.fir.expressions.isExhaustive
|
||||
|
||||
object FirExhaustiveWhenChecker : FirWhenExpressionChecker() {
|
||||
override fun check(expression: FirWhenExpression, context: CheckerContext, reporter: DiagnosticReporter) {
|
||||
// TODO: add reporting of proper missing clauses, see class WhenMissingCase
|
||||
if (expression.usedAsExpression && !expression.isExhaustive) {
|
||||
val factory = if (expression.source?.isIfExpression == true) {
|
||||
FirErrors.INVALID_IF_AS_EXPRESSION
|
||||
if (expression.source?.isIfExpression == true) {
|
||||
reporter.reportOn(expression.source, FirErrors.INVALID_IF_AS_EXPRESSION, context)
|
||||
} else {
|
||||
FirErrors.NO_ELSE_IN_WHEN
|
||||
val missingCases = (expression.exhaustivenessStatus as ExhaustivenessStatus.NotExhaustive).reasons
|
||||
reporter.reportOn(expression.source, FirErrors.NO_ELSE_IN_WHEN, missingCases, context)
|
||||
}
|
||||
reporter.reportOn(expression.source, factory, context)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -18,6 +18,7 @@ import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMB
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.SYMBOLS
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.TO_STRING
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.VISIBILITY
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirDiagnosticRenderers.WHEN_MISSING_CASES
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_DELEGATED_PROPERTY
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_IN_NON_ABSTRACT_CLASS
|
||||
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors.ABSTRACT_FUNCTION_WITH_BODY
|
||||
@@ -504,7 +505,7 @@ class FirDefaultErrorMessages : DefaultErrorMessages.Extension {
|
||||
//)
|
||||
|
||||
// When expressions
|
||||
map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive")
|
||||
map.put(NO_ELSE_IN_WHEN, "''when'' expression must be exhaustive, add necessary {0}", WHEN_MISSING_CASES)
|
||||
map.put(INVALID_IF_AS_EXPRESSION, "'if' must have both main and 'else' branches if used as an expression")
|
||||
|
||||
// Extended checkers group
|
||||
|
||||
+13
@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.diagnostics.rendering.Renderer
|
||||
import org.jetbrains.kotlin.fir.FirElement
|
||||
import org.jetbrains.kotlin.fir.FirRenderer
|
||||
import org.jetbrains.kotlin.fir.declarations.*
|
||||
import org.jetbrains.kotlin.fir.expressions.WhenMissingCase
|
||||
import org.jetbrains.kotlin.fir.render
|
||||
import org.jetbrains.kotlin.fir.renderWithType
|
||||
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
|
||||
@@ -88,4 +89,16 @@ object FirDiagnosticRenderers {
|
||||
SYMBOL.render(symbol)
|
||||
}
|
||||
}
|
||||
|
||||
private const val WHEN_MISSING_LIMIT = 7
|
||||
|
||||
val WHEN_MISSING_CASES = Renderer { missingCases: List<WhenMissingCase> ->
|
||||
if (missingCases.firstOrNull() == WhenMissingCase.Unknown) {
|
||||
"'else' branch"
|
||||
} else {
|
||||
val list = missingCases.joinToString(", ", limit = WHEN_MISSING_LIMIT) { "'$it'" }
|
||||
val branches = if (missingCases.size > 1) "branches" else "branch"
|
||||
"$list $branches or 'else' branch instead"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -14,6 +14,7 @@ import org.jetbrains.kotlin.fir.FirSourceElement
|
||||
import org.jetbrains.kotlin.fir.declarations.FirCallableDeclaration
|
||||
import org.jetbrains.kotlin.fir.declarations.FirClass
|
||||
import org.jetbrains.kotlin.fir.declarations.FirMemberDeclaration
|
||||
import org.jetbrains.kotlin.fir.expressions.WhenMissingCase
|
||||
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
|
||||
@@ -231,7 +232,7 @@ object FirErrors {
|
||||
// TODO: val UNEXPECTED_SAFE_CALL by ...
|
||||
|
||||
// When expressions
|
||||
val NO_ELSE_IN_WHEN by error0<FirSourceElement, KtWhenExpression>(SourceElementPositioningStrategies.WHEN_EXPRESSION)
|
||||
val NO_ELSE_IN_WHEN by error1<FirSourceElement, KtWhenExpression, List<WhenMissingCase>>(SourceElementPositioningStrategies.WHEN_EXPRESSION)
|
||||
val INVALID_IF_AS_EXPRESSION by error0<FirSourceElement, KtIfExpression>(SourceElementPositioningStrategies.IF_EXPRESSION)
|
||||
|
||||
// Extended checkers group
|
||||
|
||||
+252
-180
@@ -8,21 +8,34 @@ package org.jetbrains.kotlin.fir.resolve.transformers
|
||||
import org.jetbrains.kotlin.descriptors.ClassKind
|
||||
import org.jetbrains.kotlin.descriptors.Modality
|
||||
import org.jetbrains.kotlin.fir.FirElement
|
||||
import org.jetbrains.kotlin.fir.FirSession
|
||||
import org.jetbrains.kotlin.fir.declarations.*
|
||||
import org.jetbrains.kotlin.fir.expressions.*
|
||||
import org.jetbrains.kotlin.fir.expressions.LogicOperationKind.OR
|
||||
import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
|
||||
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
|
||||
import org.jetbrains.kotlin.fir.resolve.BodyResolveComponents
|
||||
import org.jetbrains.kotlin.fir.resolve.getSymbolByLookupTag
|
||||
import org.jetbrains.kotlin.fir.resolve.toSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag
|
||||
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
|
||||
import org.jetbrains.kotlin.fir.resolve.symbolProvider
|
||||
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.StandardClassIds
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirVariableSymbol
|
||||
import org.jetbrains.kotlin.fir.types.*
|
||||
import org.jetbrains.kotlin.fir.visitors.*
|
||||
import org.jetbrains.kotlin.name.ClassId
|
||||
import org.jetbrains.kotlin.fir.visitors.CompositeTransformResult
|
||||
import org.jetbrains.kotlin.fir.visitors.FirTransformer
|
||||
import org.jetbrains.kotlin.fir.visitors.FirVisitor
|
||||
import org.jetbrains.kotlin.fir.visitors.compose
|
||||
|
||||
class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyResolveComponents) : FirTransformer<Nothing?>() {
|
||||
companion object {
|
||||
private val exhaustivenessCheckers = listOf(
|
||||
WhenOnBooleanExhaustivenessChecker,
|
||||
WhenOnEnumExhaustivenessChecker,
|
||||
WhenOnSealedClassExhaustivenessChecker
|
||||
)
|
||||
}
|
||||
|
||||
override fun <E : FirElement> transformElement(element: E, data: Nothing?): CompositeTransformResult<E> {
|
||||
throw IllegalArgumentException("Should not be there")
|
||||
}
|
||||
@@ -32,220 +45,279 @@ class FirWhenExhaustivenessTransformer(private val bodyResolveComponents: BodyRe
|
||||
return whenExpression.compose()
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
private fun processExhaustivenessCheck(whenExpression: FirWhenExpression) {
|
||||
if (whenExpression.branches.any { it.condition is FirElseIfTrueCondition }) {
|
||||
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive)
|
||||
return
|
||||
}
|
||||
|
||||
val typeRef = whenExpression.subjectVariable?.returnTypeRef
|
||||
?: whenExpression.subject?.typeRef
|
||||
?: return
|
||||
|
||||
// TODO: add some report logic about flexible type (see WHEN_ENUM_CAN_BE_NULL_IN_JAVA diagnostic in old frontend)
|
||||
val type = typeRef.coneType.lowerBoundIfFlexible()
|
||||
val lookupTag = (type as? ConeLookupTagBasedType)?.lookupTag ?: return
|
||||
val nullable = type.nullability == ConeNullability.NULLABLE
|
||||
val isExhaustive = when {
|
||||
((lookupTag as? ConeClassLikeLookupTag)?.classId == bodyResolveComponents.session.builtinTypes.booleanType.id) -> {
|
||||
checkBooleanExhaustiveness(whenExpression, nullable)
|
||||
}
|
||||
|
||||
whenExpression.branches.isEmpty() -> false
|
||||
|
||||
else -> {
|
||||
val klass = lookupTag.toSymbol(bodyResolveComponents.session)?.fir as? FirRegularClass ?: return
|
||||
when {
|
||||
klass.classKind == ClassKind.ENUM_CLASS -> checkEnumExhaustiveness(whenExpression, klass, nullable)
|
||||
klass.modality == Modality.SEALED -> checkSealedClassExhaustiveness(whenExpression, klass, nullable)
|
||||
else -> return
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isExhaustive) {
|
||||
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.Exhaustive)
|
||||
} else {
|
||||
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive(listOf()))
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------ Enum exhaustiveness ------------------------
|
||||
|
||||
private fun checkEnumExhaustiveness(whenExpression: FirWhenExpression, enum: FirRegularClass, nullable: Boolean): Boolean {
|
||||
val data = EnumExhaustivenessData(
|
||||
enum.collectEnumEntries().map { it.symbol }.toMutableSet(),
|
||||
!nullable
|
||||
)
|
||||
for (branch in whenExpression.branches) {
|
||||
branch.condition.accept(EnumExhaustivenessVisitor, data)
|
||||
}
|
||||
return data.containsNull && data.remainingEntries.isEmpty()
|
||||
}
|
||||
|
||||
private class EnumExhaustivenessData(val remainingEntries: MutableSet<FirVariableSymbol<FirEnumEntry>>, var containsNull: Boolean)
|
||||
|
||||
private object EnumExhaustivenessVisitor : FirVisitor<Unit, EnumExhaustivenessData>() {
|
||||
override fun visitElement(element: FirElement, data: EnumExhaustivenessData) {}
|
||||
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: EnumExhaustivenessData) {
|
||||
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
|
||||
when (val argument = equalityOperatorCall.arguments[1]) {
|
||||
is FirConstExpression<*> -> {
|
||||
if (argument.value == null) {
|
||||
data.containsNull = true
|
||||
}
|
||||
}
|
||||
is FirQualifiedAccessExpression -> {
|
||||
if (argument.typeRef.isNullableNothing) {
|
||||
data.containsNull = true
|
||||
return
|
||||
}
|
||||
val reference = argument.calleeReference as? FirResolvedNamedReference ?: return
|
||||
val symbol = (reference.resolvedSymbol.fir as? FirEnumEntry)?.symbol ?: return
|
||||
|
||||
data.remainingEntries.remove(symbol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: EnumExhaustivenessData) {
|
||||
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
|
||||
binaryLogicExpression.acceptChildren(this, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------ Sealed class exhaustiveness ------------------------
|
||||
|
||||
private fun checkSealedClassExhaustiveness(
|
||||
whenExpression: FirWhenExpression,
|
||||
sealedClass: FirRegularClass,
|
||||
nullable: Boolean
|
||||
): Boolean {
|
||||
if (sealedClass.sealedInheritors.isNullOrEmpty()) return true
|
||||
|
||||
val data = SealedExhaustivenessData(sealedClass, !nullable)
|
||||
for (branch in whenExpression.branches) {
|
||||
branch.condition.accept(SealedExhaustivenessVisitor, data)
|
||||
}
|
||||
return data.isExhaustive()
|
||||
}
|
||||
|
||||
private inner class SealedExhaustivenessData(regularClass: FirRegularClass, var containsNull: Boolean) {
|
||||
val symbolProvider = bodyResolveComponents.symbolProvider
|
||||
private val rootNode = SealedClassInheritors(regularClass.classId, regularClass.sealedInheritors.mapToSealedInheritors())
|
||||
|
||||
private fun List<ClassId>?.mapToSealedInheritors(): MutableSet<SealedClassInheritors>? {
|
||||
if (this.isNullOrEmpty()) return null
|
||||
|
||||
return this.mapNotNull {
|
||||
val inheritor = symbolProvider.getClassLikeSymbolByFqName(it)?.fir as? FirRegularClass ?: return@mapNotNull null
|
||||
SealedClassInheritors(inheritor.classId, inheritor.sealedInheritors.mapToSealedInheritors())
|
||||
}.takeIf { it.isNotEmpty() }?.toMutableSet()
|
||||
}
|
||||
|
||||
fun removeInheritor(classId: ClassId) {
|
||||
if (rootNode.classId == classId) {
|
||||
rootNode.inheritors?.clear()
|
||||
val subjectType = whenExpression.subjectVariable?.returnTypeRef?.coneType
|
||||
?: whenExpression.subject?.typeRef?.coneType
|
||||
?: run {
|
||||
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH)
|
||||
return
|
||||
}
|
||||
|
||||
rootNode.removeInheritor(classId)
|
||||
}
|
||||
val session = bodyResolveComponents.session
|
||||
val cleanSubjectType = subjectType.fullyExpandedType(session).lowerBoundIfFlexible()
|
||||
|
||||
fun isExhaustive() = containsNull && rootNode.isEmpty()
|
||||
}
|
||||
|
||||
private data class SealedClassInheritors(val classId: ClassId, val inheritors: MutableSet<SealedClassInheritors>? = null) {
|
||||
fun removeInheritor(classId: ClassId): Boolean {
|
||||
return inheritors != null && (inheritors.removeIf { it.classId == classId } || inheritors.any { it.removeInheritor(classId) })
|
||||
}
|
||||
|
||||
fun isEmpty(): Boolean {
|
||||
return inheritors != null && inheritors.all { it.isEmpty() }
|
||||
}
|
||||
}
|
||||
|
||||
private object SealedExhaustivenessVisitor : FirDefaultVisitor<Unit, SealedExhaustivenessData>() {
|
||||
override fun visitElement(element: FirElement, data: SealedExhaustivenessData) {}
|
||||
|
||||
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: SealedExhaustivenessData) {
|
||||
if (typeOperatorCall.operation == FirOperation.IS) {
|
||||
typeOperatorCall.conversionTypeRef.accept(this, data)
|
||||
val checkers = buildList {
|
||||
exhaustivenessCheckers.filterTo(this) { it.isApplicable(cleanSubjectType, session) }
|
||||
if (isNotEmpty() && cleanSubjectType.isMarkedNullable) {
|
||||
add(WhenOnNullableExhaustivenessChecker)
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: SealedExhaustivenessData) {
|
||||
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
|
||||
when (val argument = equalityOperatorCall.arguments[1]) {
|
||||
is FirConstExpression<*> -> {
|
||||
if (argument.value == null) {
|
||||
data.containsNull = true
|
||||
}
|
||||
}
|
||||
|
||||
is FirResolvedQualifier -> {
|
||||
argument.typeRef.accept(this, data)
|
||||
}
|
||||
|
||||
is FirQualifiedAccessExpression -> {
|
||||
if (argument.typeRef.isNullableNothing) {
|
||||
data.containsNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (checkers.isEmpty()) {
|
||||
whenExpression.replaceExhaustivenessStatus(ExhaustivenessStatus.NotExhaustive.NO_ELSE_BRANCH)
|
||||
return
|
||||
}
|
||||
val whenMissingCases = mutableListOf<WhenMissingCase>()
|
||||
for (checker in checkers) {
|
||||
checker.computeMissingCases(whenExpression, cleanSubjectType, session, whenMissingCases)
|
||||
}
|
||||
if (whenMissingCases.isEmpty() && whenExpression.branches.isEmpty()) {
|
||||
whenMissingCases.add(WhenMissingCase.Unknown)
|
||||
}
|
||||
|
||||
override fun visitResolvedTypeRef(resolvedTypeRef: FirResolvedTypeRef, data: SealedExhaustivenessData) {
|
||||
val lookupTag = (resolvedTypeRef.type as? ConeLookupTagBasedType)?.lookupTag ?: return
|
||||
val symbol = data.symbolProvider.getSymbolByLookupTag(lookupTag) as? FirClassSymbol ?: return
|
||||
data.removeInheritor(symbol.classId)
|
||||
val status = if (whenMissingCases.isEmpty()) {
|
||||
ExhaustivenessStatus.Exhaustive
|
||||
} else {
|
||||
ExhaustivenessStatus.NotExhaustive(whenMissingCases)
|
||||
}
|
||||
whenExpression.replaceExhaustivenessStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class WhenExhaustivenessChecker {
|
||||
abstract fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean
|
||||
abstract fun computeMissingCases(
|
||||
whenExpression: FirWhenExpression,
|
||||
subjectType: ConeKotlinType,
|
||||
session: FirSession,
|
||||
destination: MutableCollection<WhenMissingCase>
|
||||
)
|
||||
|
||||
protected abstract class AbstractConditionChecker<in D : Any> : FirVisitor<Unit, D>() {
|
||||
override fun visitElement(element: FirElement, data: D) {}
|
||||
|
||||
override fun visitWhenExpression(whenExpression: FirWhenExpression, data: D) {
|
||||
whenExpression.branches.forEach { it.accept(this, data) }
|
||||
}
|
||||
|
||||
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: SealedExhaustivenessData) {
|
||||
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
|
||||
override fun visitWhenBranch(whenBranch: FirWhenBranch, data: D) {
|
||||
whenBranch.condition.accept(this, data)
|
||||
}
|
||||
|
||||
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: D) {
|
||||
if (binaryLogicExpression.kind == OR) {
|
||||
binaryLogicExpression.acceptChildren(this, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------ Boolean exhaustiveness ------------------------
|
||||
|
||||
private fun checkBooleanExhaustiveness(whenExpression: FirWhenExpression, nullable: Boolean): Boolean {
|
||||
val flags = BooleanExhaustivenessFlags(!nullable)
|
||||
for (branch in whenExpression.branches) {
|
||||
branch.condition.accept(BooleanExhaustivenessVisitor, flags)
|
||||
}
|
||||
return flags.containsTrue && flags.containsFalse && flags.containsNull
|
||||
private object WhenOnNullableExhaustivenessChecker : WhenExhaustivenessChecker() {
|
||||
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
|
||||
return subjectType.isNullable
|
||||
}
|
||||
|
||||
private class BooleanExhaustivenessFlags(var containsNull: Boolean) {
|
||||
override fun computeMissingCases(
|
||||
whenExpression: FirWhenExpression,
|
||||
subjectType: ConeKotlinType,
|
||||
session: FirSession,
|
||||
destination: MutableCollection<WhenMissingCase>
|
||||
) {
|
||||
val flags = Flags()
|
||||
whenExpression.accept(ConditionChecker, flags)
|
||||
if (!flags.containsNull) {
|
||||
destination.add(WhenMissingCase.NullIsMissing)
|
||||
}
|
||||
}
|
||||
|
||||
private class Flags {
|
||||
var containsNull = false
|
||||
}
|
||||
|
||||
private object ConditionChecker : AbstractConditionChecker<Flags>() {
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
|
||||
val argument = equalityOperatorCall.arguments[1]
|
||||
if (argument.typeRef.isNullableNothing) {
|
||||
data.containsNull = true
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) {
|
||||
if (typeOperatorCall.conversionTypeRef.coneType.isNullable) {
|
||||
data.containsNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker() {
|
||||
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
|
||||
return subjectType.classId == StandardClassIds.Boolean
|
||||
}
|
||||
|
||||
private class Flags {
|
||||
var containsTrue = false
|
||||
var containsFalse = false
|
||||
}
|
||||
|
||||
private object BooleanExhaustivenessVisitor : FirVisitor<Unit, BooleanExhaustivenessFlags>() {
|
||||
override fun visitElement(element: FirElement, data: BooleanExhaustivenessFlags) {}
|
||||
override fun computeMissingCases(
|
||||
whenExpression: FirWhenExpression,
|
||||
subjectType: ConeKotlinType,
|
||||
session: FirSession,
|
||||
destination: MutableCollection<WhenMissingCase>
|
||||
) {
|
||||
val flags = Flags()
|
||||
whenExpression.accept(ConditionChecker, flags)
|
||||
if (!flags.containsFalse) {
|
||||
destination.add(WhenMissingCase.BooleanIsMissing.False)
|
||||
}
|
||||
if (!flags.containsTrue) {
|
||||
destination.add(WhenMissingCase.BooleanIsMissing.True)
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: BooleanExhaustivenessFlags) {
|
||||
private object ConditionChecker : AbstractConditionChecker<Flags>() {
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
|
||||
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
|
||||
val argument = equalityOperatorCall.arguments[1]
|
||||
if (argument is FirConstExpression<*>) {
|
||||
when (argument.value) {
|
||||
true -> data.containsTrue = true
|
||||
false -> data.containsFalse = true
|
||||
null -> data.containsNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: BooleanExhaustivenessFlags) {
|
||||
if (binaryLogicExpression.kind == LogicOperationKind.OR) {
|
||||
binaryLogicExpression.acceptChildren(this, data)
|
||||
}
|
||||
private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
|
||||
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
|
||||
val symbol = subjectType.toSymbol(session) as? FirRegularClassSymbol ?: return false
|
||||
return symbol.fir.classKind == ClassKind.ENUM_CLASS
|
||||
}
|
||||
|
||||
override fun computeMissingCases(
|
||||
whenExpression: FirWhenExpression,
|
||||
subjectType: ConeKotlinType,
|
||||
session: FirSession,
|
||||
destination: MutableCollection<WhenMissingCase>
|
||||
) {
|
||||
val enumClass = (subjectType.toSymbol(session) as FirRegularClassSymbol).fir
|
||||
val allEntries = enumClass.declarations.mapNotNullTo(mutableSetOf()) { it as? FirEnumEntry }
|
||||
val checkedEntries = mutableSetOf<FirEnumEntry>()
|
||||
whenExpression.accept(ConditionChecker, checkedEntries)
|
||||
val notCheckedEntries = allEntries - checkedEntries
|
||||
notCheckedEntries.mapTo(destination) { WhenMissingCase.EnumCheckIsMissing(it.symbol.callableId) }
|
||||
}
|
||||
|
||||
private object ConditionChecker : AbstractConditionChecker<MutableSet<FirEnumEntry>>() {
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: MutableSet<FirEnumEntry>) {
|
||||
if (!equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) return
|
||||
val argument = equalityOperatorCall.arguments[1]
|
||||
val symbol = argument.toResolvedCallableReference()?.resolvedSymbol as? FirVariableSymbol<*> ?: return
|
||||
val checkedEnumEntry = symbol.fir as? FirEnumEntry ?: return
|
||||
data.add(checkedEnumEntry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecker() {
|
||||
override fun isApplicable(subjectType: ConeKotlinType, session: FirSession): Boolean {
|
||||
return (subjectType.toSymbol(session)?.fir as? FirRegularClass)?.modality == Modality.SEALED
|
||||
}
|
||||
|
||||
override fun computeMissingCases(
|
||||
whenExpression: FirWhenExpression,
|
||||
subjectType: ConeKotlinType,
|
||||
session: FirSession,
|
||||
destination: MutableCollection<WhenMissingCase>
|
||||
) {
|
||||
val allSubclasses = subjectType.toSymbol(session)?.collectAllSubclasses(session) ?: return
|
||||
val checkedSubclasses = mutableSetOf<AbstractFirBasedSymbol<*>>()
|
||||
whenExpression.accept(ConditionChecker, Flags(allSubclasses, checkedSubclasses, session))
|
||||
(allSubclasses - checkedSubclasses).mapNotNullTo(destination) {
|
||||
when (it) {
|
||||
is FirClassSymbol<*> -> WhenMissingCase.IsTypeCheckIsMissing(it.classId, it.fir.classKind.isSingleton)
|
||||
is FirVariableSymbol<*> -> WhenMissingCase.EnumCheckIsMissing(it.callableId)
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Flags(
|
||||
val allSubclasses: Set<AbstractFirBasedSymbol<*>>,
|
||||
val checkedSubclasses: MutableSet<AbstractFirBasedSymbol<*>>,
|
||||
val session: FirSession
|
||||
)
|
||||
|
||||
private object ConditionChecker : AbstractConditionChecker<Flags>() {
|
||||
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
|
||||
val isNegated = when (equalityOperatorCall.operation) {
|
||||
FirOperation.EQ, FirOperation.IDENTITY -> false
|
||||
FirOperation.NOT_EQ, FirOperation.NOT_IDENTITY -> true
|
||||
else -> return
|
||||
}
|
||||
|
||||
val symbol = when (val argument = equalityOperatorCall.arguments[1]) {
|
||||
is FirResolvedQualifier -> {
|
||||
val firClass = (argument.symbol as? FirRegularClassSymbol)?.fir
|
||||
if (firClass?.classKind == ClassKind.OBJECT) {
|
||||
firClass.symbol
|
||||
} else {
|
||||
firClass?.companionObject?.symbol
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
argument.toResolvedCallableSymbol()?.takeIf { it.fir is FirEnumEntry }
|
||||
}
|
||||
} ?: return
|
||||
processBranch(symbol, isNegated, data)
|
||||
}
|
||||
|
||||
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: Flags) {
|
||||
val isNegated = when (typeOperatorCall.operation) {
|
||||
FirOperation.IS -> false
|
||||
FirOperation.NOT_IS -> true
|
||||
else -> return
|
||||
}
|
||||
val symbol = typeOperatorCall.conversionTypeRef.coneType.fullyExpandedType(data.session).toSymbol(data.session) ?: return
|
||||
processBranch(symbol, isNegated, data)
|
||||
}
|
||||
|
||||
private fun processBranch(symbolToCheck: AbstractFirBasedSymbol<*>, isNegated: Boolean, flags: Flags) {
|
||||
val subclassesOfType = symbolToCheck.collectAllSubclasses(flags.session)
|
||||
if (subclassesOfType.none { it in flags.allSubclasses }) {
|
||||
return
|
||||
}
|
||||
val checkedSubclasses = if (isNegated) flags.allSubclasses - subclassesOfType else subclassesOfType
|
||||
flags.checkedSubclasses.addAll(checkedSubclasses)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun AbstractFirBasedSymbol<*>.collectAllSubclasses(session: FirSession): Set<AbstractFirBasedSymbol<*>> {
|
||||
return mutableSetOf<AbstractFirBasedSymbol<*>>().apply { collectAllSubclassesTo(this, session) }
|
||||
}
|
||||
|
||||
private fun AbstractFirBasedSymbol<*>.collectAllSubclassesTo(destination: MutableSet<AbstractFirBasedSymbol<*>>, session: FirSession) {
|
||||
if (this !is FirRegularClassSymbol) {
|
||||
destination.add(this)
|
||||
return
|
||||
}
|
||||
when {
|
||||
fir.modality == Modality.SEALED -> fir.sealedInheritors?.forEach {
|
||||
val symbol = session.symbolProvider.getClassLikeSymbolByFqName(it) as? FirRegularClassSymbol
|
||||
symbol?.collectAllSubclassesTo(destination, session)
|
||||
}
|
||||
fir.classKind == ClassKind.ENUM_CLASS -> fir.collectEnumEntries().mapTo(destination) { it.symbol }
|
||||
else -> destination.add(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,19 +5,61 @@
|
||||
|
||||
package org.jetbrains.kotlin.fir.expressions
|
||||
|
||||
import org.jetbrains.kotlin.fir.symbols.CallableId
|
||||
import org.jetbrains.kotlin.name.ClassId
|
||||
|
||||
sealed class ExhaustivenessStatus {
|
||||
object Exhaustive : ExhaustivenessStatus()
|
||||
class NotExhaustive(val reasons: List<WhenMissingCase>) : ExhaustivenessStatus()
|
||||
class NotExhaustive(val reasons: List<WhenMissingCase>) : ExhaustivenessStatus() {
|
||||
companion object {
|
||||
val NO_ELSE_BRANCH = NotExhaustive(listOf(WhenMissingCase.Unknown))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealed class WhenMissingCase {
|
||||
object Unknown : WhenMissingCase()
|
||||
object NullIsMissing : WhenMissingCase()
|
||||
class BooleanIsMissing(val value: Boolean) : WhenMissingCase()
|
||||
class IsTypeCheckIsMissing(val classId: ClassId) : WhenMissingCase()
|
||||
class EnumCheckIsMissing(val classId: ClassId) : WhenMissingCase()
|
||||
sealed class WhenMissingCase() {
|
||||
abstract val branchConditionText: String
|
||||
|
||||
object Unknown : WhenMissingCase() {
|
||||
override fun toString(): String = "unknown"
|
||||
|
||||
override val branchConditionText: String = "else"
|
||||
}
|
||||
|
||||
object NullIsMissing : WhenMissingCase() {
|
||||
override val branchConditionText: String = "null"
|
||||
}
|
||||
|
||||
sealed class BooleanIsMissing(val value: Boolean) : WhenMissingCase() {
|
||||
object True : BooleanIsMissing(true)
|
||||
object False : BooleanIsMissing(false)
|
||||
|
||||
override val branchConditionText: String = value.toString()
|
||||
}
|
||||
|
||||
class IsTypeCheckIsMissing(val classId: ClassId, val isSingleton: Boolean) : WhenMissingCase() {
|
||||
override val branchConditionText: String = run {
|
||||
val fqName = classId.asSingleFqName().toString()
|
||||
if (isSingleton) fqName else "is $fqName"
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
val name = classId.shortClassName.identifier
|
||||
return if (isSingleton) name else "is $name"
|
||||
}
|
||||
}
|
||||
|
||||
class EnumCheckIsMissing(val callableId: CallableId) : WhenMissingCase() {
|
||||
override val branchConditionText: String = callableId.asFqNameForDebugInfo().toString()
|
||||
|
||||
override fun toString(): String {
|
||||
return callableId.callableName.identifier
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return branchConditionText
|
||||
}
|
||||
}
|
||||
|
||||
val FirWhenExpression.isExhaustive: Boolean
|
||||
|
||||
@@ -18,6 +18,6 @@ fun test2(x: Stmt): String =
|
||||
}
|
||||
|
||||
fun test3(x: Expr): String =
|
||||
<!NO_ELSE_IN_WHEN!>when<!> (x) {
|
||||
when (x) {
|
||||
is Stmt -> "stmt"
|
||||
}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
sealed class Base {
|
||||
sealed class A : Base() {
|
||||
object A1 : A()
|
||||
sealed class A2 : A()
|
||||
}
|
||||
sealed class B : Base() {
|
||||
sealed class B1 : B()
|
||||
object B2 : B()
|
||||
}
|
||||
|
||||
fun foo() = when (this) {
|
||||
is A -> 1
|
||||
is B.B1 -> 2
|
||||
B.B2 -> 3
|
||||
// No else required
|
||||
}
|
||||
|
||||
fun bar() = <!NO_ELSE_IN_WHEN!>when<!> (this) {
|
||||
is A -> 1
|
||||
is B.B1 -> 2
|
||||
}
|
||||
|
||||
fun baz() = when (this) {
|
||||
is A -> 1
|
||||
B.B2 -> 3
|
||||
// No else required (no possible B1 instances)
|
||||
}
|
||||
|
||||
fun negated() = when (this) {
|
||||
!is A -> 1
|
||||
A.A1 -> 2
|
||||
is A.A2 -> 3
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
// FIR_IDENTICAL
|
||||
sealed class Base {
|
||||
sealed class A : Base() {
|
||||
object A1 : A()
|
||||
|
||||
Reference in New Issue
Block a user