FIR: use FirCachesFactory for

class cache in KotlinDeserializedJvmSymbolsProvider
This commit is contained in:
Ilya Kirillov
2021-01-19 17:15:43 +01:00
parent 3cee5e848a
commit b270d66f68
2 changed files with 42 additions and 31 deletions
@@ -10,7 +10,7 @@ import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.descriptors.SourceElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.ThreadSafeMutableState
import org.jetbrains.kotlin.fir.caches.firCachesFactory
import org.jetbrains.kotlin.fir.caches.*
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.deserialization.FirConstDeserializer
import org.jetbrains.kotlin.fir.deserialization.FirDeserializationContext
@@ -50,14 +50,22 @@ class KotlinDeserializedJvmSymbolsProvider(
private val kotlinScopeProvider: KotlinScopeProvider,
) : FirSymbolProvider(session) {
private val annotationsLoader = AnnotationsLoader(session)
private val classCache = SymbolProviderCache<ClassId, FirRegularClassSymbol>()
private val typeAliasCache = SymbolProviderCache<ClassId, FirTypeAliasSymbol>()
private val packagePartsCache = SymbolProviderCache<FqName, Collection<PackagePartsCacheData>>()
private val classCache =
session.firCachesFactory.createCacheWithPostCompute<ClassId, FirRegularClassSymbol?, FirDeserializationContext?, KotlinClassFinder.Result.KotlinClass?>(
createValue = { classId, context -> findAndDeserializeClass(classId, context) },
postCompute = { _, symbol, result ->
if (result != null && symbol != null) {
postCompute(result.kotlinJvmBinaryClass, result.byteContent, symbol)
}
}
)
private val knownNameInPackageCache = KnownNameInPackageCache(session, javaClassFinder)
// TODO: implement thread safety for this property
private val handledByJava = HashSet<ClassId>()
private class PackagePartsCacheData(
val proto: ProtoBuf.Package,
@@ -123,7 +131,7 @@ class KotlinDeserializedJvmSymbolsProvider(
get() = classHeader.isPreRelease
override fun getClassLikeSymbolByFqName(classId: ClassId): FirClassLikeSymbol<*>? {
return findAndDeserializeClass(classId) ?: findAndDeserializeTypeAlias(classId)
return getClass(classId) ?: findAndDeserializeTypeAlias(classId)
}
private fun findAndDeserializeTypeAlias(
@@ -149,53 +157,58 @@ class KotlinDeserializedJvmSymbolsProvider(
private fun findAndDeserializeClassViaParent(classId: ClassId): FirRegularClassSymbol? {
val outerClassId = classId.outerClassId ?: return null
findAndDeserializeClass(outerClassId) ?: return null
return classCache[classId]
getClass(outerClassId) ?: return null
return classCache.getValueIfComputed(classId)
}
private fun getClass(
classId: ClassId,
parentContext: FirDeserializationContext? = null
): FirRegularClassSymbol? {
return classCache.getValue(classId, parentContext)
}
private fun findAndDeserializeClass(
classId: ClassId,
parentContext: FirDeserializationContext? = null
): FirRegularClassSymbol? {
if (knownNameInPackageCache.hasNoTopLevelClassOf(classId)) return null
if (classId in classCache) return classCache[classId]
if (classId in handledByJava) return null
): Pair<FirRegularClassSymbol?, KotlinClassFinder.Result.KotlinClass?> {
if (knownNameInPackageCache.hasNoTopLevelClassOf(classId)) return null to null
val result = try {
kotlinClassFinder.findKotlinClassOrContent(classId)
} catch (e: ProcessCanceledException) {
return null
return null to null
}
val (kotlinJvmBinaryClass, byteContent) = when (result) {
val kotlinClass = when (result) {
is KotlinClassFinder.Result.KotlinClass -> result
is KotlinClassFinder.Result.ClassFileContent -> {
handledByJava.add(classId)
return try {
javaSymbolProvider.getFirJavaClass(classId, result)
} catch (e: ProcessCanceledException) {
null
}
return javaSymbolProvider.getFirJavaClass(classId, result) to null
}
null -> return findAndDeserializeClassViaParent(classId)
null -> return findAndDeserializeClassViaParent(classId) to null
}
if (kotlinJvmBinaryClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null
val (nameResolver, classProto) = kotlinJvmBinaryClass.readClassDataFrom() ?: return null
if (kotlinClass.kotlinJvmBinaryClass.classHeader.kind != KotlinClassHeader.Kind.CLASS) return null to null
val (nameResolver, classProto) = kotlinClass.kotlinJvmBinaryClass.readClassDataFrom() ?: return null to null
if (parentContext == null && Flags.CLASS_KIND.get(classProto.flags) == ProtoBuf.Class.Kind.COMPANION_OBJECT) {
return findAndDeserializeClassViaParent(classId)
return findAndDeserializeClassViaParent(classId) to null
}
val symbol = FirRegularClassSymbol(classId)
deserializeClassToSymbol(
classId, classProto, symbol, nameResolver, session,
JvmBinaryAnnotationDeserializer(session, kotlinJvmBinaryClass, byteContent),
JvmBinaryAnnotationDeserializer(session, kotlinClass.kotlinJvmBinaryClass, kotlinClass.byteContent),
kotlinScopeProvider,
parentContext, KotlinJvmBinarySourceElement(kotlinJvmBinaryClass),
this::findAndDeserializeClass
parentContext, KotlinJvmBinarySourceElement(kotlinClass.kotlinJvmBinaryClass),
this::getClass
)
classCache[classId] = symbol
return symbol to kotlinClass
}
fun postCompute(
kotlinJvmBinaryClass: KotlinJvmBinaryClass,
byteContent: ByteArray?,
symbol: FirRegularClassSymbol
) {
val annotations = mutableListOf<FirAnnotationCall>()
kotlinJvmBinaryClass.loadClassAnnotations(
object : KotlinJvmBinaryClass.AnnotationVisitor {
@@ -209,7 +222,6 @@ class KotlinDeserializedJvmSymbolsProvider(
byteContent,
)
(symbol.fir.annotations as MutableList<FirAnnotationCall>) += annotations
return symbol
}
private fun loadFunctionsByName(part: PackagePartsCacheData, name: Name): List<FirNamedFunctionSymbol> {
@@ -38,7 +38,6 @@ import org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches.FirThreadSafeCache
import org.jetbrains.kotlin.idea.fir.low.level.api.lazy.resolve.FirLazyDeclarationResolver
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.FirModuleWithDependenciesSymbolProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.FirIdeProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSessionFactory.registerIdeComponents
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.FirThreadSafeSymbolProviderWrapper
import org.jetbrains.kotlin.idea.fir.low.level.api.util.ModuleLibrariesSearchScope
import org.jetbrains.kotlin.idea.fir.low.level.api.util.checkCanceled