diff --git a/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/FirModuleResolveState.kt b/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/FirModuleResolveState.kt index a9c668e3623..2399012538a 100644 --- a/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/FirModuleResolveState.kt +++ b/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/FirModuleResolveState.kt @@ -85,20 +85,4 @@ abstract class FirModuleResolveState { ) internal abstract fun getFirFile(declaration: FirDeclaration, cache: ModuleFileCache): FirFile? - - fun withFirDeclaration(declaration: D, action: (D) -> R): R { - val originalDeclaration = (declaration as? FirCallableDeclaration<*>)?.unwrapFakeOverrides() ?: declaration - val session = originalDeclaration.session - return when { - originalDeclaration.origin == FirDeclarationOrigin.Source - && session is FirIdeSourcesSession - -> { - val cache = session.cache - val file = getFirFile(declaration, cache) - ?: error("Fir file was not found for\n${declaration.render()}\n${declaration.ktDeclaration.getElementTextInContext()}") - cache.firFileLockProvider.withReadLock(file) { action(declaration) } - } - else -> action(declaration) - } - } } diff --git a/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/LowLevelFirApiFacade.kt b/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/LowLevelFirApiFacade.kt index 21f456f7b1e..e51d168ec1f 100644 --- a/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/LowLevelFirApiFacade.kt +++ b/idea/idea-frontend-fir/idea-fir-low-level-api/src/org/jetbrains/kotlin/idea/fir/low/level/api/api/LowLevelFirApiFacade.kt @@ -7,13 +7,16 @@ package org.jetbrains.kotlin.idea.fir.low.level.api.api import org.jetbrains.kotlin.diagnostics.Diagnostic import org.jetbrains.kotlin.fir.FirElement -import org.jetbrains.kotlin.fir.declarations.FirDeclaration -import org.jetbrains.kotlin.fir.declarations.FirFile -import org.jetbrains.kotlin.fir.declarations.FirResolvePhase +import org.jetbrains.kotlin.fir.declarations.* +import org.jetbrains.kotlin.fir.render +import org.jetbrains.kotlin.fir.unwrapFakeOverrides import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo import org.jetbrains.kotlin.idea.caches.project.getModuleInfo import org.jetbrains.kotlin.idea.fir.low.level.api.FirIdeResolveStateService import org.jetbrains.kotlin.idea.fir.low.level.api.annotations.InternalForInline +import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSourcesSession +import org.jetbrains.kotlin.idea.fir.low.level.api.util.ktDeclaration +import org.jetbrains.kotlin.idea.util.getElementTextInContext import org.jetbrains.kotlin.psi.KtDeclaration import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtFile @@ -88,6 +91,31 @@ inline fun KtLambdaExpression.withFirDeclaration return action(firDeclaration) } +/** + * Executes [action] with given [FirDeclaration] + * [FirDeclaration] passed to [action] will be resolved at least to [phase] when executing [action] on it + */ +fun D.withFirDeclaration( + resolveState: FirModuleResolveState, + phase: FirResolvePhase = FirResolvePhase.RAW_FIR, + action: (D) -> R, +): R { + resolvedFirToPhase(phase, resolveState) + val originalDeclaration = (this as? FirCallableDeclaration<*>)?.unwrapFakeOverrides() ?: this + val session = originalDeclaration.session + return when { + originalDeclaration.origin == FirDeclarationOrigin.Source + && session is FirIdeSourcesSession + -> { + val cache = session.cache + val file = resolveState.getFirFile(this, cache) + ?: error("Fir file was not found for\n${render()}\n${ktDeclaration.getElementTextInContext()}") + cache.firFileLockProvider.withReadLock(file) { action(this) } + } + else -> action(this) + } +} + /** * Returns a list of Diagnostics compiler finds for given [KtElement] */ diff --git a/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/utils/FirRefWithValidityCheck.kt b/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/utils/FirRefWithValidityCheck.kt index 708fd7a25ee..d143492571c 100644 --- a/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/utils/FirRefWithValidityCheck.kt +++ b/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/utils/FirRefWithValidityCheck.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlin.fir.declarations.FirDeclaration import org.jetbrains.kotlin.fir.declarations.FirResolvePhase import org.jetbrains.kotlin.idea.fir.low.level.api.api.FirModuleResolveState import org.jetbrains.kotlin.idea.fir.low.level.api.api.resolvedFirToPhase +import org.jetbrains.kotlin.idea.fir.low.level.api.api.withFirDeclaration import org.jetbrains.kotlin.idea.frontend.api.ValidityToken import org.jetbrains.kotlin.idea.frontend.api.ValidityTokenOwner import org.jetbrains.kotlin.idea.frontend.api.assertIsValid @@ -24,16 +25,15 @@ internal class FirRefWithValidityCheck(fir: D, resolveState: ?: throw EntityWasGarbageCollectedException("FirElement") val resolveState = resolveStateWeakRef.get() ?: throw EntityWasGarbageCollectedException("FirModuleResolveState") - fir.resolvedFirToPhase(phase, resolveState) return when (phase) { FirResolvePhase.BODY_RESOLVE -> { /* The BODY_RESOLVE phase is the maximum possible phase we can resolve our declaration to So there is not need to run whole `action` under read lock */ - action(resolveState.withFirDeclaration(fir) { it }) + action(fir.withFirDeclaration(resolveState, phase) { it }) } - else -> resolveState.withFirDeclaration(fir) { action(it) } + else -> fir.withFirDeclaration(resolveState, phase) { action(it) } } }