From 9fa3e2eb13df7f054e8e6484cd9b66eaebbc86b5 Mon Sep 17 00:00:00 2001 From: pyos Date: Mon, 13 Sep 2021 14:46:08 +0200 Subject: [PATCH] FIR: always load parents of nested classes first --- .../AbstractFirDeserializedSymbolsProvider.kt | 49 +++++-------------- .../KotlinDeserializedJvmSymbolsProvider.kt | 37 +++++--------- 2 files changed, 23 insertions(+), 63 deletions(-) diff --git a/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolsProvider.kt b/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolsProvider.kt index 1b9665fa83c..f4b4581f24d 100644 --- a/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolsProvider.kt +++ b/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolsProvider.kt @@ -67,7 +67,7 @@ abstract class AbstractFirDeserializedSymbolsProvider( private val packagePartsCache = session.firCachesFactory.createCache(::tryComputePackagePartInfos) private val typeAliasCache = session.firCachesFactory.createCache(::findAndDeserializeTypeAlias) - protected val classCache: FirCache = + private val classCache: FirCache = session.firCachesFactory.createCacheWithPostCompute( createValue = { classId, context -> findAndDeserializeClass(classId, context) }, postCompute = { _, symbol, postProcessor -> @@ -103,8 +103,6 @@ abstract class AbstractFirDeserializedSymbolsProvider( val sourceElement: DeserializedContainerSource, val classPostProcessor: DeserializedClassPostProcessor ) : ClassMetadataFindResult() - - object ShouldDeserializeViaParent : ClassMetadataFindResult() } private fun tryComputePackagePartInfos(packageFqName: FqName): List { @@ -149,20 +147,10 @@ abstract class AbstractFirDeserializedSymbolsProvider( ) symbol to postProcessor } - ClassMetadataFindResult.ShouldDeserializeViaParent -> findAndDeserializeClassViaParent(classId) to null null -> null to null } } - private fun findAndDeserializeClassViaParent(classId: ClassId): FirRegularClassSymbol? { - val outerClassId = classId.outerClassId ?: return null - //This will cause cyclic cache request that is highly observable in IDE (but not in the compiler - but possible SOE bug also) - //To avoid it in IDE there is special implementation that forces load parent class before any nested class request: - //[org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSessionFactory.KotlinDeserializedJvmSymbolsProviderForIde] - getClass(outerClassId) ?: return null - return classCache.getValueIfComputed(classId) - } - private fun loadFunctionsByCallableId(callableId: CallableId): List { return getPackageParts(callableId.packageName).flatMap { part -> val functionIds = part.topLevelFunctionNameIndex[callableId.callableName] ?: return@flatMap emptyList() @@ -181,37 +169,26 @@ abstract class AbstractFirDeserializedSymbolsProvider( } } - private fun getPackageParts(packageFqName: FqName): Collection { - return packagePartsCache.getValue(packageFqName) - } - - protected open fun shouldLoadParentsFirst(classId: ClassId): Boolean = false + private fun getPackageParts(packageFqName: FqName): Collection = + packagePartsCache.getValue(packageFqName) protected fun getClass( classId: ClassId, parentContext: FirDeserializationContext? = null ): FirRegularClassSymbol? { - if (parentContext == null && shouldLoadParentsFirst(classId)) { - return getClassAfterLoadingParents(classId) + val parentClassId = classId.outerClassId + if (parentContext == null && parentClassId != null) { + val alreadyLoaded = classCache.getValueIfComputed(classId) + if (alreadyLoaded != null) return alreadyLoaded + // Load parent first in case correct `parentContext` is needed to deserialize the metadata of this class. + getClass(parentClassId, null) + // If that's the case, `classCache` should contain a value for `classId`. } return classCache.getValue(classId, parentContext) } - private fun getClassAfterLoadingParents(classId: ClassId): FirRegularClassSymbol? { - classId.outerClassId?.let { parentClassId -> - val alreadyLoaded = classCache.getValueIfComputed(classId) - if (alreadyLoaded != null) return alreadyLoaded - getClassAfterLoadingParents(parentClassId) - } - return classCache.getValue(classId, null) - } - - private fun getTypeAlias( - classId: ClassId, - ): FirTypeAliasSymbol? { - if (!classId.relativeClassName.isOneSegmentFQN()) return null - return typeAliasCache.getValue(classId) - } + private fun getTypeAlias(classId: ClassId): FirTypeAliasSymbol? = + if (classId.relativeClassName.isOneSegmentFQN()) typeAliasCache.getValue(classId) else null // ------------------------ SymbolProvider methods ------------------------ @@ -235,6 +212,4 @@ abstract class AbstractFirDeserializedSymbolsProvider( override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? { return getClass(classId) ?: getTypeAlias(classId) } - - override fun getPackage(fqName: FqName): FqName? = null } diff --git a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/KotlinDeserializedJvmSymbolsProvider.kt b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/KotlinDeserializedJvmSymbolsProvider.kt index 47b10652023..aa152fa3c0a 100644 --- a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/KotlinDeserializedJvmSymbolsProvider.kt +++ b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/KotlinDeserializedJvmSymbolsProvider.kt @@ -19,9 +19,6 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol import org.jetbrains.kotlin.load.java.JavaClassFinder import org.jetbrains.kotlin.load.kotlin.* import org.jetbrains.kotlin.load.kotlin.header.KotlinClassHeader -import org.jetbrains.kotlin.metadata.ProtoBuf -import org.jetbrains.kotlin.metadata.deserialization.Flags -import org.jetbrains.kotlin.metadata.deserialization.NameResolver import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmMetadataVersion import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmProtoBufUtil import org.jetbrains.kotlin.name.ClassId @@ -91,9 +88,6 @@ class KotlinDeserializedJvmSymbolsProvider( private val KotlinJvmBinaryClass.isPreReleaseInvisible: Boolean get() = classHeader.isPreRelease - override fun shouldLoadParentsFirst(classId: ClassId): Boolean = - javaClassConverter.hasTopLevelClassOf(classId) - override fun extractClassMetadata(classId: ClassId, parentContext: FirDeserializationContext?): ClassMetadataFindResult? { // Kotlin classes are annotated Java classes, so this check also looks for them. if (!javaClassConverter.hasTopLevelClassOf(classId)) return null @@ -104,7 +98,7 @@ class KotlinDeserializedJvmSymbolsProvider( return null } val kotlinClass = when (result) { - is KotlinClassFinder.Result.KotlinClass -> result + is KotlinClassFinder.Result.KotlinClass -> result.kotlinJvmBinaryClass is KotlinClassFinder.Result.ClassFileContent -> { val javaClass = try { javaClassConverter.findClass(classId, result.content) ?: return null @@ -115,22 +109,20 @@ class KotlinDeserializedJvmSymbolsProvider( javaClassConverter.convertJavaClassToFir(symbol, classId.outerClassId?.let(::getClass), javaClass) } } - null -> return ClassMetadataFindResult.ShouldDeserializeViaParent - } - if (kotlinClass.kotlinJvmBinaryClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null - val (nameResolver, classProto) = kotlinClass.extractMetadata() ?: return null - - if (parentContext == null && Flags.CLASS_KIND.get(classProto.flags) == ProtoBuf.Class.Kind.COMPANION_OBJECT) { - return ClassMetadataFindResult.ShouldDeserializeViaParent + null -> return null } + if (kotlinClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null + val data = kotlinClass.classHeader.data ?: return null + val strings = kotlinClass.classHeader.strings ?: return null + val (nameResolver, classProto) = JvmProtoBufUtil.readClassDataFrom(data, strings) return ClassMetadataFindResult.Metadata( nameResolver, classProto, - JvmBinaryAnnotationDeserializer(session, kotlinClass.kotlinJvmBinaryClass, kotlinClassFinder, kotlinClass.byteContent), - kotlinClass.kotlinJvmBinaryClass.containingLibrary.toPath(), - KotlinJvmBinarySourceElement(kotlinClass.kotlinJvmBinaryClass), - classPostProcessor = { loadAnnotationsFromClassFile(kotlinClass, it) } + JvmBinaryAnnotationDeserializer(session, kotlinClass, kotlinClassFinder, result.byteContent), + kotlinClass.containingLibrary.toPath(), + KotlinJvmBinarySourceElement(kotlinClass), + classPostProcessor = { loadAnnotationsFromClassFile(result, it) } ) } @@ -141,7 +133,7 @@ class KotlinDeserializedJvmSymbolsProvider( kotlinClass: KotlinClassFinder.Result.KotlinClass, symbol: FirRegularClassSymbol ) { - val annotations = mutableListOf() + val annotations = symbol.fir.annotations as MutableList kotlinClass.kotlinJvmBinaryClass.loadClassAnnotations( object : KotlinJvmBinaryClass.AnnotationVisitor { override fun visitAnnotation(classId: ClassId, source: SourceElement): KotlinJvmBinaryClass.AnnotationArgumentVisitor? { @@ -153,16 +145,9 @@ class KotlinDeserializedJvmSymbolsProvider( }, kotlinClass.byteContent, ) - (symbol.fir.annotations as MutableList) += annotations symbol.fir.replaceDeprecation(symbol.fir.getDeprecationInfos(session.languageVersionSettings.apiVersion)) } - private fun KotlinClassFinder.Result.KotlinClass.extractMetadata(): Pair? { - val data = kotlinJvmBinaryClass.classHeader.data ?: return null - val strings = kotlinJvmBinaryClass.classHeader.strings ?: return null - return JvmProtoBufUtil.readClassDataFrom(data, strings) - } - private fun String?.toPath(): Path? { return this?.let { Paths.get(it).normalize() } }