FIR tree gen: refactor interface/class solver code

To be able to reuse it in the upcoming IR tree generator.
This commit is contained in:
Alexander Udalov
2022-02-16 01:09:25 +01:00
parent ac46ce908f
commit db1d1be4fc
@@ -22,46 +22,67 @@ import org.jetbrains.kotlin.fir.tree.generator.model.KindOwner
* if `P1` is a class then `P2` can not be a class (because both of them a parents of E`
*/
interface Node {
val parents: List<Node>
val origin: Node
}
private class NodeImpl(val element: KindOwner) : Node {
override val parents: List<Node>
get() = element.allParents.map(::NodeImpl)
override val origin: NodeImpl
get() = if (element.origin == element) this else NodeImpl(element.origin)
override fun equals(other: Any?): Boolean =
other is NodeImpl && element == other.element
override fun hashCode(): Int =
element.hashCode()
}
fun configureInterfacesAndAbstractClasses(builder: AbstractFirTreeBuilder) {
val elements = collectElements(builder)
val elementMapping = ElementMapping(elements)
val solution = solve2sat(elements, elementMapping)
processRequirementsFromConfig(solution, elementMapping)
updateKinds(solution, elementMapping)
val solution = solveGraphForClassVsInterface(
elements,
elements.filter { it.element.kind?.isInterface == true },
elements.filter { it.element.kind?.isInterface == false },
)
updateKinds(elements, solution)
updateSealedKinds(elements)
}
private class ElementMapping(val elements: Collection<KindOwner>) {
private val varToElements: Map<Int, KindOwner> = elements.mapIndexed { index, element -> 2 * index to element.origin }.toMap() +
fun solveGraphForClassVsInterface(
elements: List<Node>, requiredInterfaces: Collection<Node>, requiredClasses: Collection<Node>,
): List<Boolean> {
val elementMapping = ElementMapping(elements)
val solution = solve2sat(elements, elementMapping)
processRequirementsFromConfig(solution, elementMapping, requiredInterfaces, requiredClasses)
return solution
}
private class ElementMapping(val elements: Collection<Node>) {
private val varToElements: Map<Int, Node> = elements.mapIndexed { index, element -> 2 * index to element.origin }.toMap() +
elements.mapIndexed { index, element -> 2 * index + 1 to element }.toMap()
private val elementsToVar: Map<KindOwner, Int> = elements.mapIndexed { index, element -> element.origin to index }.toMap()
private val hasInheritors = elements.map { it to false }.toMap(mutableMapOf()).also {
for (element in elements) {
for (parent in element.allParents) {
it[parent.origin] = true
}
}
}
private val elementsToVar: Map<Node, Int> = elements.mapIndexed { index, element -> element.origin to index }.toMap()
operator fun get(element: KindOwner): Int = elementsToVar.getValue(element)
operator fun get(index: Int): KindOwner = varToElements.getValue(index)
fun hasInheritors(element: KindOwner): Boolean {
return hasInheritors[element]!!
}
operator fun get(element: Node): Int = elementsToVar.getValue(element)
operator fun get(index: Int): Node = varToElements.getValue(index)
val size: Int = elements.size
}
private fun collectElements(builder: AbstractFirTreeBuilder): List<KindOwner> {
return (builder.elements + builder.elements.flatMap { it.allImplementations }).map { it.origin }
private fun collectElements(builder: AbstractFirTreeBuilder): List<NodeImpl> {
return (builder.elements + builder.elements.flatMap { it.allImplementations }).map { NodeImpl(it.origin) }
}
private fun updateKinds(solution: List<Boolean>, elementMapping: ElementMapping) {
private fun updateKinds(nodes: List<NodeImpl>, solution: List<Boolean>) {
val allParents = nodes.flatMapTo(mutableSetOf()) { element -> element.parents.map { it.origin } }
for (index in solution.indices) {
val isClass = solution[index]
val element = elementMapping[index * 2].origin
val node = nodes[index].origin
val element = node.element
val existingKind = element.kind
if (isClass) {
if (existingKind == Implementation.Kind.Interface)
@@ -70,7 +91,7 @@ private fun updateKinds(solution: List<Boolean>, elementMapping: ElementMapping)
if (existingKind == null) {
element.kind = when (element) {
is Implementation -> {
if (elementMapping.hasInheritors(element))
if (node in allParents)
Implementation.Kind.AbstractClass
else
Implementation.Kind.FinalClass
@@ -85,8 +106,9 @@ private fun updateKinds(solution: List<Boolean>, elementMapping: ElementMapping)
}
}
private fun updateSealedKinds(elements: Collection<KindOwner>) {
for (element in elements) {
private fun updateSealedKinds(nodes: Collection<NodeImpl>) {
for (node in nodes) {
val element = node.element
if (element is Element) {
if (element.isSealed) {
element.kind = when (element.kind) {
@@ -99,17 +121,22 @@ private fun updateSealedKinds(elements: Collection<KindOwner>) {
}
}
private fun processRequirementsFromConfig(solution: MutableList<Boolean>, elementMapping: ElementMapping) {
fun forceParentsToBeInterfaces(element: KindOwner) {
private fun processRequirementsFromConfig(
solution: MutableList<Boolean>,
elementMapping: ElementMapping,
requiredInterfaces: Collection<Node>,
requiredClasses: Collection<Node>,
) {
fun forceParentsToBeInterfaces(element: Node) {
val origin = element.origin
val index = elementMapping[origin]
if (!solution[index]) return
solution[index] = false
origin.allParents.forEach { forceParentsToBeInterfaces(it) }
origin.parents.forEach { forceParentsToBeInterfaces(it) }
}
fun forceInheritorsToBeClasses(element: KindOwner) {
val queue = ArrayDeque<KindOwner>()
fun forceInheritorsToBeClasses(element: Node) {
val queue = ArrayDeque<Node>()
queue.add(element)
while (queue.isNotEmpty()) {
val e = queue.removeFirst().origin
@@ -117,25 +144,18 @@ private fun processRequirementsFromConfig(solution: MutableList<Boolean>, elemen
if (solution[index]) continue
solution[index] = true
for (inheritor in elementMapping.elements) {
if (e in inheritor.allParents.map { it.origin }) {
if (e in inheritor.parents.map { it.origin }) {
queue.add(inheritor)
}
}
}
}
for (index in solution.indices) {
val element = elementMapping[index * 2]
val kind = element.kind ?: continue
if (kind.isInterface) {
forceParentsToBeInterfaces(element)
} else {
forceInheritorsToBeClasses(element)
}
}
requiredInterfaces.forEach(::forceParentsToBeInterfaces)
requiredClasses.forEach(::forceInheritorsToBeClasses)
}
private fun solve2sat(elements: Collection<KindOwner>, elementsToVar: ElementMapping): MutableList<Boolean> {
private fun solve2sat(elements: Collection<Node>, elementsToVar: ElementMapping): MutableList<Boolean> {
val (g, gt) = buildGraphs(elements, elementsToVar)
val used = g.indices.mapTo(mutableListOf()) { false }
@@ -188,27 +208,27 @@ private fun solve2sat(elements: Collection<KindOwner>, elementsToVar: ElementMap
}
private fun buildGraphs(elements: Collection<KindOwner>, elementMapping: ElementMapping): Pair<List<List<Int>>, List<List<Int>>> {
private fun buildGraphs(elements: Collection<Node>, elementMapping: ElementMapping): Pair<List<List<Int>>, List<List<Int>>> {
val g = (1..elementMapping.size * 2).map { mutableListOf<Int>() }
val gt = (1..elementMapping.size * 2).map { mutableListOf<Int>() }
fun Int.direct(): Int = this
fun Int.invert(): Int = this + 1
fun extractIndex(element: KindOwner) = elementMapping[element] * 2
fun extractIndex(element: Node) = elementMapping[element] * 2
for (element in elements) {
val elementVar = extractIndex(element)
for (parent in element.allParents) {
for (parent in element.parents) {
val parentVar = extractIndex(parent.origin)
// parent -> element
g[parentVar.direct()] += elementVar.direct()
g[elementVar.invert()] += parentVar.invert()
}
for (i in 0 until element.allParents.size) {
for (j in i + 1 until element.allParents.size) {
val firstParentVar = extractIndex(element.allParents[i].origin)
val secondParentVar = extractIndex(element.allParents[j].origin)
for (i in 0 until element.parents.size) {
for (j in i + 1 until element.parents.size) {
val firstParentVar = extractIndex(element.parents[i].origin)
val secondParentVar = extractIndex(element.parents[j].origin)
// firstParent -> !secondParent
g[firstParentVar.direct()] += secondParentVar.invert()
g[secondParentVar.direct()] += firstParentVar.invert()