FIR: always load parents of nested classes first

This commit is contained in:
pyos
2021-09-13 14:46:08 +02:00
committed by TeamCityServer
parent 032cf5a892
commit 9fa3e2eb13
2 changed files with 23 additions and 63 deletions
@@ -67,7 +67,7 @@ abstract class AbstractFirDeserializedSymbolsProvider(
private val packagePartsCache = session.firCachesFactory.createCache(::tryComputePackagePartInfos)
private val typeAliasCache = session.firCachesFactory.createCache(::findAndDeserializeTypeAlias)
protected val classCache: FirCache<ClassId, FirRegularClassSymbol?, FirDeserializationContext?> =
private val classCache: FirCache<ClassId, FirRegularClassSymbol?, FirDeserializationContext?> =
session.firCachesFactory.createCacheWithPostCompute(
createValue = { classId, context -> findAndDeserializeClass(classId, context) },
postCompute = { _, symbol, postProcessor ->
@@ -103,8 +103,6 @@ abstract class AbstractFirDeserializedSymbolsProvider(
val sourceElement: DeserializedContainerSource,
val classPostProcessor: DeserializedClassPostProcessor
) : ClassMetadataFindResult()
object ShouldDeserializeViaParent : ClassMetadataFindResult()
}
private fun tryComputePackagePartInfos(packageFqName: FqName): List<PackagePartsCacheData> {
@@ -149,20 +147,10 @@ abstract class AbstractFirDeserializedSymbolsProvider(
)
symbol to postProcessor
}
ClassMetadataFindResult.ShouldDeserializeViaParent -> findAndDeserializeClassViaParent(classId) to null
null -> null to null
}
}
private fun findAndDeserializeClassViaParent(classId: ClassId): FirRegularClassSymbol? {
val outerClassId = classId.outerClassId ?: return null
//This will cause cyclic cache request that is highly observable in IDE (but not in the compiler - but possible SOE bug also)
//To avoid it in IDE there is special implementation that forces load parent class before any nested class request:
//[org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSessionFactory.KotlinDeserializedJvmSymbolsProviderForIde]
getClass(outerClassId) ?: return null
return classCache.getValueIfComputed(classId)
}
private fun loadFunctionsByCallableId(callableId: CallableId): List<FirNamedFunctionSymbol> {
return getPackageParts(callableId.packageName).flatMap { part ->
val functionIds = part.topLevelFunctionNameIndex[callableId.callableName] ?: return@flatMap emptyList()
@@ -181,37 +169,26 @@ abstract class AbstractFirDeserializedSymbolsProvider(
}
}
private fun getPackageParts(packageFqName: FqName): Collection<PackagePartsCacheData> {
return packagePartsCache.getValue(packageFqName)
}
protected open fun shouldLoadParentsFirst(classId: ClassId): Boolean = false
private fun getPackageParts(packageFqName: FqName): Collection<PackagePartsCacheData> =
packagePartsCache.getValue(packageFqName)
protected fun getClass(
classId: ClassId,
parentContext: FirDeserializationContext? = null
): FirRegularClassSymbol? {
if (parentContext == null && shouldLoadParentsFirst(classId)) {
return getClassAfterLoadingParents(classId)
val parentClassId = classId.outerClassId
if (parentContext == null && parentClassId != null) {
val alreadyLoaded = classCache.getValueIfComputed(classId)
if (alreadyLoaded != null) return alreadyLoaded
// Load parent first in case correct `parentContext` is needed to deserialize the metadata of this class.
getClass(parentClassId, null)
// If that's the case, `classCache` should contain a value for `classId`.
}
return classCache.getValue(classId, parentContext)
}
private fun getClassAfterLoadingParents(classId: ClassId): FirRegularClassSymbol? {
classId.outerClassId?.let { parentClassId ->
val alreadyLoaded = classCache.getValueIfComputed(classId)
if (alreadyLoaded != null) return alreadyLoaded
getClassAfterLoadingParents(parentClassId)
}
return classCache.getValue(classId, null)
}
private fun getTypeAlias(
classId: ClassId,
): FirTypeAliasSymbol? {
if (!classId.relativeClassName.isOneSegmentFQN()) return null
return typeAliasCache.getValue(classId)
}
private fun getTypeAlias(classId: ClassId): FirTypeAliasSymbol? =
if (classId.relativeClassName.isOneSegmentFQN()) typeAliasCache.getValue(classId) else null
// ------------------------ SymbolProvider methods ------------------------
@@ -235,6 +212,4 @@ abstract class AbstractFirDeserializedSymbolsProvider(
override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? {
return getClass(classId) ?: getTypeAlias(classId)
}
override fun getPackage(fqName: FqName): FqName? = null
}
@@ -19,9 +19,6 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.load.java.JavaClassFinder
import org.jetbrains.kotlin.load.kotlin.*
import org.jetbrains.kotlin.load.kotlin.header.KotlinClassHeader
import org.jetbrains.kotlin.metadata.ProtoBuf
import org.jetbrains.kotlin.metadata.deserialization.Flags
import org.jetbrains.kotlin.metadata.deserialization.NameResolver
import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmMetadataVersion
import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmProtoBufUtil
import org.jetbrains.kotlin.name.ClassId
@@ -91,9 +88,6 @@ class KotlinDeserializedJvmSymbolsProvider(
private val KotlinJvmBinaryClass.isPreReleaseInvisible: Boolean
get() = classHeader.isPreRelease
override fun shouldLoadParentsFirst(classId: ClassId): Boolean =
javaClassConverter.hasTopLevelClassOf(classId)
override fun extractClassMetadata(classId: ClassId, parentContext: FirDeserializationContext?): ClassMetadataFindResult? {
// Kotlin classes are annotated Java classes, so this check also looks for them.
if (!javaClassConverter.hasTopLevelClassOf(classId)) return null
@@ -104,7 +98,7 @@ class KotlinDeserializedJvmSymbolsProvider(
return null
}
val kotlinClass = when (result) {
is KotlinClassFinder.Result.KotlinClass -> result
is KotlinClassFinder.Result.KotlinClass -> result.kotlinJvmBinaryClass
is KotlinClassFinder.Result.ClassFileContent -> {
val javaClass = try {
javaClassConverter.findClass(classId, result.content) ?: return null
@@ -115,22 +109,20 @@ class KotlinDeserializedJvmSymbolsProvider(
javaClassConverter.convertJavaClassToFir(symbol, classId.outerClassId?.let(::getClass), javaClass)
}
}
null -> return ClassMetadataFindResult.ShouldDeserializeViaParent
}
if (kotlinClass.kotlinJvmBinaryClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null
val (nameResolver, classProto) = kotlinClass.extractMetadata() ?: return null
if (parentContext == null && Flags.CLASS_KIND.get(classProto.flags) == ProtoBuf.Class.Kind.COMPANION_OBJECT) {
return ClassMetadataFindResult.ShouldDeserializeViaParent
null -> return null
}
if (kotlinClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null
val data = kotlinClass.classHeader.data ?: return null
val strings = kotlinClass.classHeader.strings ?: return null
val (nameResolver, classProto) = JvmProtoBufUtil.readClassDataFrom(data, strings)
return ClassMetadataFindResult.Metadata(
nameResolver,
classProto,
JvmBinaryAnnotationDeserializer(session, kotlinClass.kotlinJvmBinaryClass, kotlinClassFinder, kotlinClass.byteContent),
kotlinClass.kotlinJvmBinaryClass.containingLibrary.toPath(),
KotlinJvmBinarySourceElement(kotlinClass.kotlinJvmBinaryClass),
classPostProcessor = { loadAnnotationsFromClassFile(kotlinClass, it) }
JvmBinaryAnnotationDeserializer(session, kotlinClass, kotlinClassFinder, result.byteContent),
kotlinClass.containingLibrary.toPath(),
KotlinJvmBinarySourceElement(kotlinClass),
classPostProcessor = { loadAnnotationsFromClassFile(result, it) }
)
}
@@ -141,7 +133,7 @@ class KotlinDeserializedJvmSymbolsProvider(
kotlinClass: KotlinClassFinder.Result.KotlinClass,
symbol: FirRegularClassSymbol
) {
val annotations = mutableListOf<FirAnnotation>()
val annotations = symbol.fir.annotations as MutableList<FirAnnotation>
kotlinClass.kotlinJvmBinaryClass.loadClassAnnotations(
object : KotlinJvmBinaryClass.AnnotationVisitor {
override fun visitAnnotation(classId: ClassId, source: SourceElement): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
@@ -153,16 +145,9 @@ class KotlinDeserializedJvmSymbolsProvider(
},
kotlinClass.byteContent,
)
(symbol.fir.annotations as MutableList<FirAnnotation>) += annotations
symbol.fir.replaceDeprecation(symbol.fir.getDeprecationInfos(session.languageVersionSettings.apiVersion))
}
private fun KotlinClassFinder.Result.KotlinClass.extractMetadata(): Pair<NameResolver, ProtoBuf.Class>? {
val data = kotlinJvmBinaryClass.classHeader.data ?: return null
val strings = kotlinJvmBinaryClass.classHeader.strings ?: return null
return JvmProtoBufUtil.readClassDataFrom(data, strings)
}
private fun String?.toPath(): Path? {
return this?.let { Paths.get(it).normalize() }
}