From 58c1b5dd1f728d45550b2a1ff11fb5ec75bca7b7 Mon Sep 17 00:00:00 2001 From: "Denis.Zharkov" Date: Tue, 13 Dec 2022 15:43:50 +0100 Subject: [PATCH] K2: Optimize AbstractFirDeserializedSymbolProvider Avoid filling caches with keys that are definitely empty (if it's cheap to compute that), to decrease the size of backing maps. The strategy is pre-computing the sets of names that might be met. NB: the size of the sets is way fewer than a size of all queried names. --- .../services/PackagePartProviderTestImpl.kt | 3 + .../fir/session/KlibBasedSymbolProvider.kt | 5 ++ .../AbstractFirDeserializedSymbolProvider.kt | 58 +++++++++++++++++-- .../JvmClassFileBasedSymbolProvider.kt | 6 +- .../OptionalAnnotationClassesProvider.kt | 13 +++-- .../IncrementalPackagePartProvider.kt | 9 +++ .../load/kotlin/JvmPackagePartProviderBase.kt | 5 ++ .../kotlin/load/kotlin/PackagePartProvider.kt | 8 +++ 8 files changed, 96 insertions(+), 11 deletions(-) diff --git a/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/services/PackagePartProviderTestImpl.kt b/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/services/PackagePartProviderTestImpl.kt index 164407c66d0..99acff0b225 100644 --- a/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/services/PackagePartProviderTestImpl.kt +++ b/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/services/PackagePartProviderTestImpl.kt @@ -26,6 +26,9 @@ internal class PackagePartProviderTestImpl( return providers.flatMapTo(mutableSetOf()) { it.findPackageParts(packageFqName) }.toList() } + override fun computePackageSetWithNonClassDeclarations(): Set = + providers.flatMapTo(mutableSetOf()) { it.computePackageSetWithNonClassDeclarations() } + override fun getAnnotationsOnBinaryModule(moduleName: String): List { return providers.flatMapTo(mutableSetOf()) { it.getAnnotationsOnBinaryModule(moduleName) }.toList() } diff --git a/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/KlibBasedSymbolProvider.kt b/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/KlibBasedSymbolProvider.kt index 4b2429dd305..1efdcf0ddcc 100644 --- a/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/KlibBasedSymbolProvider.kt +++ b/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/KlibBasedSymbolProvider.kt @@ -94,6 +94,11 @@ class KlibBasedSymbolProvider( } } + override fun computePackageSetWithNonClassDeclarations(): Set = fragmentNamesInLibraries.keys + + // Looks like it's expensive to compute the presence of a class properly for KLib + override fun mayHaveTopLevelClass(classId: ClassId): Boolean = true + @OptIn(SymbolInternals::class) override fun extractClassMetadata(classId: ClassId, parentContext: FirDeserializationContext?): ClassMetadataFindResult? { val packageStringName = classId.packageFqName.asString() diff --git a/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolProvider.kt b/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolProvider.kt index 6287fd84436..97f6326997b 100644 --- a/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolProvider.kt +++ b/compiler/fir/fir-deserialization/src/org/jetbrains/kotlin/fir/deserialization/AbstractFirDeserializedSymbolProvider.kt @@ -75,6 +75,22 @@ abstract class AbstractFirDeserializedSymbolProvider( ) : FirSymbolProvider(session) { // ------------------------ Caches ------------------------ + private val packageNamesForNonClassDeclarations: Set by lazy(LazyThreadSafetyMode.PUBLICATION) { + computePackageSetWithNonClassDeclarations() + } + + private val typeAliasesNamesByPackage: FirCache, Nothing?> = + session.firCachesFactory.createCache { fqName: FqName -> + getPackageParts(fqName).flatMapTo(mutableSetOf()) { it.typeAliasNameIndex.keys } + } + + private val allNamesByPackage: FirCache, Nothing?> = + session.firCachesFactory.createCache { fqName: FqName -> + getPackageParts(fqName).flatMapTo(mutableSetOf()) { + it.topLevelFunctionNameIndex.keys + it.topLevelPropertyNameIndex.keys + } + } + private val packagePartsCache = session.firCachesFactory.createCache(::tryComputePackagePartInfos) private val typeAliasCache: FirCache = session.firCachesFactory.createCacheWithPostCompute( @@ -102,6 +118,14 @@ abstract class AbstractFirDeserializedSymbolProvider( protected abstract fun computePackagePartsInfos(packageFqName: FqName): List + // Return full package names that might be not empty (have some non-class declarations) in this provider + // In JVM, it's expensive to compute all the packages that might contain a Java class among dependencies + // But, as we have all the metadata, we may be sure about top-level callables and type aliases + // This method should only be used for sake of optimization to avoid having too many empty-list/null values in our caches + protected abstract fun computePackageSetWithNonClassDeclarations(): Set + + protected abstract fun mayHaveTopLevelClass(classId: ClassId): Boolean + protected abstract fun extractClassMetadata( classId: ClassId, parentContext: FirDeserializationContext? = null @@ -201,6 +225,11 @@ abstract class AbstractFirDeserializedSymbolProvider( parentContext: FirDeserializationContext? = null ): FirRegularClassSymbol? { val parentClassId = classId.outerClassId + + // Actually, the second "if" should be enough but the first one might work faster + if (parentClassId == null && !mayHaveTopLevelClass(classId)) return null + if (parentClassId != null && !mayHaveTopLevelClass(classId.outermostClassId)) return null + if (parentContext == null && parentClassId != null) { val alreadyLoaded = classCache.getValueIfComputed(classId) if (alreadyLoaded != null) return alreadyLoaded @@ -211,26 +240,43 @@ abstract class AbstractFirDeserializedSymbolProvider( return classCache.getValue(classId, parentContext) } - private fun getTypeAlias(classId: ClassId): FirTypeAliasSymbol? = - if (classId.relativeClassName.isOneSegmentFQN()) typeAliasCache.getValue(classId) else null + private fun getTypeAlias(classId: ClassId): FirTypeAliasSymbol? { + if (!classId.relativeClassName.isOneSegmentFQN()) return null + + // Don't actually query FirCache when we're sure there are no relevant value + // It helps to decrease the size of a cache thus leading to better query time + val packageFqName = classId.packageFqName + if (packageFqName.asString() !in packageNamesForNonClassDeclarations) return null + if (classId.shortClassName !in typeAliasesNamesByPackage.getValue(packageFqName)) return null + + return typeAliasCache.getValue(classId) + } // ------------------------ SymbolProvider methods ------------------------ @FirSymbolProviderInternals override fun getTopLevelCallableSymbolsTo(destination: MutableList>, packageFqName: FqName, name: Name) { val callableId = CallableId(packageFqName, name) - destination += functionCache.getValue(callableId) - destination += propertyCache.getValue(callableId) + destination += functionCache.getCallables(callableId) + destination += propertyCache.getCallables(callableId) + } + + private fun > FirCache, Nothing?>.getCallables(id: CallableId): List { + // Don't actually query FirCache when we're sure there are no relevant value + // It helps to decrease the size of a cache thus leading to better query time + if (id.packageName.asString() !in packageNamesForNonClassDeclarations) return emptyList() + if (id.callableName !in allNamesByPackage.getValue(id.packageName)) return emptyList() + return getValue(id) } @FirSymbolProviderInternals override fun getTopLevelFunctionSymbolsTo(destination: MutableList, packageFqName: FqName, name: Name) { - destination += functionCache.getValue(CallableId(packageFqName, name)) + destination += functionCache.getCallables(CallableId(packageFqName, name)) } @FirSymbolProviderInternals override fun getTopLevelPropertySymbolsTo(destination: MutableList, packageFqName: FqName, name: Name) { - destination += propertyCache.getValue(CallableId(packageFqName, name)) + destination += propertyCache.getCallables(CallableId(packageFqName, name)) } override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? { diff --git a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/JvmClassFileBasedSymbolProvider.kt b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/JvmClassFileBasedSymbolProvider.kt index 7f7fd9972d7..2da8c6ed8d7 100644 --- a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/JvmClassFileBasedSymbolProvider.kt +++ b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/JvmClassFileBasedSymbolProvider.kt @@ -28,7 +28,7 @@ import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmProtoBufUtil import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.resolve.jvm.JvmClassName -import org.jetbrains.kotlin.serialization.deserialization.* +import org.jetbrains.kotlin.serialization.deserialization.IncompatibleVersionErrorData import org.jetbrains.kotlin.serialization.deserialization.builtins.BuiltInSerializerProtocol import org.jetbrains.kotlin.utils.toMetadataVersion import java.nio.file.Path @@ -97,6 +97,10 @@ class JvmClassFileBasedSymbolProvider( } } + override fun computePackageSetWithNonClassDeclarations(): Set = packagePartProvider.computePackageSetWithNonClassDeclarations() + + override fun mayHaveTopLevelClass(classId: ClassId): Boolean = javaFacade.hasTopLevelClassOf(classId) + private val KotlinJvmBinaryClass.incompatibility: IncompatibleVersionErrorData? get() { // TODO: skipMetadataVersionCheck diff --git a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/OptionalAnnotationClassesProvider.kt b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/OptionalAnnotationClassesProvider.kt index 0fbdfa737e8..2ca555c59ab 100644 --- a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/OptionalAnnotationClassesProvider.kt +++ b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/deserialization/OptionalAnnotationClassesProvider.kt @@ -34,12 +34,12 @@ class OptionalAnnotationClassesProvider( private val optionalAnnotationClassesAndPackages by lazy(LazyThreadSafetyMode.PUBLICATION) { val optionalAnnotationClasses = mutableMapOf() - val optionalAnnotationPackages = mutableSetOf() + val optionalAnnotationPackages = mutableSetOf() for (klass in packagePartProvider.getAllOptionalAnnotationClasses()) { val classId = klass.nameResolver.getClassId(klass.classProto.fqName) optionalAnnotationClasses[classId] = klass - optionalAnnotationPackages.add(classId.packageFqName) + optionalAnnotationPackages.add(classId.packageFqName.asString()) } return@lazy Pair(optionalAnnotationClasses, optionalAnnotationPackages) @@ -49,6 +49,10 @@ class OptionalAnnotationClassesProvider( return emptyList() } + override fun computePackageSetWithNonClassDeclarations(): Set = optionalAnnotationClassesAndPackages.second + + override fun mayHaveTopLevelClass(classId: ClassId): Boolean = classId in optionalAnnotationClassesAndPackages.first + override fun extractClassMetadata( classId: ClassId, parentContext: FirDeserializationContext? @@ -69,5 +73,6 @@ class OptionalAnnotationClassesProvider( return JvmFlags.IS_COMPILED_IN_JVM_DEFAULT_MODE.get(classProto.getExtension(JvmProtoBuf.jvmClassFlags)) } - override fun getPackage(fqName: FqName): FqName? = if (optionalAnnotationClassesAndPackages.second.contains(fqName)) fqName else null -} \ No newline at end of file + override fun getPackage(fqName: FqName): FqName? = + if (optionalAnnotationClassesAndPackages.second.contains(fqName.asString())) fqName else null +} diff --git a/compiler/frontend.java/src/org/jetbrains/kotlin/load/kotlin/incremental/IncrementalPackagePartProvider.kt b/compiler/frontend.java/src/org/jetbrains/kotlin/load/kotlin/incremental/IncrementalPackagePartProvider.kt index 90bc4dc3cae..4d55c4f92ec 100644 --- a/compiler/frontend.java/src/org/jetbrains/kotlin/load/kotlin/incremental/IncrementalPackagePartProvider.kt +++ b/compiler/frontend.java/src/org/jetbrains/kotlin/load/kotlin/incremental/IncrementalPackagePartProvider.kt @@ -52,6 +52,15 @@ class IncrementalPackagePartProvider( parent.findPackageParts(packageFqName)).distinct() } + private val allPackageNames: Set by lazy { + buildSet { + moduleMappings.flatMapTo(this@buildSet) { it.packageFqName2Parts.keys } + addAll(parent.computePackageSetWithNonClassDeclarations()) + } + } + + override fun computePackageSetWithNonClassDeclarations(): Set = allPackageNames + override fun getAnnotationsOnBinaryModule(moduleName: String): List { return parent.getAnnotationsOnBinaryModule(moduleName) } diff --git a/core/descriptors.jvm/src/org/jetbrains/kotlin/load/kotlin/JvmPackagePartProviderBase.kt b/core/descriptors.jvm/src/org/jetbrains/kotlin/load/kotlin/JvmPackagePartProviderBase.kt index 262434cff81..3d7a8c20dd6 100644 --- a/core/descriptors.jvm/src/org/jetbrains/kotlin/load/kotlin/JvmPackagePartProviderBase.kt +++ b/core/descriptors.jvm/src/org/jetbrains/kotlin/load/kotlin/JvmPackagePartProviderBase.kt @@ -39,6 +39,11 @@ abstract class JvmPackagePartProviderBase : PackagePartProvider, Me return result.toList() } + private val allPackageNames: Set by lazy { + loadedModules.flatMapTo(mutableSetOf()) { it.mapping.packageFqName2Parts.keys } + } + + override fun computePackageSetWithNonClassDeclarations(): Set = allPackageNames override fun findMetadataPackageParts(packageFqName: String): List = getPackageParts(packageFqName).flatMap(PackageParts::metadataParts).distinct() diff --git a/core/deserialization.common.jvm/src/org/jetbrains/kotlin/load/kotlin/PackagePartProvider.kt b/core/deserialization.common.jvm/src/org/jetbrains/kotlin/load/kotlin/PackagePartProvider.kt index 2502123a181..8d4eabe5eb5 100644 --- a/core/deserialization.common.jvm/src/org/jetbrains/kotlin/load/kotlin/PackagePartProvider.kt +++ b/core/deserialization.common.jvm/src/org/jetbrains/kotlin/load/kotlin/PackagePartProvider.kt @@ -18,6 +18,12 @@ interface PackagePartProvider { */ fun findPackageParts(packageFqName: String): List + /** + * This method is only for sake of optimization + * @return package names set for which that provider has package parts + */ + fun computePackageSetWithNonClassDeclarations(): Set + fun getAnnotationsOnBinaryModule(moduleName: String): List fun getAllOptionalAnnotationClasses(): List @@ -28,5 +34,7 @@ interface PackagePartProvider { override fun getAnnotationsOnBinaryModule(moduleName: String): List = emptyList() override fun getAllOptionalAnnotationClasses(): List = emptyList() + + override fun computePackageSetWithNonClassDeclarations(): Set = emptySet() } }