From 771c839d7422cbb24bb6dd0983c84fc0737ffcf5 Mon Sep 17 00:00:00 2001 From: Jinseong Jeon Date: Tue, 27 Oct 2020 01:14:29 -0700 Subject: [PATCH] FIR DFA: element-wise join at merging points of try expression --- .../fir/resolve/dfa/FirDataFlowAnalyzer.kt | 29 ++++++-- .../kotlin/fir/resolve/dfa/LogicSystem.kt | 6 ++ .../fir/resolve/dfa/PersistentLogicSystem.kt | 69 +++++++++++++++++++ .../tryCatch/correctSmartcasts.fir.kt | 8 +-- .../tryCatch/correctSmartcasts_after.fir.kt | 8 +-- .../tryCatch/falsePositiveSmartcasts.fir.kt | 2 +- .../falsePositiveSmartcasts_after.fir.kt | 2 +- 7 files changed, 108 insertions(+), 16 deletions(-) diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt index 2689e1fddf1..6c915ae8c07 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt @@ -649,7 +649,8 @@ abstract class FirDataFlowAnalyzer( fun enterCatchClause(catch: FirCatch) { // NB: fork to isolate effects inside the catch clause // Otherwise, changes in the catch clause could affect the previous node: try main block. - graphBuilder.enterCatchClause(catch).mergeIncomingFlow(updateReceivers = true, shouldForkFlow = true) + // NB: element-wise join due to multiple incoming flows: try main enter and exit + graphBuilder.enterCatchClause(catch).mergeIncomingFlowElementwise(updateReceivers = true, shouldForkFlow = true) } fun exitCatchClause(catch: FirCatch) { @@ -659,7 +660,8 @@ abstract class FirDataFlowAnalyzer( fun enterFinallyBlock() { // NB: fork to isolate effects inside the finally block // Otherwise, changes in the finally block could affect the previous nodes: try main block and catch clauses. - graphBuilder.enterFinallyBlock().mergeIncomingFlow(shouldForkFlow = true) + // NB: element-wise join due to multiple incoming flows: try expression enter, try main exit, and catch exits + graphBuilder.enterFinallyBlock().mergeIncomingFlowElementwise(shouldForkFlow = true) } fun exitFinallyBlock(tryExpression: FirTryExpression) { @@ -670,7 +672,8 @@ abstract class FirDataFlowAnalyzer( val (tryExpressionExitNode, unionNode) = graphBuilder.exitTryExpression(callCompleted) // NB: fork to prevent effects after the try expression from being flown into the try expression // Otherwise, changes in any following nodes could affect the previous nodes, including try main block and finally block if any. - tryExpressionExitNode.mergeIncomingFlow(shouldForkFlow = true) + // NB: element-wise join due to multiple incoming flows: try main exit and catch exits (if no finally exists) + tryExpressionExitNode.mergeIncomingFlowElementwise(shouldForkFlow = true) unionNode?.let { unionFlowFromArguments(it) } } @@ -899,8 +902,11 @@ abstract class FirDataFlowAnalyzer( } if (isAssignment) { - if (initializer is FirConstExpression<*> && initializer.kind == FirConstKind.Null) return - flow.addTypeStatement(propertyVariable typeEq initializer.typeRef.coneType) + if (initializer is FirConstExpression<*> && initializer.kind == FirConstKind.Null) { + flow.addTypeStatement(propertyVariable typeEq property.returnTypeRef.coneType.withNullability(ConeNullability.NULLABLE)) + } else { + flow.addTypeStatement(propertyVariable typeEq initializer.typeRef.coneType) + } } } @@ -1121,12 +1127,23 @@ abstract class FirDataFlowAnalyzer( private fun > T.mergeIncomingFlow( updateReceivers: Boolean = false, shouldForkFlow: Boolean = false + ): T = foldIncomingFlow(logicSystem::joinFlow, updateReceivers, shouldForkFlow) + + private fun > T.mergeIncomingFlowElementwise( + updateReceivers: Boolean = false, + shouldForkFlow: Boolean = false + ): T = foldIncomingFlow(logicSystem::elementwiseJoinFlow, updateReceivers, shouldForkFlow) + + private inline fun > T.foldIncomingFlow( + mergeOperation: (Collection) -> FLOW, + updateReceivers: Boolean = false, + shouldForkFlow: Boolean = false ): T = this.also { node -> val previousFlows = if (node.isDead) node.previousNodes.mapNotNull { runIf(!node.incomingEdges.getValue(it).kind.isBack) { it.flow } } else node.previousNodes.mapNotNull { prev -> prev.takeIf { node.incomingEdges.getValue(it).kind.usedInDfa }?.flow } - var flow = logicSystem.joinFlow(previousFlows) + var flow = mergeOperation.invoke(previousFlows) if (updateReceivers) { logicSystem.updateAllReceivers(flow) } diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/LogicSystem.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/LogicSystem.kt index 20c2001ec71..d0abe642ecd 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/LogicSystem.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/LogicSystem.kt @@ -28,9 +28,15 @@ abstract class LogicSystem(protected val context: ConeInferenceCont abstract fun createEmptyFlow(): FLOW abstract fun forkFlow(flow: FLOW): FLOW + + // Differential computation abstract fun joinFlow(flows: Collection): FLOW abstract fun unionFlow(flows: Collection): FLOW + // Comprehensive element-wise computation + abstract fun elementwiseJoinFlow(flows: Collection): FLOW + abstract fun elementwiseUnionFlow(flows: Collection): FLOW + abstract fun addTypeStatement(flow: FLOW, statement: TypeStatement) abstract fun addImplication(flow: FLOW, implication: Implication) diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/PersistentLogicSystem.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/PersistentLogicSystem.kt index e1d8ff91d28..5db207daafa 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/PersistentLogicSystem.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/PersistentLogicSystem.kt @@ -275,6 +275,75 @@ abstract class PersistentLogicSystem(context: ConeInferenceContext) : LogicSyste } } + override fun elementwiseJoinFlow(flows: Collection): PersistentFlow { + return elementwiseFoldFlow( + flows, + mergeOperation = { statements -> this.or(statements).takeIf { it.isNotEmpty } } + ) + } + + override fun elementwiseUnionFlow(flows: Collection): PersistentFlow { + return elementwiseFoldFlow( + flows, + mergeOperation = this::and + ) + } + + private inline fun elementwiseFoldFlow( + flows: Collection, + mergeOperation: (Collection) -> MutableTypeStatement?, + ): PersistentFlow { + if (flows.isEmpty()) return createEmptyFlow() + flows.singleOrNull()?.let { return it } + + val aliasedVariablesThatDontChangeAlias = computeAliasesThatDontChange(flows) + + val commonFlow = flows.reduce(::lowestCommonFlow) + + // >>> comprehensive element-wise fold >>> + val variables = flows.flatMap { it.approvedTypeStatements.keys }.toSet() + for (variable in variables) { + val info = mergeOperation(flows.map { it.getApprovedTypeStatements(variable, commonFlow) }) ?: continue + removeAllAboutVariable(commonFlow, variable) + commonFlow.addApprovedStatements(info) + } + // <<< comprehensive element-wise fold <<< + + commonFlow.addVariableAliases(aliasedVariablesThatDontChangeAlias) + + updateAllReceivers(commonFlow) + + return commonFlow + } + + @OptIn(DfaInternals::class) + private fun PersistentFlow.getApprovedTypeStatements(variable: RealVariable, parentFlow: PersistentFlow): MutableTypeStatement { + var flow = this + val result = MutableTypeStatement(variable) + val variableUnderAlias = directAliasMap[variable] + if (variableUnderAlias == null) { + // >>> comprehensive element-wise fold >>> + // get approved type statement even though the starting flow == parent flow + if (flow == parentFlow) { + flow.approvedTypeStatements[variable]?.let { + result += it + } + } else { + while (flow != parentFlow) { + flow.approvedTypeStatements[variable]?.let { + result += it + } + flow = flow.previousFlow!! + } + } + // <<< comprehensive element-wise fold <<< + } else { + result.exactType.addIfNotNull(variableUnderAlias.originalType) + flow.approvedTypeStatements[variableUnderAlias.variable]?.let { result += it } + } + return result + } + override fun addTypeStatement(flow: PersistentFlow, statement: TypeStatement) { if (statement.isEmpty) return with(flow) { diff --git a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts.fir.kt b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts.fir.kt index 34a9ab25056..b24665e36d8 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts.fir.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts.fir.kt @@ -16,7 +16,7 @@ fun test1(s: String?) { requireNotNull(s) } t2.not() - s.length + s.length } } @@ -45,7 +45,7 @@ fun test3() { s = null return } - s.length + s.length } fun test4() { @@ -61,7 +61,7 @@ fun test4() { catch (e: ExcB) { } - s.length + s.length } fun test5(s: String?) { @@ -74,7 +74,7 @@ fun test5(s: String?) { catch (e: ExcB) { } - s.length + s.length } fun test6(s: String?) { diff --git a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts_after.fir.kt b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts_after.fir.kt index ca50d78463d..48729b80495 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts_after.fir.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/correctSmartcasts_after.fir.kt @@ -17,7 +17,7 @@ fun test1(s: String?) { requireNotNull(s) } t2.not() - s.length + s.length } } @@ -46,7 +46,7 @@ fun test3() { s = null return } - s.length + s.length } fun test4() { @@ -62,7 +62,7 @@ fun test4() { catch (e: ExcB) { } - s.length + s.length } fun test5(s: String?) { @@ -75,7 +75,7 @@ fun test5(s: String?) { catch (e: ExcB) { } - s.length + s.length } fun test6(s: String?) { diff --git a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts.fir.kt b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts.fir.kt index 1cadd1dd25a..156ddc884e2 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts.fir.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts.fir.kt @@ -101,6 +101,6 @@ fun test6(s1: String?, s2: String?) { requireNotNull(s2) } s.length - s1.length + s1.length s2.length } \ No newline at end of file diff --git a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts_after.fir.kt b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts_after.fir.kt index 2cd1a23c5ea..0e40b596a63 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts_after.fir.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/tryCatch/falsePositiveSmartcasts_after.fir.kt @@ -103,6 +103,6 @@ fun test6(s1: String?, s2: String?) { requireNotNull(s2) } s.length - s1.length + s1.length s2.length } \ No newline at end of file