[Commonizer] Extract CIR classifiers cache from the root node

This commit is contained in:
Dmitriy Dolovov
2020-11-23 15:59:36 +03:00
parent 8d9abed3dc
commit eca231a01d
9 changed files with 71 additions and 50 deletions
@@ -14,6 +14,7 @@ import org.jetbrains.kotlin.descriptors.commonizer.utils.internedClassId
import org.jetbrains.kotlin.descriptors.commonizer.utils.isUnderStandardKotlinPackages
internal class CommonizationVisitor(
private val cache: CirClassifiersCache,
private val root: CirRootNode
) : CirNodeVisitor<Unit, Unit> {
override fun visitRootNode(node: CirRootNode, data: Unit) {
@@ -87,7 +88,7 @@ internal class CommonizationVisitor(
val companionObjectName = node.targetDeclarations.mapTo(HashSet()) { it!!.companion }.singleOrNull()
if (companionObjectName != null) {
val companionObjectClassId = internedClassId(node.classId, companionObjectName)
val companionObjectNode = root.cache.classes[companionObjectClassId]
val companionObjectNode = cache.classNode(companionObjectClassId)
?: error("Can't find companion object with class ID $companionObjectClassId")
if (companionObjectNode.commonDeclaration() != null) {
@@ -131,7 +132,7 @@ internal class CommonizationVisitor(
if (expandedClassId.packageFqName.isUnderStandardKotlinPackages)
return null // this case is not supported
val expandedClassNode = root.cache.classes[expandedClassId] ?: return null
val expandedClassNode = cache.classNode(expandedClassId) ?: return null
val expandedClass = expandedClassNode.targetDeclarations[index]
?: error("Can't find expanded class with class ID $expandedClassId and index $index for type alias $classId")
@@ -147,7 +148,7 @@ internal class CommonizationVisitor(
if (supertypesMap.isNullOrEmpty())
emptyList()
else
supertypesMap.values.compactMapNotNull { supertypesGroup -> commonize(supertypesGroup, TypeCommonizer(root.cache)) }
supertypesMap.values.compactMapNotNull { supertypesGroup -> commonize(supertypesGroup, TypeCommonizer(cache)) }
)
}
}
@@ -9,7 +9,7 @@ import org.jetbrains.kotlin.descriptors.DescriptorVisibility
import org.jetbrains.kotlin.descriptors.commonizer.cir.*
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeFactory
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirClassifiersCache
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirNodeWithClassId
import org.jetbrains.kotlin.descriptors.commonizer.utils.isUnderStandardKotlinPackages
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.types.Variance
@@ -66,7 +66,7 @@ private class ClassTypeCommonizer(private val cache: CirClassifiersCache) : Abst
isMarkedNullable == next.isMarkedNullable
&& classId == next.classifierId
&& outerType.commonizeWith(next.outerType)
&& commonizeClassifier(classId, cache.classes).first
&& commonizeClassifier(classId) { cache.classNode(classId) }.first
&& arguments.commonizeWith(next.arguments)
}
@@ -102,7 +102,7 @@ private class TypeAliasTypeCommonizer(private val cache: CirClassifiersCache) :
return false
if (commonizedTypeBuilder == null) {
val (commonized, commonClassifier) = commonizeClassifier(typeAliasId, cache.typeAliases)
val (commonized, commonClassifier) = commonizeClassifier(typeAliasId) { cache.typeAliasNode(typeAliasId) }
if (!commonized)
return false
@@ -218,7 +218,7 @@ private class TypeArgumentListCommonizer(cache: CirClassifiersCache) : AbstractL
private inline fun <reified T : CirClassifier> commonizeClassifier(
classifierId: ClassId,
classifierNodes: Map<ClassId, CirNode<*, T>>,
classifierNode: (classifierId: ClassId) -> CirNodeWithClassId<*, T>?
): Pair<Boolean, T?> {
if (classifierId.packageFqName.isUnderStandardKotlinPackages) {
/* either class or type alias from Kotlin stdlib */
@@ -226,7 +226,7 @@ private inline fun <reified T : CirClassifier> commonizeClassifier(
}
/* or descriptors themselves can be commonized */
return when (val node = classifierNodes[classifierId]) {
return when (val node = classifierNode(classifierId)) {
null -> {
// No node means that the class or type alias was not subject for commonization at all, probably it lays
// not in commonized module descriptors but somewhere in their dependencies.
@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.descriptors.commonizer.builder.DeclarationsBuilderVi
import org.jetbrains.kotlin.descriptors.commonizer.builder.createGlobalBuilderComponents
import org.jetbrains.kotlin.descriptors.commonizer.core.CommonizationVisitor
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirTreeMerger
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.DefaultCirClassifiersCache
import org.jetbrains.kotlin.storage.LockBasedStorageManager
fun runCommonization(parameters: Parameters): Result {
@@ -19,11 +20,12 @@ fun runCommonization(parameters: Parameters): Result {
val storageManager = LockBasedStorageManager("Declaration descriptors commonization")
// build merged tree:
val mergeResult = CirTreeMerger(storageManager, parameters).merge()
val cache = DefaultCirClassifiersCache()
val mergeResult = CirTreeMerger(storageManager, cache, parameters).merge()
// commonize:
val mergedTree = mergeResult.root
mergedTree.accept(CommonizationVisitor(mergedTree), Unit)
mergedTree.accept(CommonizationVisitor(cache, mergedTree), Unit)
parameters.progressLogger?.invoke("Commonized declarations")
// build resulting descriptors:
@@ -5,9 +5,31 @@
package org.jetbrains.kotlin.descriptors.commonizer.mergedtree
import gnu.trove.THashMap
import org.jetbrains.kotlin.name.ClassId
interface CirClassifiersCache {
val classes: Map<ClassId, CirClassNode>
val typeAliases: Map<ClassId, CirTypeAliasNode>
fun classNode(classId: ClassId): CirClassNode?
fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode?
fun addClassNode(classId: ClassId, node: CirClassNode)
fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode)
}
class DefaultCirClassifiersCache : CirClassifiersCache {
private val classNodes = THashMap<ClassId, CirClassNode>()
private val typeAliases = THashMap<ClassId, CirTypeAliasNode>()
override fun classNode(classId: ClassId): CirClassNode? = classNodes[classId]
override fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode? = typeAliases[typeAliasId]
override fun addClassNode(classId: ClassId, node: CirClassNode) {
val oldNode = classNodes.put(classId, node)
check(oldNode == null) { "Rewriting class node $classId" }
}
override fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode) {
val oldNode = typeAliases.put(typeAliasId, node)
check(oldNode == null) { "Rewriting type alias node $typeAliasId" }
}
}
@@ -8,7 +8,6 @@ package org.jetbrains.kotlin.descriptors.commonizer.mergedtree
import gnu.trove.THashMap
import org.jetbrains.kotlin.descriptors.commonizer.cir.CirRoot
import org.jetbrains.kotlin.descriptors.commonizer.utils.CommonizedGroup
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.storage.NullableLazyValue
@@ -16,13 +15,7 @@ class CirRootNode(
override val targetDeclarations: CommonizedGroup<CirRoot>,
override val commonDeclaration: NullableLazyValue<CirRoot>
) : CirNode<CirRoot, CirRoot> {
class CirClassifiersCacheImpl : CirClassifiersCache {
override val classes = THashMap<ClassId, CirClassNode>()
override val typeAliases = THashMap<ClassId, CirTypeAliasNode>()
}
val modules: MutableMap<Name, CirModuleNode> = THashMap()
val cache = CirClassifiersCacheImpl()
override fun <T, R> accept(visitor: CirNodeVisitor<T, R>, data: T): R =
visitor.visitRootNode(this, data)
@@ -12,7 +12,6 @@ import org.jetbrains.kotlin.descriptors.commonizer.Parameters
import org.jetbrains.kotlin.descriptors.commonizer.TargetProvider
import org.jetbrains.kotlin.descriptors.commonizer.cir.CirClass
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.*
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.utils.intern
import org.jetbrains.kotlin.descriptors.commonizer.utils.internedClassId
import org.jetbrains.kotlin.name.ClassId
@@ -24,6 +23,7 @@ import org.jetbrains.kotlin.storage.StorageManager
class CirTreeMerger(
private val storageManager: StorageManager,
private val cache: CirClassifiersCache,
private val parameters: Parameters
) {
class CirTreeMergeResult(
@@ -32,11 +32,9 @@ class CirTreeMerger(
)
private val size = parameters.targetProviders.size
private lateinit var cacheRW: CirClassifiersCacheImpl
fun merge(): CirTreeMergeResult {
val rootNode: CirRootNode = buildRootNode(storageManager, size)
cacheRW = rootNode.cache
val allModuleInfos: List<Map<String, ModuleInfo>> = parameters.targetProviders.map { it.modulesProvider.loadModuleInfos() }
val commonModuleNames = allModuleInfos.map { it.keys }.reduce { a, b -> a intersect b }
@@ -140,7 +138,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val propertyNode: CirPropertyNode = properties.getOrPut(PropertyApproximationKey(propertyDescriptor)) {
buildPropertyNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildPropertyNode(storageManager, size, cache, parentCommonDeclaration)
}
propertyNode.targetDeclarations[targetIndex] = CirPropertyFactory.create(propertyDescriptor)
}
@@ -152,7 +150,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val functionNode: CirFunctionNode = functions.getOrPut(FunctionApproximationKey(functionDescriptor)) {
buildFunctionNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildFunctionNode(storageManager, size, cache, parentCommonDeclaration)
}
functionNode.targetDeclarations[targetIndex] = CirFunctionFactory.create(functionDescriptor)
}
@@ -168,7 +166,7 @@ class CirTreeMerger(
val classId = classIdFunction(className)
val classNode: CirClassNode = classes.getOrPut(className) {
buildClassNode(storageManager, size, cacheRW, parentCommonDeclaration, classId)
buildClassNode(storageManager, size, cache, parentCommonDeclaration, classId)
}
classNode.targetDeclarations[targetIndex] = CirClassFactory.create(classDescriptor)
@@ -203,7 +201,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val constructorNode: CirClassConstructorNode = constructors.getOrPut(ConstructorApproximationKey(constructorDescriptor)) {
buildClassConstructorNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildClassConstructorNode(storageManager, size, cache, parentCommonDeclaration)
}
constructorNode.targetDeclarations[targetIndex] = CirClassConstructorFactory.create(constructorDescriptor)
}
@@ -218,7 +216,7 @@ class CirTreeMerger(
val typeAliasClassId = internedClassId(packageFqName, typeAliasName)
val typeAliasNode: CirTypeAliasNode = typeAliases.getOrPut(typeAliasName) {
buildTypeAliasNode(storageManager, size, cacheRW, typeAliasClassId)
buildTypeAliasNode(storageManager, size, cache, typeAliasClassId)
}
typeAliasNode.targetDeclarations[targetIndex] = CirTypeAliasFactory.create(typeAliasDescriptor)
}
@@ -9,7 +9,6 @@ import org.jetbrains.kotlin.descriptors.commonizer.cir.*
import org.jetbrains.kotlin.descriptors.commonizer.cir.impl.CirClassRecursionMarker
import org.jetbrains.kotlin.descriptors.commonizer.cir.impl.CirClassifierRecursionMarker
import org.jetbrains.kotlin.descriptors.commonizer.core.*
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.utils.CommonizedGroup
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
@@ -80,18 +79,18 @@ internal fun buildFunctionNode(
internal fun buildClassNode(
storageManager: StorageManager,
size: Int,
cacheRW: CirClassifiersCacheImpl,
cache: CirClassifiersCache,
parentCommonDeclaration: NullableLazyValue<*>?,
classId: ClassId
): CirClassNode = buildNode(
storageManager = storageManager,
size = size,
parentCommonDeclaration = parentCommonDeclaration,
commonizerProducer = { ClassCommonizer(cacheRW) },
commonizerProducer = { ClassCommonizer(cache) },
recursionMarker = CirClassRecursionMarker,
nodeProducer = { targetDeclarations, commonDeclaration ->
CirClassNode(targetDeclarations, commonDeclaration, classId).also {
cacheRW.classes[classId] = it
cache.addClassNode(classId, it)
}
}
)
@@ -112,16 +111,16 @@ internal fun buildClassConstructorNode(
internal fun buildTypeAliasNode(
storageManager: StorageManager,
size: Int,
cacheRW: CirClassifiersCacheImpl,
classId: ClassId
cache: CirClassifiersCache,
typeAliasId: ClassId
): CirTypeAliasNode = buildNode(
storageManager = storageManager,
size = size,
commonizerProducer = { TypeAliasCommonizer(cacheRW) },
commonizerProducer = { TypeAliasCommonizer(cache) },
recursionMarker = CirClassifierRecursionMarker,
nodeProducer = { targetDeclarations, commonDeclaration ->
CirTypeAliasNode(targetDeclarations, commonDeclaration, classId).also {
cacheRW.typeAliases[classId] = it
CirTypeAliasNode(targetDeclarations, commonDeclaration, typeAliasId).also {
cache.addTypeAliasNode(typeAliasId, it)
}
}
)
@@ -11,12 +11,10 @@ import org.jetbrains.kotlin.descriptors.commonizer.cir.CirType
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirClassFactory
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeAliasFactory
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeFactory
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirClassifiersCache
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.buildClassNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.buildTypeAliasNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.*
import org.jetbrains.kotlin.descriptors.commonizer.utils.mockClassType
import org.jetbrains.kotlin.descriptors.commonizer.utils.mockTAType
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.resolve.descriptorUtil.classId
import org.jetbrains.kotlin.storage.LockBasedStorageManager
import org.jetbrains.kotlin.types.KotlinType
@@ -26,11 +24,11 @@ import org.junit.Test
class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
private lateinit var cache: CirClassifiersCacheImpl
private lateinit var cache: CirClassifiersCache
@Before
fun initialize() {
cache = CirClassifiersCacheImpl() // reset cache
cache = DefaultCirClassifiersCache() // reset cache
}
@Test
@@ -467,11 +465,11 @@ class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
when (descriptor) {
is ClassDescriptor -> {
val classId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.classes.getOrPut(classId) {
val node = cache.classNode(classId) {
buildClassNode(
storageManager = LockBasedStorageManager.NO_LOCKS,
size = variants.size,
cacheRW = cache,
cache = cache,
parentCommonDeclaration = null,
classId = classId
)
@@ -479,13 +477,13 @@ class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
node.targetDeclarations[index] = CirClassFactory.create(descriptor)
}
is TypeAliasDescriptor -> {
val classId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.typeAliases.getOrPut(classId) {
val typeAliasId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.typeAliasNode(typeAliasId) {
buildTypeAliasNode(
storageManager = LockBasedStorageManager.NO_LOCKS,
size = variants.size,
cacheRW = cache,
classId = classId
cache = cache,
typeAliasId = typeAliasId
)
}
node.targetDeclarations[index] = CirTypeAliasFactory.create(descriptor)
@@ -526,5 +524,11 @@ class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
companion object {
fun areEqual(cache: CirClassifiersCache, a: CirType, b: CirType): Boolean =
TypeCommonizer(cache).run { commonizeWith(a) && commonizeWith(b) }
private fun CirClassifiersCache.classNode(classId: ClassId, computation: () -> CirClassNode) =
classNode(classId) ?: computation()
private fun CirClassifiersCache.typeAliasNode(typeAliasId: ClassId, computation: () -> CirTypeAliasNode) =
typeAliasNode(typeAliasId) ?: computation()
}
}
@@ -122,8 +122,10 @@ private fun createPackageFragmentForClassifier(classifierFqName: FqName): Packag
}
internal val EMPTY_CLASSIFIERS_CACHE = object : CirClassifiersCache {
override val classes: Map<ClassId, CirClassNode> get() = emptyMap()
override val typeAliases: Map<ClassId, CirTypeAliasNode> get() = emptyMap()
override fun classNode(classId: ClassId): CirClassNode? = null
override fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode? = null
override fun addClassNode(classId: ClassId, node: CirClassNode) = error("This method should not be called")
override fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode) = error("This method should not be called")
}
internal class MockBuiltInsProvider(private val builtIns: KotlinBuiltIns) : BuiltInsProvider {