diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt index b0df3b5e09f..3c953b5cb94 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt @@ -409,13 +409,6 @@ private fun BodyResolveComponents.transformExpressionUsingSm stability: PropertyStability, typesFromSmartCast: MutableList, ): FirSmartCastExpression? { - val smartcastStability = stability.impliedSmartcastStability - ?: if (dataFlowAnalyzer.isAccessToUnstableLocalVariable(expression)) { - SmartcastStability.CAPTURED_VARIABLE - } else { - SmartcastStability.STABLE_VALUE - } - val originalType = expression.resolvedType.fullyExpandedType(session) val allTypes = typesFromSmartCast.also { if (originalType !is ConeStubType) { @@ -430,6 +423,13 @@ private fun BodyResolveComponents.transformExpressionUsingSm type = intersectedType } + val smartcastStability = stability.impliedSmartcastStability + ?: if (dataFlowAnalyzer.isAccessToUnstableLocalVariable(expression, intersectedType)) { + SmartcastStability.CAPTURED_VARIABLE + } else { + SmartcastStability.STABLE_VALUE + } + // Example (1): if (x is String) { ... }, where x: dynamic // the dynamic type will "consume" all other, erasing information. // Example (2): if (x == null) { ... }, diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt index ea098c6ccc5..f78f0783daa 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt @@ -131,10 +131,36 @@ abstract class FirDataFlowAnalyzer( // ----------------------------------- Requests ----------------------------------- - fun isAccessToUnstableLocalVariable(expression: FirExpression): Boolean = - context.variableAssignmentAnalyzer.isAccessToUnstableLocalVariable(expression) + /** + * When variable access resolution encounters a variable access which has smartcast information, assignments associated with that + * variable are checked to determine variable stability, and therefore smartcast stability. These assignments are tracked by + * [FirLocalVariableAssignmentAnalyzer], which knows how each assignment may limit variable stability, like assignments within or after + * a non-in-place lambda body. So for a given lexical scope (function body, lambda body, and even local class init) and a given + * variable, [FirLocalVariableAssignmentAnalyzer] knows all associated assignments (past and/or future) which could limit stability. + * + * When a [targetType] is provided, all assignments are checked for the specified variable access expression: + * 1. If there are no assignments, the variable is always considered **stable**. + * 2. If there is an unresolved assignment type, the variable is considered **unstable**. + * 3. If any resolved assignment type is not a subtype of the [targetType], the variable is considered **unstable**. + * 4. If none of the previous conditions are true, the variable is considered **stable**. + * + * When a [targetType] is **not** provided, **any** assignments cause the variable to be considered **unstable**. + * + * @param expression The variable access expression. + * @param targetType Smartcast target type (optional: see function description). + * + * @see [getTypeUsingSmartcastInfo] + * @see [FirLocalVariableAssignmentAnalyzer.isAccessToUnstableLocalVariable] + * @see [FirLocalVariableAssignmentAnalyzer.isStableType] + */ + fun isAccessToUnstableLocalVariable(expression: FirExpression, targetType: ConeKotlinType?): Boolean = + context.variableAssignmentAnalyzer.isAccessToUnstableLocalVariable(expression, targetType, components.session) /** + * Retrieve smartcast type information [FirDataFlowAnalyzer] may have for the specified variable access expression. Type information + * is **stateful** and changes as the FIR tree is navigated by [FirDataFlowAnalyzer]. + * + * @param expression The variable access expression. * @param ignoreCallArguments Should be set to `true` when call argument flow should not be used for smart-casting. This is important * because the receiver of implicit `invoke` calls is visited *after* the call arguments due to tower resolution. */ @@ -170,6 +196,12 @@ abstract class FirDataFlowAnalyzer( } localFunctionNode?.mergeIncomingFlow() functionEnterNode.mergeIncomingFlow { _, flow -> + /* + * Anonymous functions which can be revisited, either in-place or not in-place, are treated as repeatable statements. This + * causes any assignments to local variables within the anonymous function body to clear type statements for those local + * variables. + * TODO(KT-57678): is it possible for FirLocalVariableAssignmentAnalyzer to handle this for both lambdas and loops? + */ if (function is FirAnonymousFunction && function.invocationKind?.canBeRevisited() != false) { enterRepeatableStatement(flow, function) } @@ -1044,8 +1076,13 @@ abstract class FirDataFlowAnalyzer( } fun exitVariableAssignment(assignment: FirVariableAssignment) { + val property = assignment.calleeReference?.toResolvedPropertySymbol()?.fir + if (property != null && property.isLocal) { + context.variableAssignmentAnalyzer.visitAssignment(property, assignment.rValue.resolvedType) + } + graphBuilder.exitVariableAssignment(assignment).mergeIncomingFlow { _, flow -> - val property = assignment.calleeReference?.toResolvedPropertySymbol()?.fir ?: return@mergeIncomingFlow + property ?: return@mergeIncomingFlow if (property.isLocal || property.isVal) { exitVariableInitialization(flow, assignment.rValue, property, assignment.lValue, hasExplicitType = false) } else { @@ -1110,7 +1147,7 @@ abstract class FirDataFlowAnalyzer( private val RealVariable.hasLocalStability get() = stability == PropertyStability.LOCAL_VAR private fun RealVariable.isStableOrLocalStableAccess(access: FirExpression): Boolean { - return isStable || (hasLocalStability && !isAccessToUnstableLocalVariable(access)) + return isStable || (hasLocalStability && !isAccessToUnstableLocalVariable(access, targetType = null)) } fun exitThrowExceptionNode(throwExpression: FirThrowExpression) { diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirLocalVariableAssignmentAnalyzer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirLocalVariableAssignmentAnalyzer.kt index d17ad4cfacf..0bff6fbf497 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirLocalVariableAssignmentAnalyzer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirLocalVariableAssignmentAnalyzer.kt @@ -7,6 +7,7 @@ package org.jetbrains.kotlin.fir.resolve.dfa import org.jetbrains.kotlin.contracts.description.isInPlace import org.jetbrains.kotlin.fir.FirElement +import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.expressions.* import org.jetbrains.kotlin.fir.references.FirNamedReference @@ -15,8 +16,11 @@ import org.jetbrains.kotlin.fir.references.toResolvedPropertySymbol import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol import org.jetbrains.kotlin.fir.expressions.explicitReceiver +import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.typeContext import org.jetbrains.kotlin.fir.visitors.FirVisitor import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.types.AbstractTypeChecker /** * Helper that checks if an access to a local variable access is stable. @@ -28,8 +32,9 @@ import org.jetbrains.kotlin.name.Name internal class FirLocalVariableAssignmentAnalyzer { private var rootFunction: FirFunctionSymbol<*>? = null private var assignedLocalVariablesByDeclaration: Map, Fork>? = null + private var variableAssignments: Map>? = null - private val scopes: Stack>> = stackOf() + private val scopes: Stack> = stackOf() // Example of control-flow-postponed lambdas: callBoth({ a.x }, { a = null }) // Lambdas are called in an unknown order, so control flow edges to both of them go from before the call. @@ -47,19 +52,20 @@ internal class FirLocalVariableAssignmentAnalyzer { fun reset() { rootFunction = null assignedLocalVariablesByDeclaration = null + variableAssignments = null postponedLambdas.reset() scopes.reset() } /** Checks whether the given access is an unstable access to a local variable at this moment. */ @OptIn(DfaInternals::class) - fun isAccessToUnstableLocalVariable(fir: FirExpression): Boolean { + fun isAccessToUnstableLocalVariable(fir: FirExpression, targetType: ConeKotlinType?, session: FirSession): Boolean { if (assignedLocalVariablesByDeclaration == null) return false val realFir = fir.unwrapElement() as? FirQualifiedAccessExpression ?: return false val property = realFir.calleeReference.toResolvedPropertySymbol()?.fir ?: return false // Have data => have a root function => `scopes` is not empty. - return property in scopes.top().second || postponedLambdas.all().any { lambdas -> + return !isStableType(scopes.top().second[property], targetType, session) || postponedLambdas.all().any { lambdas -> // Control-flow-postponed lambdas' assignments should be in `functionScopes.top()`. // The reason we can't check them here is that one of the entries may be the lambda // that is currently being analyzed, and assignments in it are, in fact, totally fine. @@ -67,20 +73,40 @@ internal class FirLocalVariableAssignmentAnalyzer { } } + private fun isStableType(assignments: Collection?, targetType: ConeKotlinType?, session: FirSession): Boolean { + if (assignments == null) return true // No assignments => always stable. + if (targetType == null) return false // No target type => always unstable. + if (assignments.any { it.type == null }) return false // At least 1 unknown assignment type => always unstable. + + // Stability is determined by assignments. All assignments must be a subtype of the target type. + return assignments.all { AbstractTypeChecker.isSubtypeOf(session.typeContext, it.type!!, targetType) } + } + private fun getInfoForDeclaration(symbol: FirBasedSymbol<*>): Fork? { val root = rootFunction ?: return null if (root == symbol) return null - val cachedMap = assignedLocalVariablesByDeclaration ?: run { - val data = MiniCfgBuilder.MiniCfgData() - MiniCfgBuilder().visitElement(root.fir, data) - data.forks.also { assignedLocalVariablesByDeclaration = it } - } + val cachedMap = buildInfoForRoot(root) return cachedMap[symbol] } - private fun enterScope(symbol: FirBasedSymbol<*>, evaluatedInPlace: Boolean): Pair> { + private fun buildInfoForRoot(root: FirFunctionSymbol<*>): Map, Fork> { + assignedLocalVariablesByDeclaration?.let { return it } + + val data = MiniCfgBuilder.MiniCfgData() + MiniCfgBuilder().visitElement(root.fir, data) + + assignedLocalVariablesByDeclaration = data.forks + variableAssignments = data.assignments + + return data.forks + } + + private fun enterScope( + symbol: FirBasedSymbol<*>, + evaluatedInPlace: Boolean, + ): Pair { val currentInfo = getInfoForDeclaration(symbol) - val prohibitInThisScope = scopes.top().second.toMutableSet() + val prohibitInThisScope = scopes.top().second.copy() scopes.push(currentInfo to prohibitInThisScope) if (!evaluatedInPlace) { for ((outerInfo, prohibitInOuterScope) in scopes.all()) { @@ -100,9 +126,9 @@ internal class FirLocalVariableAssignmentAnalyzer { // } // FE1.0 has the same behavior. // KT-59692 - currentInfo?.assignedInside?.let(prohibitInOuterScope::addAll) + prohibitInOuterScope.merge(currentInfo?.assignedInside) // => any write to a variable outside the callable invalidates smart casts inside it - outerInfo?.assignedLater?.let(prohibitInThisScope::addAll) + prohibitInThisScope.merge(outerInfo?.assignedLater) } } return scopes.top() @@ -111,7 +137,7 @@ internal class FirLocalVariableAssignmentAnalyzer { fun enterFunction(function: FirFunction) { if (rootFunction == null) { rootFunction = function.symbol - scopes.push(null to mutableSetOf()) + scopes.push(null to VariableAssignments()) return } val (info, prohibitSmartCasts) = @@ -119,7 +145,7 @@ internal class FirLocalVariableAssignmentAnalyzer { for (concurrentLambdas in postponedLambdas.all()) { for ((otherLambda, dataFlowOnly) in concurrentLambdas) { if (!dataFlowOnly && otherLambda != info) { - prohibitSmartCasts += otherLambda.assignedInside + prohibitSmartCasts.merge(otherLambda.assignedInside) } } } @@ -130,6 +156,7 @@ internal class FirLocalVariableAssignmentAnalyzer { if (scopes.isEmpty) { rootFunction = null assignedLocalVariablesByDeclaration = null + variableAssignments = null } } @@ -138,7 +165,7 @@ internal class FirLocalVariableAssignmentAnalyzer { val (info, prohibitSmartCasts) = enterScope(klass.symbol, klass is FirAnonymousObject) if (klass is FirAnonymousObject && info != null) { // Assignments in initializers and methods invalidate smart casts in other members. - prohibitSmartCasts.addAll(info.assignedInside) + prohibitSmartCasts.merge(info.assignedInside) } } @@ -175,6 +202,13 @@ internal class FirLocalVariableAssignmentAnalyzer { } } + fun visitAssignment(property: FirProperty, type: ConeKotlinType) { + buildInfoForRoot(rootFunction ?: return) + val assignments = variableAssignments?.get(property) ?: return + val assignment = assignments.firstOrNull { it.type == null } ?: return + assignment.type = type + } + companion object { /** * Computes assigned local variables in each execution path. This analyzer runs before BODY_RESOLVE. Hence, it works on @@ -253,13 +287,50 @@ internal class FirLocalVariableAssignmentAnalyzer { * so that shadowed names are handled correctly. This works because local variables at any scope have higher priority * than members on implicit receivers, even if the implicit receiver is introduced by a later scope. */ - class Fork( - val assignedLater: Set, - val assignedInside: Set, + private class Fork( + val assignedLater: VariableAssignments, + val assignedInside: VariableAssignments, ) + private class Assignment( + var type: ConeKotlinType? = null, + ) + + private class VariableAssignments { + private val assignments: MutableMap> = mutableMapOf() + + operator fun get(property: FirProperty): Set? { + return assignments[property] + } + + operator fun contains(property: FirProperty): Boolean { + return property in assignments + } + + fun add(property: FirProperty, assignment: Assignment) { + assignments.getOrPut(property) { mutableSetOf() }.add(assignment) + } + + fun copy(): VariableAssignments { + val copy = VariableAssignments() + copy.assignments += this.assignments + return copy + } + + fun merge(other: VariableAssignments?) { + if (other == null) return + for ((property, values) in other.assignments) { + assignments.getOrPut(property) { mutableSetOf() }.addAll(values) + } + } + + fun retain(properties: Set) { + assignments.keys.retainAll(properties) + } + } + private class MiniFlow(val parents: Set) { - val assignedLater: MutableSet = mutableSetOf() + val assignedLater = VariableAssignments() fun fork(): MiniFlow = MiniFlow(setOf(this)) @@ -273,13 +344,13 @@ internal class FirLocalVariableAssignmentAnalyzer { element.acceptChildren(this, data) } - private fun visitElementWithLexicalScope(element: FirElement, data: MiniCfgData): Set { + private fun visitElementWithLexicalScope(element: FirElement, data: MiniCfgData): VariableAssignments { // Detach the flow so that variables declared inside the structure do not leak into the outside. val flow = MiniFlow.start() val freeVariables = data.variableDeclarations.flatMapTo(mutableSetOf()) { it.values } data.flow = flow element.acceptChildren(this, data) - return flow.assignedLater.apply { retainAll(freeVariables) } + return flow.assignedLater.apply { retain(freeVariables) } } override fun visitAnonymousFunction(anonymousFunction: FirAnonymousFunction, data: MiniCfgData) = @@ -397,19 +468,26 @@ internal class FirLocalVariableAssignmentAnalyzer { private fun MiniCfgData.recordAssignment(reference: FirReference) { val name = (reference as? FirNamedReference)?.name ?: return val property = variableDeclarations.lastOrNull { name in it }?.get(name) ?: return - flow.recordAssignments(setOf(property)) + + val assignment = Assignment() + assignments.getOrPut(property) { mutableListOf() }.add(assignment) + flow.recordAssignment(property, assignment) } - private fun MiniFlow.recordAssignments(properties: Set) { - // All assignments already recorded here should also have been recorded in all parents, - // so if (properties - assignedLater) is empty, no point in continuing. - if (!assignedLater.addAll(properties)) return + private fun MiniFlow.recordAssignment(property: FirProperty, assignment: Assignment) { + assignedLater.add(property, assignment) + parents.forEach { it.recordAssignment(property, assignment) } + } + + private fun MiniFlow.recordAssignments(properties: VariableAssignments) { + assignedLater.merge(properties) parents.forEach { it.recordAssignments(properties) } } class MiniCfgData { var flow: MiniFlow = MiniFlow.start() val variableDeclarations: ArrayDeque> = ArrayDeque(listOf(mutableMapOf())) + val assignments: MutableMap> = mutableMapOf() val forks: MutableMap, Fork> = mutableMapOf() } } diff --git a/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlace.fir.kt b/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlace.fir.kt index c1d2ce2459e..03168697ce5 100644 --- a/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlace.fir.kt +++ b/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlace.fir.kt @@ -55,7 +55,7 @@ fun test4() { require(x is String) runWithoutContract { x = "" - x.length + x.length } } diff --git a/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlaceAndBounds.fir.kt b/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlaceAndBounds.fir.kt index c5d1ba7b3ec..feb7b8d63af 100644 --- a/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlaceAndBounds.fir.kt +++ b/compiler/testData/diagnostics/tests/smartCasts/lambdasWithContracts/lambdaWithCallInPlaceAndBounds.fir.kt @@ -184,7 +184,7 @@ fun test15() { x = 10 y = x y.length - y.inc() + y.inc() } } @@ -193,7 +193,7 @@ fun test16(x: Any) { runWithoutContract { require(x is String) y = x - y.length + y.length } } diff --git a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.fir.kt b/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.fir.kt index 5f50f78d3ad..009a0af5863 100644 --- a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.fir.kt +++ b/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.fir.kt @@ -1,3 +1,4 @@ +// SKIP_TXT // ISSUE: KT-55338 fun test_1() { @@ -8,7 +9,7 @@ fun test_1() { s = "hello" } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 } } @@ -62,7 +63,7 @@ fun test_3_2() { s = "world" } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 } } @@ -90,7 +91,18 @@ fun test_4_2() { s = getString() } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 + } +} + +fun test_5() { + var s: String? = null + + for (i in 1..10) { + s = null + s = getString() + s.length // smartcast in K1 and K2 + noInlineRun { s.length } // smartcast in K1, unsafe call in K2 } } diff --git a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.kt b/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.kt index 533d08ffc73..1873ff0751d 100644 --- a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.kt +++ b/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.kt @@ -1,3 +1,4 @@ +// SKIP_TXT // ISSUE: KT-55338 fun test_1() { @@ -8,7 +9,7 @@ fun test_1() { s = "hello" } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 } } @@ -62,7 +63,7 @@ fun test_3_2() { s = "world" } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 } } @@ -90,7 +91,18 @@ fun test_4_2() { s = getString() } s.length // smartcast in K1 and K2 - noInlineRun { s.length } // smartcast in K1, unsafe call in K2 <------------ + noInlineRun { s.length } // smartcast in K1 and K2 + } +} + +fun test_5() { + var s: String? = null + + for (i in 1..10) { + s = null + s = getString() + s.length // smartcast in K1 and K2 + noInlineRun { s.length } // smartcast in K1, unsafe call in K2 } } diff --git a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.txt b/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.txt deleted file mode 100644 index 91f22a217ef..00000000000 --- a/compiler/testData/diagnostics/tests/smartCasts/variables/capturedLoopVariable.txt +++ /dev/null @@ -1,12 +0,0 @@ -package - -public fun getNullableString(): kotlin.String? -public fun getString(): kotlin.String -public fun noInlineRun(/*0*/ block: () -> kotlin.Unit): kotlin.Unit -public fun test_1(): kotlin.Unit -public fun test_2_1(): kotlin.Unit -public fun test_2_2(): kotlin.Unit -public fun test_3_1(): kotlin.Unit -public fun test_3_2(): kotlin.Unit -public fun test_4_1(): kotlin.Unit -public fun test_4_2(): kotlin.Unit