From db1d1be4fc92d6fd7225e205997d4e1c7c8b02af Mon Sep 17 00:00:00 2001 From: Alexander Udalov Date: Wed, 16 Feb 2022 01:09:25 +0100 Subject: [PATCH] FIR tree gen: refactor interface/class solver code To be able to reuse it in the upcoming IR tree generator. --- .../util/InterfaceAbstractClassSolver.kt | 122 ++++++++++-------- 1 file changed, 71 insertions(+), 51 deletions(-) diff --git a/compiler/fir/tree/tree-generator/src/org/jetbrains/kotlin/fir/tree/generator/util/InterfaceAbstractClassSolver.kt b/compiler/fir/tree/tree-generator/src/org/jetbrains/kotlin/fir/tree/generator/util/InterfaceAbstractClassSolver.kt index 2fb2d3d1998..6508420d7ee 100644 --- a/compiler/fir/tree/tree-generator/src/org/jetbrains/kotlin/fir/tree/generator/util/InterfaceAbstractClassSolver.kt +++ b/compiler/fir/tree/tree-generator/src/org/jetbrains/kotlin/fir/tree/generator/util/InterfaceAbstractClassSolver.kt @@ -22,46 +22,67 @@ import org.jetbrains.kotlin.fir.tree.generator.model.KindOwner * if `P1` is a class then `P2` can not be a class (because both of them a parents of E` */ +interface Node { + val parents: List + val origin: Node +} + +private class NodeImpl(val element: KindOwner) : Node { + override val parents: List + get() = element.allParents.map(::NodeImpl) + + override val origin: NodeImpl + get() = if (element.origin == element) this else NodeImpl(element.origin) + + override fun equals(other: Any?): Boolean = + other is NodeImpl && element == other.element + + override fun hashCode(): Int = + element.hashCode() +} + fun configureInterfacesAndAbstractClasses(builder: AbstractFirTreeBuilder) { val elements = collectElements(builder) - val elementMapping = ElementMapping(elements) - - val solution = solve2sat(elements, elementMapping) - processRequirementsFromConfig(solution, elementMapping) - updateKinds(solution, elementMapping) + val solution = solveGraphForClassVsInterface( + elements, + elements.filter { it.element.kind?.isInterface == true }, + elements.filter { it.element.kind?.isInterface == false }, + ) + updateKinds(elements, solution) updateSealedKinds(elements) } -private class ElementMapping(val elements: Collection) { - private val varToElements: Map = elements.mapIndexed { index, element -> 2 * index to element.origin }.toMap() + +fun solveGraphForClassVsInterface( + elements: List, requiredInterfaces: Collection, requiredClasses: Collection, +): List { + val elementMapping = ElementMapping(elements) + val solution = solve2sat(elements, elementMapping) + processRequirementsFromConfig(solution, elementMapping, requiredInterfaces, requiredClasses) + return solution +} + +private class ElementMapping(val elements: Collection) { + private val varToElements: Map = elements.mapIndexed { index, element -> 2 * index to element.origin }.toMap() + elements.mapIndexed { index, element -> 2 * index + 1 to element }.toMap() - private val elementsToVar: Map = elements.mapIndexed { index, element -> element.origin to index }.toMap() - private val hasInheritors = elements.map { it to false }.toMap(mutableMapOf()).also { - for (element in elements) { - for (parent in element.allParents) { - it[parent.origin] = true - } - } - } + private val elementsToVar: Map = elements.mapIndexed { index, element -> element.origin to index }.toMap() - operator fun get(element: KindOwner): Int = elementsToVar.getValue(element) - operator fun get(index: Int): KindOwner = varToElements.getValue(index) - - fun hasInheritors(element: KindOwner): Boolean { - return hasInheritors[element]!! - } + operator fun get(element: Node): Int = elementsToVar.getValue(element) + operator fun get(index: Int): Node = varToElements.getValue(index) val size: Int = elements.size } -private fun collectElements(builder: AbstractFirTreeBuilder): List { - return (builder.elements + builder.elements.flatMap { it.allImplementations }).map { it.origin } +private fun collectElements(builder: AbstractFirTreeBuilder): List { + return (builder.elements + builder.elements.flatMap { it.allImplementations }).map { NodeImpl(it.origin) } } -private fun updateKinds(solution: List, elementMapping: ElementMapping) { +private fun updateKinds(nodes: List, solution: List) { + val allParents = nodes.flatMapTo(mutableSetOf()) { element -> element.parents.map { it.origin } } + for (index in solution.indices) { val isClass = solution[index] - val element = elementMapping[index * 2].origin + val node = nodes[index].origin + val element = node.element val existingKind = element.kind if (isClass) { if (existingKind == Implementation.Kind.Interface) @@ -70,7 +91,7 @@ private fun updateKinds(solution: List, elementMapping: ElementMapping) if (existingKind == null) { element.kind = when (element) { is Implementation -> { - if (elementMapping.hasInheritors(element)) + if (node in allParents) Implementation.Kind.AbstractClass else Implementation.Kind.FinalClass @@ -85,8 +106,9 @@ private fun updateKinds(solution: List, elementMapping: ElementMapping) } } -private fun updateSealedKinds(elements: Collection) { - for (element in elements) { +private fun updateSealedKinds(nodes: Collection) { + for (node in nodes) { + val element = node.element if (element is Element) { if (element.isSealed) { element.kind = when (element.kind) { @@ -99,17 +121,22 @@ private fun updateSealedKinds(elements: Collection) { } } -private fun processRequirementsFromConfig(solution: MutableList, elementMapping: ElementMapping) { - fun forceParentsToBeInterfaces(element: KindOwner) { +private fun processRequirementsFromConfig( + solution: MutableList, + elementMapping: ElementMapping, + requiredInterfaces: Collection, + requiredClasses: Collection, +) { + fun forceParentsToBeInterfaces(element: Node) { val origin = element.origin val index = elementMapping[origin] if (!solution[index]) return solution[index] = false - origin.allParents.forEach { forceParentsToBeInterfaces(it) } + origin.parents.forEach { forceParentsToBeInterfaces(it) } } - fun forceInheritorsToBeClasses(element: KindOwner) { - val queue = ArrayDeque() + fun forceInheritorsToBeClasses(element: Node) { + val queue = ArrayDeque() queue.add(element) while (queue.isNotEmpty()) { val e = queue.removeFirst().origin @@ -117,25 +144,18 @@ private fun processRequirementsFromConfig(solution: MutableList, elemen if (solution[index]) continue solution[index] = true for (inheritor in elementMapping.elements) { - if (e in inheritor.allParents.map { it.origin }) { + if (e in inheritor.parents.map { it.origin }) { queue.add(inheritor) } } } } - for (index in solution.indices) { - val element = elementMapping[index * 2] - val kind = element.kind ?: continue - if (kind.isInterface) { - forceParentsToBeInterfaces(element) - } else { - forceInheritorsToBeClasses(element) - } - } + requiredInterfaces.forEach(::forceParentsToBeInterfaces) + requiredClasses.forEach(::forceInheritorsToBeClasses) } -private fun solve2sat(elements: Collection, elementsToVar: ElementMapping): MutableList { +private fun solve2sat(elements: Collection, elementsToVar: ElementMapping): MutableList { val (g, gt) = buildGraphs(elements, elementsToVar) val used = g.indices.mapTo(mutableListOf()) { false } @@ -188,27 +208,27 @@ private fun solve2sat(elements: Collection, elementsToVar: ElementMap } -private fun buildGraphs(elements: Collection, elementMapping: ElementMapping): Pair>, List>> { +private fun buildGraphs(elements: Collection, elementMapping: ElementMapping): Pair>, List>> { val g = (1..elementMapping.size * 2).map { mutableListOf() } val gt = (1..elementMapping.size * 2).map { mutableListOf() } fun Int.direct(): Int = this fun Int.invert(): Int = this + 1 - fun extractIndex(element: KindOwner) = elementMapping[element] * 2 + fun extractIndex(element: Node) = elementMapping[element] * 2 for (element in elements) { val elementVar = extractIndex(element) - for (parent in element.allParents) { + for (parent in element.parents) { val parentVar = extractIndex(parent.origin) // parent -> element g[parentVar.direct()] += elementVar.direct() g[elementVar.invert()] += parentVar.invert() } - for (i in 0 until element.allParents.size) { - for (j in i + 1 until element.allParents.size) { - val firstParentVar = extractIndex(element.allParents[i].origin) - val secondParentVar = extractIndex(element.allParents[j].origin) + for (i in 0 until element.parents.size) { + for (j in i + 1 until element.parents.size) { + val firstParentVar = extractIndex(element.parents[i].origin) + val secondParentVar = extractIndex(element.parents[j].origin) // firstParent -> !secondParent g[firstParentVar.direct()] += secondParentVar.invert() g[secondParentVar.direct()] += firstParentVar.invert()