diff --git a/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/ComponentsContainers.kt b/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/ComponentsContainers.kt index 45ab6c462da..543ad97a76b 100644 --- a/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/ComponentsContainers.kt +++ b/compiler/fir/entrypoint/src/org/jetbrains/kotlin/fir/session/ComponentsContainers.kt @@ -25,6 +25,7 @@ import org.jetbrains.kotlin.fir.extensions.FirRegisteredPluginAnnotations import org.jetbrains.kotlin.fir.java.FirJavaVisibilityChecker import org.jetbrains.kotlin.fir.java.FirJvmDefaultModeComponent import org.jetbrains.kotlin.fir.java.enhancement.FirAnnotationTypeQualifierResolver +import org.jetbrains.kotlin.fir.java.enhancement.FirEnhancedSymbolsStorage import org.jetbrains.kotlin.fir.resolve.* import org.jetbrains.kotlin.fir.resolve.calls.ConeCallConflictResolverFactory import org.jetbrains.kotlin.fir.resolve.calls.FirSyntheticNamesProvider @@ -77,6 +78,7 @@ fun FirSession.registerCliCompilerOnlyComponents() { fun FirSession.registerCommonJavaComponents(javaModuleResolver: JavaModuleResolver) { val jsr305State = languageVersionSettings.getFlag(JvmAnalysisFlags.javaTypeEnhancementState) register(FirAnnotationTypeQualifierResolver::class, FirAnnotationTypeQualifierResolver(this, jsr305State, javaModuleResolver)) + register(FirEnhancedSymbolsStorage::class, FirEnhancedSymbolsStorage(this)) register( FirJvmDefaultModeComponent::class, FirJvmDefaultModeComponent(languageVersionSettings.getFlag(JvmAnalysisFlags.jvmDefaultMode)) diff --git a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/enhancement/SignatureEnhancement.kt b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/enhancement/SignatureEnhancement.kt index 3e0e1f8d83e..89f5bcb2e8e 100644 --- a/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/enhancement/SignatureEnhancement.kt +++ b/compiler/fir/java/src/org/jetbrains/kotlin/fir/java/enhancement/SignatureEnhancement.kt @@ -8,6 +8,10 @@ package org.jetbrains.kotlin.fir.java.enhancement import org.jetbrains.kotlin.descriptors.ClassKind import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.fir.* +import org.jetbrains.kotlin.fir.caches.FirCache +import org.jetbrains.kotlin.fir.caches.FirCachesFactory +import org.jetbrains.kotlin.fir.caches.createCache +import org.jetbrains.kotlin.fir.caches.firCachesFactory import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.declarations.builder.FirConstructorBuilder import org.jetbrains.kotlin.fir.declarations.builder.FirPrimaryConstructorBuilder @@ -68,30 +72,21 @@ class FirSignatureEnhancement( private val contextQualifiers: JavaTypeQualifiersByElementType? = typeQualifierResolver.extractDefaultQualifiers(owner) - private val enhancements = mutableMapOf, FirCallableSymbol<*>>() + private val enhancementsCache = session.enhancedSymbolStorage.cacheByOwner.getValue(owner.symbol, null) - fun enhancedFunction( - function: FirFunctionSymbol<*>, - name: Name? - ): FirFunctionSymbol<*> { - return enhancements.getOrPut(function) { - enhance(function, name).also { enhancedVersion -> - val enhancedVersionFir = enhancedVersion.fir - (enhancedVersionFir.initialSignatureAttr as? FirSimpleFunction)?.let { - enhancedVersionFir.initialSignatureAttr = enhancedFunction(it.symbol, it.name).fir - } - } - } as FirFunctionSymbol<*> + fun enhancedFunction(function: FirFunctionSymbol<*>, name: Name?): FirFunctionSymbol<*> { + return enhancementsCache.enhancedFunctions.getValue(function, this to name) } fun enhancedProperty(property: FirVariableSymbol<*>, name: Name): FirVariableSymbol<*> { - return enhancements.getOrPut(property) { enhance(property, name) } as FirVariableSymbol<*> + return enhancementsCache.enhancedVariables.getValue(property, this to name) } private fun FirDeclaration.computeDefaultQualifiers() = typeQualifierResolver.extractAndMergeDefaultQualifiers(contextQualifiers, annotations) - private fun enhance( + @PrivateForInline + internal fun enhance( original: FirVariableSymbol<*>, name: Name ): FirVariableSymbol<*> { @@ -182,7 +177,8 @@ class FirSignatureEnhancement( } } - private fun enhance( + @PrivateForInline + internal fun enhance( original: FirFunctionSymbol<*>, name: Name? ): FirFunctionSymbol<*> { @@ -192,7 +188,12 @@ class FirSignatureEnhancement( return original } enhanceTypeParameterBounds(firMethod.typeParameters) - return enhanceMethod(firMethod, original.callableId, name) + return enhanceMethod(firMethod, original.callableId, name).also { enhancedVersion -> + val enhancedVersionFir = enhancedVersion.fir + (enhancedVersionFir.initialSignatureAttr as? FirSimpleFunction)?.let { + enhancedVersionFir.initialSignatureAttr = enhancedFunction(it.symbol, it.name).fir + } + } } private fun enhanceMethod( @@ -551,3 +552,26 @@ private class EnhancementSignatureParts( override val TypeParameterMarker.isFromJava: Boolean get() = (this as ConeTypeParameterLookupTag).symbol.fir.origin == FirDeclarationOrigin.Java } + +class FirEnhancedSymbolsStorage(val session: FirSession) : FirSessionComponent { + private val cachesFactory = session.firCachesFactory + + val cacheByOwner: FirCache = + cachesFactory.createCache { _ -> EnhancementSymbolsCache(cachesFactory) } + + class EnhancementSymbolsCache(cachesFactory: FirCachesFactory) { + @OptIn(PrivateForInline::class) + val enhancedFunctions: FirCache, FirFunctionSymbol<*>, Pair> = + cachesFactory.createCache { original, (enhancement, name) -> + enhancement.enhance(original, name) + } + + @OptIn(PrivateForInline::class) + val enhancedVariables: FirCache, FirVariableSymbol<*>, Pair> = + cachesFactory.createCache { original, (enhancement, name) -> + enhancement.enhance(original, name) + } + } +} + +private val FirSession.enhancedSymbolStorage: FirEnhancedSymbolsStorage by FirSession.sessionComponentAccessor() diff --git a/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.fir.kt b/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.fir.kt index 9808bc3e4de..52f24c84d08 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.fir.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.fir.kt @@ -1,5 +1,4 @@ // !LANGUAGE: -ProhibitConcurrentHashMapContains -// FIR_IDE_IGNORE // FULL_JDK class A : java.util.concurrent.ConcurrentHashMap() { diff --git a/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.kt b/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.kt index 3a56ad8a7fb..1afd85ba221 100644 --- a/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.kt +++ b/compiler/testData/diagnostics/testsWithStdLib/java/concurrentHashMapContains.kt @@ -1,5 +1,4 @@ // !LANGUAGE: -ProhibitConcurrentHashMapContains -// FIR_IDE_IGNORE // FULL_JDK class A : java.util.concurrent.ConcurrentHashMap() {