FIR: always load parents of nested classes first
This commit is contained in:
+12
-37
@@ -67,7 +67,7 @@ abstract class AbstractFirDeserializedSymbolsProvider(
|
||||
|
||||
private val packagePartsCache = session.firCachesFactory.createCache(::tryComputePackagePartInfos)
|
||||
private val typeAliasCache = session.firCachesFactory.createCache(::findAndDeserializeTypeAlias)
|
||||
protected val classCache: FirCache<ClassId, FirRegularClassSymbol?, FirDeserializationContext?> =
|
||||
private val classCache: FirCache<ClassId, FirRegularClassSymbol?, FirDeserializationContext?> =
|
||||
session.firCachesFactory.createCacheWithPostCompute(
|
||||
createValue = { classId, context -> findAndDeserializeClass(classId, context) },
|
||||
postCompute = { _, symbol, postProcessor ->
|
||||
@@ -103,8 +103,6 @@ abstract class AbstractFirDeserializedSymbolsProvider(
|
||||
val sourceElement: DeserializedContainerSource,
|
||||
val classPostProcessor: DeserializedClassPostProcessor
|
||||
) : ClassMetadataFindResult()
|
||||
|
||||
object ShouldDeserializeViaParent : ClassMetadataFindResult()
|
||||
}
|
||||
|
||||
private fun tryComputePackagePartInfos(packageFqName: FqName): List<PackagePartsCacheData> {
|
||||
@@ -149,20 +147,10 @@ abstract class AbstractFirDeserializedSymbolsProvider(
|
||||
)
|
||||
symbol to postProcessor
|
||||
}
|
||||
ClassMetadataFindResult.ShouldDeserializeViaParent -> findAndDeserializeClassViaParent(classId) to null
|
||||
null -> null to null
|
||||
}
|
||||
}
|
||||
|
||||
private fun findAndDeserializeClassViaParent(classId: ClassId): FirRegularClassSymbol? {
|
||||
val outerClassId = classId.outerClassId ?: return null
|
||||
//This will cause cyclic cache request that is highly observable in IDE (but not in the compiler - but possible SOE bug also)
|
||||
//To avoid it in IDE there is special implementation that forces load parent class before any nested class request:
|
||||
//[org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSessionFactory.KotlinDeserializedJvmSymbolsProviderForIde]
|
||||
getClass(outerClassId) ?: return null
|
||||
return classCache.getValueIfComputed(classId)
|
||||
}
|
||||
|
||||
private fun loadFunctionsByCallableId(callableId: CallableId): List<FirNamedFunctionSymbol> {
|
||||
return getPackageParts(callableId.packageName).flatMap { part ->
|
||||
val functionIds = part.topLevelFunctionNameIndex[callableId.callableName] ?: return@flatMap emptyList()
|
||||
@@ -181,37 +169,26 @@ abstract class AbstractFirDeserializedSymbolsProvider(
|
||||
}
|
||||
}
|
||||
|
||||
private fun getPackageParts(packageFqName: FqName): Collection<PackagePartsCacheData> {
|
||||
return packagePartsCache.getValue(packageFqName)
|
||||
}
|
||||
|
||||
protected open fun shouldLoadParentsFirst(classId: ClassId): Boolean = false
|
||||
private fun getPackageParts(packageFqName: FqName): Collection<PackagePartsCacheData> =
|
||||
packagePartsCache.getValue(packageFqName)
|
||||
|
||||
protected fun getClass(
|
||||
classId: ClassId,
|
||||
parentContext: FirDeserializationContext? = null
|
||||
): FirRegularClassSymbol? {
|
||||
if (parentContext == null && shouldLoadParentsFirst(classId)) {
|
||||
return getClassAfterLoadingParents(classId)
|
||||
val parentClassId = classId.outerClassId
|
||||
if (parentContext == null && parentClassId != null) {
|
||||
val alreadyLoaded = classCache.getValueIfComputed(classId)
|
||||
if (alreadyLoaded != null) return alreadyLoaded
|
||||
// Load parent first in case correct `parentContext` is needed to deserialize the metadata of this class.
|
||||
getClass(parentClassId, null)
|
||||
// If that's the case, `classCache` should contain a value for `classId`.
|
||||
}
|
||||
return classCache.getValue(classId, parentContext)
|
||||
}
|
||||
|
||||
private fun getClassAfterLoadingParents(classId: ClassId): FirRegularClassSymbol? {
|
||||
classId.outerClassId?.let { parentClassId ->
|
||||
val alreadyLoaded = classCache.getValueIfComputed(classId)
|
||||
if (alreadyLoaded != null) return alreadyLoaded
|
||||
getClassAfterLoadingParents(parentClassId)
|
||||
}
|
||||
return classCache.getValue(classId, null)
|
||||
}
|
||||
|
||||
private fun getTypeAlias(
|
||||
classId: ClassId,
|
||||
): FirTypeAliasSymbol? {
|
||||
if (!classId.relativeClassName.isOneSegmentFQN()) return null
|
||||
return typeAliasCache.getValue(classId)
|
||||
}
|
||||
private fun getTypeAlias(classId: ClassId): FirTypeAliasSymbol? =
|
||||
if (classId.relativeClassName.isOneSegmentFQN()) typeAliasCache.getValue(classId) else null
|
||||
|
||||
// ------------------------ SymbolProvider methods ------------------------
|
||||
|
||||
@@ -235,6 +212,4 @@ abstract class AbstractFirDeserializedSymbolsProvider(
|
||||
override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? {
|
||||
return getClass(classId) ?: getTypeAlias(classId)
|
||||
}
|
||||
|
||||
override fun getPackage(fqName: FqName): FqName? = null
|
||||
}
|
||||
|
||||
+11
-26
@@ -19,9 +19,6 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
|
||||
import org.jetbrains.kotlin.load.java.JavaClassFinder
|
||||
import org.jetbrains.kotlin.load.kotlin.*
|
||||
import org.jetbrains.kotlin.load.kotlin.header.KotlinClassHeader
|
||||
import org.jetbrains.kotlin.metadata.ProtoBuf
|
||||
import org.jetbrains.kotlin.metadata.deserialization.Flags
|
||||
import org.jetbrains.kotlin.metadata.deserialization.NameResolver
|
||||
import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmMetadataVersion
|
||||
import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmProtoBufUtil
|
||||
import org.jetbrains.kotlin.name.ClassId
|
||||
@@ -91,9 +88,6 @@ class KotlinDeserializedJvmSymbolsProvider(
|
||||
private val KotlinJvmBinaryClass.isPreReleaseInvisible: Boolean
|
||||
get() = classHeader.isPreRelease
|
||||
|
||||
override fun shouldLoadParentsFirst(classId: ClassId): Boolean =
|
||||
javaClassConverter.hasTopLevelClassOf(classId)
|
||||
|
||||
override fun extractClassMetadata(classId: ClassId, parentContext: FirDeserializationContext?): ClassMetadataFindResult? {
|
||||
// Kotlin classes are annotated Java classes, so this check also looks for them.
|
||||
if (!javaClassConverter.hasTopLevelClassOf(classId)) return null
|
||||
@@ -104,7 +98,7 @@ class KotlinDeserializedJvmSymbolsProvider(
|
||||
return null
|
||||
}
|
||||
val kotlinClass = when (result) {
|
||||
is KotlinClassFinder.Result.KotlinClass -> result
|
||||
is KotlinClassFinder.Result.KotlinClass -> result.kotlinJvmBinaryClass
|
||||
is KotlinClassFinder.Result.ClassFileContent -> {
|
||||
val javaClass = try {
|
||||
javaClassConverter.findClass(classId, result.content) ?: return null
|
||||
@@ -115,22 +109,20 @@ class KotlinDeserializedJvmSymbolsProvider(
|
||||
javaClassConverter.convertJavaClassToFir(symbol, classId.outerClassId?.let(::getClass), javaClass)
|
||||
}
|
||||
}
|
||||
null -> return ClassMetadataFindResult.ShouldDeserializeViaParent
|
||||
}
|
||||
if (kotlinClass.kotlinJvmBinaryClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null
|
||||
val (nameResolver, classProto) = kotlinClass.extractMetadata() ?: return null
|
||||
|
||||
if (parentContext == null && Flags.CLASS_KIND.get(classProto.flags) == ProtoBuf.Class.Kind.COMPANION_OBJECT) {
|
||||
return ClassMetadataFindResult.ShouldDeserializeViaParent
|
||||
null -> return null
|
||||
}
|
||||
if (kotlinClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null
|
||||
val data = kotlinClass.classHeader.data ?: return null
|
||||
val strings = kotlinClass.classHeader.strings ?: return null
|
||||
val (nameResolver, classProto) = JvmProtoBufUtil.readClassDataFrom(data, strings)
|
||||
|
||||
return ClassMetadataFindResult.Metadata(
|
||||
nameResolver,
|
||||
classProto,
|
||||
JvmBinaryAnnotationDeserializer(session, kotlinClass.kotlinJvmBinaryClass, kotlinClassFinder, kotlinClass.byteContent),
|
||||
kotlinClass.kotlinJvmBinaryClass.containingLibrary.toPath(),
|
||||
KotlinJvmBinarySourceElement(kotlinClass.kotlinJvmBinaryClass),
|
||||
classPostProcessor = { loadAnnotationsFromClassFile(kotlinClass, it) }
|
||||
JvmBinaryAnnotationDeserializer(session, kotlinClass, kotlinClassFinder, result.byteContent),
|
||||
kotlinClass.containingLibrary.toPath(),
|
||||
KotlinJvmBinarySourceElement(kotlinClass),
|
||||
classPostProcessor = { loadAnnotationsFromClassFile(result, it) }
|
||||
)
|
||||
}
|
||||
|
||||
@@ -141,7 +133,7 @@ class KotlinDeserializedJvmSymbolsProvider(
|
||||
kotlinClass: KotlinClassFinder.Result.KotlinClass,
|
||||
symbol: FirRegularClassSymbol
|
||||
) {
|
||||
val annotations = mutableListOf<FirAnnotation>()
|
||||
val annotations = symbol.fir.annotations as MutableList<FirAnnotation>
|
||||
kotlinClass.kotlinJvmBinaryClass.loadClassAnnotations(
|
||||
object : KotlinJvmBinaryClass.AnnotationVisitor {
|
||||
override fun visitAnnotation(classId: ClassId, source: SourceElement): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
|
||||
@@ -153,16 +145,9 @@ class KotlinDeserializedJvmSymbolsProvider(
|
||||
},
|
||||
kotlinClass.byteContent,
|
||||
)
|
||||
(symbol.fir.annotations as MutableList<FirAnnotation>) += annotations
|
||||
symbol.fir.replaceDeprecation(symbol.fir.getDeprecationInfos(session.languageVersionSettings.apiVersion))
|
||||
}
|
||||
|
||||
private fun KotlinClassFinder.Result.KotlinClass.extractMetadata(): Pair<NameResolver, ProtoBuf.Class>? {
|
||||
val data = kotlinJvmBinaryClass.classHeader.data ?: return null
|
||||
val strings = kotlinJvmBinaryClass.classHeader.strings ?: return null
|
||||
return JvmProtoBufUtil.readClassDataFrom(data, strings)
|
||||
}
|
||||
|
||||
private fun String?.toPath(): Path? {
|
||||
return this?.let { Paths.get(it).normalize() }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user