FIR IDE: make some access to FIR elements under read locks, resolve under write locks

This commit is contained in:
Ilya Kirillov
2020-10-25 11:53:35 +03:00
parent 12ed92cd49
commit b42bed7b3c
13 changed files with 263 additions and 102 deletions
@@ -16,10 +16,12 @@ import org.jetbrains.kotlin.fir.resolve.providers.FirProvider
import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo
import org.jetbrains.kotlin.idea.fir.low.level.api.api.FirModuleResolveState
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.FirTowerDataContextCollector
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCache
import org.jetbrains.kotlin.idea.fir.low.level.api.file.structure.FirElementsRecorder
import org.jetbrains.kotlin.psi.KtDeclaration
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtLambdaExpression
internal class FirModuleResolveStateForCompletion(
override val project: Project,
@@ -37,7 +39,8 @@ internal class FirModuleResolveStateForCompletion(
originalState.getSessionFor(moduleInfo)
override fun getOrBuildFirFor(element: KtElement): FirElement {
completionMapping[originalState.elementBuilder.getPsiAsFirElementSource(element)]?.let { return it }
val psi = originalState.elementBuilder.getPsiAsFirElementSource(element)
synchronized(completionMapping) { completionMapping[psi] }?.let { return it }
return originalState.elementBuilder.getOrBuildFirFor(
element,
originalState.rootModuleSession.cache,
@@ -53,7 +56,7 @@ internal class FirModuleResolveStateForCompletion(
}
override fun recordPsiToFirMappingsForCompletionFrom(fir: FirDeclaration, firFile: FirFile, ktFile: KtFile) {
fir.accept(FirElementsRecorder(), completionMapping)
synchronized(completionMapping) { fir.accept(FirElementsRecorder(), completionMapping) }
}
override fun <D : FirDeclaration> resolvedFirToPhase(declaration: D, toPhase: FirResolvePhase): D {
@@ -70,6 +73,10 @@ internal class FirModuleResolveStateForCompletion(
originalState.lazyResolveDeclarationForCompletion(firFunction, containerFirFile, firIdeProvider, toPhase, towerDataContextCollector)
}
override fun getFirFile(declaration: FirDeclaration, cache: ModuleFileCache): FirFile? {
return cache.getContainerFirFile(declaration)
}
override fun getDiagnostics(element: KtElement): List<Diagnostic> {
error("Diagnostics should not be retrieved in completion")
}
@@ -82,6 +89,14 @@ internal class FirModuleResolveStateForCompletion(
error("Should not be used in completion")
}
override fun findSourceFirDeclaration(ktDeclaration: KtDeclaration): FirDeclaration {
error("Should not be used in completion")
}
override fun findSourceFirDeclaration(ktDeclaration: KtLambdaExpression): FirDeclaration {
error("Should not be used in completion")
}
override fun getBuiltFirFileOrNull(ktFile: KtFile): FirFile? {
error("Should not be used in completion")
}
@@ -12,6 +12,7 @@ import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
import org.jetbrains.kotlin.fir.psi
import org.jetbrains.kotlin.fir.resolve.providers.FirProvider
import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo
import org.jetbrains.kotlin.idea.caches.project.ModuleSourceInfo
@@ -20,16 +21,18 @@ import org.jetbrains.kotlin.idea.fir.low.level.api.api.FirModuleResolveState
import org.jetbrains.kotlin.idea.fir.low.level.api.diagnostics.DiagnosticsCollector
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.FirElementBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.FirTowerDataContextCollector
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.getNonLocalContainingOrThisDeclaration
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.FirFileBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCache
import org.jetbrains.kotlin.idea.fir.low.level.api.file.structure.FileStructureCache
import org.jetbrains.kotlin.idea.fir.low.level.api.lazy.resolve.FirLazyDeclarationResolver
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.firIdeProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSessionProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSourcesSession
import org.jetbrains.kotlin.idea.fir.low.level.api.util.FirElementFinder
import org.jetbrains.kotlin.idea.fir.low.level.api.util.findSourceNonLocalFirDeclaration
import org.jetbrains.kotlin.psi.KtDeclaration
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.idea.util.getElementTextInContext
import org.jetbrains.kotlin.psi.*
internal class FirModuleResolveStateImpl(
override val project: Project,
@@ -76,6 +79,31 @@ internal class FirModuleResolveStateImpl(
sessionProvider.getModuleCache(ktDeclaration.getModuleInfo() as ModuleSourceInfo)
)
override fun findSourceFirDeclaration(ktDeclaration: KtDeclaration): FirDeclaration =
findSourceFirDeclarationByExpression(ktDeclaration)
override fun findSourceFirDeclaration(ktDeclaration: KtLambdaExpression): FirDeclaration =
findSourceFirDeclarationByExpression(ktDeclaration)
/**
* [ktDeclaration] should be either [KtDeclaration] or [KtLambdaExpression]
*/
private fun findSourceFirDeclarationByExpression(ktDeclaration: KtExpression): FirDeclaration {
val nonLocalFirDeclaration = ktDeclaration.getNonLocalContainingOrThisDeclaration()
?: error("Declaration should have non-local container${ktDeclaration.getElementTextInContext()}")
if (ktDeclaration == nonLocalFirDeclaration) return findNonLocalSourceFirDeclaration(ktDeclaration as KtDeclaration)
val container = nonLocalFirDeclaration.findSourceNonLocalFirDeclaration(
firFileBuilder,
rootModuleSession.firIdeProvider.symbolProvider,
sessionProvider.getModuleCache(ktDeclaration.getModuleInfo() as ModuleSourceInfo)
)
val firDeclaration = FirElementFinder.findElementIn<FirDeclaration>(container) { firDeclaration ->
firDeclaration.psi == ktDeclaration
}
return firDeclaration
?: error("FirDeclaration was not found for\n${ktDeclaration.getElementTextInContext()}")
}
override fun isFirFileBuilt(ktFile: KtFile): Boolean {
val moduleSourceInfo = ktFile.getModuleInfo() as? ModuleSourceInfo ?: return true
val cache = sessionProvider.getModuleCache(moduleSourceInfo)
@@ -111,4 +139,7 @@ internal class FirModuleResolveStateImpl(
)
}
}
override fun getFirFile(declaration: FirDeclaration, cache: ModuleFileCache): FirFile? =
cache.getContainerFirFile(declaration)
}
@@ -8,13 +8,11 @@ package org.jetbrains.kotlin.idea.fir.low.level.api.api
import org.jetbrains.annotations.TestOnly
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.diagnostics.Diagnostic
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
import org.jetbrains.kotlin.fir.*
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.resolve.providers.FirProvider
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProvider
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo
import org.jetbrains.kotlin.idea.caches.project.getModuleInfo
import org.jetbrains.kotlin.idea.fir.low.level.api.FirIdeResolveStateService
@@ -22,9 +20,14 @@ import org.jetbrains.kotlin.idea.fir.low.level.api.FirTransformerProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.FirTowerDataContextCollector
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.FirFileBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCache
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.withReadLock
import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSourcesSession
import org.jetbrains.kotlin.idea.fir.low.level.api.util.ktDeclaration
import org.jetbrains.kotlin.idea.util.getElementTextInContext
import org.jetbrains.kotlin.psi.KtDeclaration
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtLambdaExpression
abstract class FirModuleResolveState {
abstract val project: Project
@@ -54,6 +57,15 @@ abstract class FirModuleResolveState {
ktDeclaration: KtDeclaration,
): FirDeclaration
abstract fun findSourceFirDeclaration(
ktDeclaration: KtDeclaration,
): FirDeclaration
abstract fun findSourceFirDeclaration(
ktDeclaration: KtLambdaExpression,
): FirDeclaration
// todo temporary, used only in completion
internal abstract fun recordPsiToFirMappingsForCompletionFrom(fir: FirDeclaration, firFile: FirFile, ktFile: KtFile)
@@ -67,5 +79,22 @@ abstract class FirModuleResolveState {
toPhase: FirResolvePhase,
towerDataContextCollector: FirTowerDataContextCollector
)
}
internal abstract fun getFirFile(declaration: FirDeclaration, cache: ModuleFileCache): FirFile?
fun <D : FirDeclaration, R> withFirDeclaration(declaration: D, action: (D) -> R): R {
val originalDeclaration = (declaration as? FirCallableDeclaration<*>)?.unwrapFakeOverrides() ?: declaration
val session = originalDeclaration.session
return when {
originalDeclaration.origin == FirDeclarationOrigin.Source
&& session is FirIdeSourcesSession
-> {
val cache = session.cache
val file = getFirFile(declaration, cache)
?: error("Fir file was not found for\n${declaration.render()}\n${declaration.ktDeclaration.getElementTextInContext()}")
cache.firFileLockProvider.withReadLock(file) { action(declaration) }
}
else -> action(declaration)
}
}
}
@@ -8,15 +8,21 @@ package org.jetbrains.kotlin.idea.fir.low.level.api.api
import org.jetbrains.kotlin.diagnostics.Diagnostic
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.FirSymbolOwner
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
import org.jetbrains.kotlin.fir.psi
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo
import org.jetbrains.kotlin.idea.caches.project.getModuleInfo
import org.jetbrains.kotlin.idea.fir.low.level.api.FirIdeResolveStateService
import org.jetbrains.kotlin.idea.fir.low.level.api.sessions.FirIdeSourcesSession
import org.jetbrains.kotlin.psi.KtDeclaration
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtLambdaExpression
import kotlin.reflect.KClass
object LowLevelFirApiFacade {
@@ -29,24 +35,15 @@ object LowLevelFirApiFacade {
fun getSessionFor(element: KtElement): FirSession =
getResolveStateFor(element).getSessionFor(element.getModuleInfo())
@Deprecated("Consider using withFirElement")
fun getOrBuildFirFor(element: KtElement, resolveState: FirModuleResolveState): FirElement =
resolveState.getOrBuildFirFor(element)
@Suppress("DEPRECATION")
inline fun <R> withFirElement(element: KtElement, resolveState: FirModuleResolveState, action: (FirElement) -> R): R =
action(getOrBuildFirFor(element, resolveState))
@Deprecated("Consider using withFirFile")
fun getFirFile(ktFile: KtFile, resolveState: FirModuleResolveState) =
resolveState.getFirFile(ktFile)
@Suppress("DEPRECATION")
inline fun <R> withFirFile(ktFile: KtFile, resolveState: FirModuleResolveState, action: (FirFile) -> R): R =
action(getFirFile(ktFile, resolveState))
/**
* Creates [FirDeclaration] by [KtDeclaration] and runs an [action] with it
* [ktDeclaration]
* [FirDeclaration] passed to [action] should not be leaked outside [action] lambda
* [FirDeclaration] passed to [action] will be resolved at least to [phase]
* Otherwise, some threading problems may arise,
@@ -59,11 +56,31 @@ object LowLevelFirApiFacade {
phase: FirResolvePhase = FirResolvePhase.RAW_FIR,
action: (FirDeclaration) -> R
): R {
val firDeclaration = resolveState.findNonLocalSourceFirDeclaration(ktDeclaration)
val firDeclaration = resolveState.findSourceFirDeclaration(ktDeclaration)
resolvedFirToPhase(firDeclaration, phase, resolveState)
return action(firDeclaration)
}
inline fun <reified F : FirDeclaration, R> withFirDeclarationOfType(
ktDeclaration: KtDeclaration,
resolveState: FirModuleResolveState,
action: (F) -> R
): R {
val firDeclaration = resolveState.findSourceFirDeclaration(ktDeclaration)
if (firDeclaration !is F) throw InvalidFirElementTypeException(ktDeclaration, F::class, firDeclaration::class)
return action(firDeclaration)
}
inline fun <reified F : FirDeclaration, R> withFirDeclarationOfType(
ktDeclaration: KtLambdaExpression,
resolveState: FirModuleResolveState,
action: (F) -> R
): R {
val firDeclaration = resolveState.findSourceFirDeclaration(ktDeclaration)
if (firDeclaration !is F) throw InvalidFirElementTypeException(ktDeclaration, F::class, firDeclaration::class)
return action(firDeclaration)
}
inline fun <F : FirElement, R> withFir(fir: F, action: (F) -> R): R {
// TODO locking
return action(fir)
@@ -83,17 +100,16 @@ object LowLevelFirApiFacade {
resolveState.resolvedFirToPhase(firDeclaration, phase)
}
@Deprecated("Consider using withFir")
fun KtElement.getOrBuildFir(
resolveState: FirModuleResolveState,
) = LowLevelFirApiFacade.getOrBuildFirFor(this, resolveState)
@Deprecated("Consider using withFirSafe")
inline fun <reified E : FirElement> KtElement.getOrBuildFirSafe(
resolveState: FirModuleResolveState,
) = LowLevelFirApiFacade.getOrBuildFirFor(this, resolveState) as? E
@Deprecated("Consider using withFirOfType")
inline fun <reified E : FirElement> KtElement.getOrBuildFirOfType(
resolveState: FirModuleResolveState,
): E {
@@ -102,25 +118,6 @@ inline fun <reified E : FirElement> KtElement.getOrBuildFirOfType(
throw InvalidFirElementTypeException(this, E::class, fir::class)
}
inline fun <R> KtElement.withFir(
resolveState: FirModuleResolveState,
action: (FirElement) -> R,
) = LowLevelFirApiFacade.withFirElement(this, resolveState, action)
inline fun <R : Any, reified E : FirElement> KtElement.withFirSafe(
resolveState: FirModuleResolveState,
action: (E) -> R?
) = LowLevelFirApiFacade.withFirElement(this, resolveState) { element -> (element as? E)?.let(action) }
@Suppress("DEPRECATION")
inline fun <R, reified E : FirElement> KtElement.withFirOfType(
resolveState: FirModuleResolveState,
action: (E) -> R
): R = action(getOrBuildFirOfType(resolveState))
class InvalidFirElementTypeException(
ktElement: KtElement,
expectedFirClass: KClass<out FirElement>,
@@ -54,8 +54,8 @@ internal class FirFileBuilder(
val needResolve = toPhase > FirResolvePhase.RAW_FIR
val firFile = buildRawFirFileWithCaching(ktFile, cache, lazyBodiesMode = !needResolve)
if (needResolve) {
cache.firFileLockProvider.withLock(firFile) {
if (firFile.resolvePhase >= toPhase) return@withLock
cache.firFileLockProvider.withWriteLock(firFile) {
if (firFile.resolvePhase >= toPhase) return@withWriteLock
runResolveWithoutLock(firFile, fromPhase = firFile.resolvePhase, toPhase = toPhase, checkPCE = checkPCE)
}
}
@@ -66,11 +66,11 @@ internal class FirFileBuilder(
* Runs [resolve] function (which is considered to do some resolve on [firFile]) under a lock for [firFile]
*/
inline fun <R> runCustomResolveUnderLock(firFile: FirFile, cache: ModuleFileCache, resolve: () -> R): R =
cache.firFileLockProvider.withLock(firFile) { resolve() }
cache.firFileLockProvider.withWriteLock(firFile) { resolve() }
inline fun <R : Any> runCustomResolveWithPCECheck(firFile: FirFile, cache: ModuleFileCache, resolve: () -> R): R {
val lock = cache.firFileLockProvider.getLockFor(firFile)
return lock.lockWithPCECheck(LOCKING_INTERVAL_MS) { resolve() }
return lock.writeLock().lockWithPCECheck(LOCKING_INTERVAL_MS) { resolve() }
}
fun runResolveWithoutLock(
@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.fir.declarations.FirClassLikeDeclaration
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.psi
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.symbols.CallableId
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.name.ClassId
@@ -50,12 +51,13 @@ internal abstract class ModuleFileCache {
abstract fun getCachedFirFile(ktFile: KtFile): FirFile?
// todo make it ReadWriteLock and allow access fir elements only under read lock
// for now locks only held for resolve
// but there can be a situation when we are accessing some fir element in one thread without lock
// in the same time other thread performs resolve of it
// which can cause weird errors on user side
abstract val firFileLockProvider: LockProvider<FirFile, ReentrantLock>
abstract val firFileLockProvider: LockProvider<FirFile, ReentrantReadWriteLock>
inline fun <D : FirDeclaration, R> withReadLockOn(declaration: D, action: (D) -> R): R {
val file = getContainerFirFile(declaration)
?: error("No fir file found for\n${declaration.render()}")
return firFileLockProvider.withReadLock(file) { action(declaration) }
}
}
internal class ModuleFileCacheImpl(override val session: FirSession) : ModuleFileCache() {
@@ -74,5 +76,5 @@ internal class ModuleFileCacheImpl(override val session: FirSession) : ModuleFil
return getCachedFirFile(ktFile)
}
override val firFileLockProvider: LockProvider<FirFile, ReentrantLock> = LockProvider { ReentrantLock() }
override val firFileLockProvider: LockProvider<FirFile, ReentrantReadWriteLock> = LockProvider { ReentrantReadWriteLock() }
}
@@ -12,6 +12,8 @@ import org.jetbrains.kotlin.fir.resolve.toSymbol
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.getNonLocalContainingOrThisDeclaration
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.FirFileBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCache
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.withReadLock
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.withWriteLock
import org.jetbrains.kotlin.idea.fir.low.level.api.lazy.resolve.FirLazyDeclarationResolver
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.firIdeProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.util.findSourceNonLocalFirDeclaration
@@ -90,7 +92,10 @@ internal class FileStructure(
): WithInBlockModificationFileStructureElement {
val newFunction = firIdeProvider.buildFunctionWithBody(containerKtFunction) as FirSimpleFunction
val originalFunction = original.firSymbol.fir as FirSimpleFunction
replaceFunction(originalFunction, newFunction)
moduleFileCache.firFileLockProvider.withWriteLock(firFile) {
replaceFunction(originalFunction, newFunction)
}
try {
firLazyDeclarationResolver.lazyResolveDeclaration(
@@ -100,41 +105,52 @@ internal class FileStructure(
checkPCE = true,
reresolveFile = true,
)
return WithInBlockModificationFileStructureElement(
firFile,
containerKtFunction,
newFunction.symbol,
containerKtFunction.modificationStamp,
)
return moduleFileCache.firFileLockProvider.withReadLock(firFile) {
WithInBlockModificationFileStructureElement(
firFile,
containerKtFunction,
newFunction.symbol,
containerKtFunction.modificationStamp,
)
}
} catch (e: Throwable) {
replaceFunction(newFunction, originalFunction)
moduleFileCache.firFileLockProvider.withWriteLock(firFile) {
replaceFunction(newFunction, originalFunction)
}
throw e
}
}
private fun createDeclarationStructure(declaration: KtDeclaration): FileStructureElement {
val firDeclaration = declaration.findSourceNonLocalFirDeclaration(firFileBuilder, firIdeProvider.symbolProvider, moduleFileCache)
val firDeclaration = declaration.findSourceNonLocalFirDeclaration(
firFileBuilder,
firIdeProvider.symbolProvider,
moduleFileCache,
firFile
)
firLazyDeclarationResolver.lazyResolveDeclaration(
firDeclaration,
moduleFileCache,
FirResolvePhase.BODY_RESOLVE,
checkPCE = true
)
return when {
declaration is KtNamedFunction && declaration.hasExplicitTypeOrUnit -> {
WithInBlockModificationFileStructureElement(
firFile,
declaration,
(firDeclaration as FirSimpleFunction).symbol,
declaration.modificationStamp,
)
}
else -> {
NonLocalDeclarationFileStructureElement(
firFile,
firDeclaration,
declaration,
)
return moduleFileCache.firFileLockProvider.withReadLock(firFile) {
when {
declaration is KtNamedFunction && declaration.hasExplicitTypeOrUnit -> {
WithInBlockModificationFileStructureElement(
firFile,
declaration,
(firDeclaration as FirSimpleFunction).symbol,
declaration.modificationStamp,
)
}
else -> {
NonLocalDeclarationFileStructureElement(
firFile,
firDeclaration,
declaration,
)
}
}
}
}
@@ -148,7 +148,7 @@ internal class FirLazyDeclarationResolver(
val ktDeclaration = firDeclarationToResolve.ktDeclaration
designation += ktDeclaration.parentsOfType<KtClassOrObject>()
.filter { it !is KtEnumEntry }
.map { it.findSourceNonLocalFirDeclaration(firFileBuilder, provider.symbolProvider, moduleFileCache) }
.map { it.findSourceNonLocalFirDeclaration(firFileBuilder, provider.symbolProvider, moduleFileCache, containerFirFile) }
.toList()
.asReversed()
if (nonLocalDeclarationToResolve is FirCallableDeclaration<*>) {
@@ -6,11 +6,14 @@
package org.jetbrains.kotlin.idea.fir.low.level.api.util
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.declarations.FirRegularClass
import org.jetbrains.kotlin.fir.declarations.FirTypeAlias
import org.jetbrains.kotlin.fir.psi
import org.jetbrains.kotlin.fir.realPsi
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.api.InvalidFirElementTypeException
import org.jetbrains.kotlin.idea.fir.low.level.api.element.builder.getNonLocalContainingOrThisDeclaration
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.FirFileBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCache
import org.jetbrains.kotlin.idea.util.classIdIfNonLocal
@@ -22,19 +25,46 @@ import org.jetbrains.kotlin.psi.psiUtil.containingClassOrObject
internal fun KtDeclaration.findSourceNonLocalFirDeclaration(
firFileBuilder: FirFileBuilder,
firSymbolProvider: FirSymbolProvider,
moduleFileCache: ModuleFileCache
moduleFileCache: ModuleFileCache,
containerFirFile: FirFile? = null
): FirDeclaration {
//TODO test what way faster
findSourceNonLocalFirDeclarationByProvider(firFileBuilder, firSymbolProvider, moduleFileCache)?.let { return it }
findSourceOfNonLocalFirDeclarationByTraversingWholeTree(firFileBuilder, moduleFileCache)?.let { return it }
findSourceNonLocalFirDeclarationByProvider(firFileBuilder, firSymbolProvider, moduleFileCache, containerFirFile)?.let { return it }
findSourceOfNonLocalFirDeclarationByTraversingWholeTree(firFileBuilder, moduleFileCache, containerFirFile)?.let { return it }
error("No fir element was found for\n${getElementTextInContext()}")
}
internal fun KtDeclaration.findFirDeclarationForAnyFirSourceDeclaration(
firFileBuilder: FirFileBuilder,
firSymbolProvider: FirSymbolProvider,
moduleFileCache: ModuleFileCache
): FirDeclaration {
val nonLocalDeclaration = getNonLocalContainingOrThisDeclaration()
?.findSourceNonLocalFirDeclaration(firFileBuilder, firSymbolProvider, moduleFileCache)
?: firFileBuilder.buildRawFirFileWithCaching(containingKtFile, moduleFileCache, lazyBodiesMode = true)
val fir = FirElementFinder.findElementIn<FirDeclaration>(nonLocalDeclaration) { firDeclaration ->
firDeclaration.psi == this
}
return fir
?: error("FirDeclaration was not found for\n${getElementTextInContext()}")
}
internal inline fun <reified F : FirDeclaration> KtDeclaration.findFirDeclarationForAnyFirSourceDeclarationOfType(
firFileBuilder: FirFileBuilder,
firSymbolProvider: FirSymbolProvider,
moduleFileCache: ModuleFileCache
): FirDeclaration {
val fir = findFirDeclarationForAnyFirSourceDeclaration(firFileBuilder, firSymbolProvider, moduleFileCache)
if (fir !is F) throw InvalidFirElementTypeException(this, F::class, fir::class)
return fir
}
private fun KtDeclaration.findSourceOfNonLocalFirDeclarationByTraversingWholeTree(
firFileBuilder: FirFileBuilder,
moduleFileCache: ModuleFileCache,
containerFirFile: FirFile?,
): FirDeclaration? {
val firFile = firFileBuilder.buildRawFirFileWithCaching(containingKtFile, moduleFileCache, lazyBodiesMode = true)
val firFile = containerFirFile ?: firFileBuilder.buildRawFirFileWithCaching(containingKtFile, moduleFileCache, lazyBodiesMode = true)
val originalDeclaration = originalDeclaration
return FirElementFinder.findElementIn(firFile, goInside = { it is FirRegularClass }) { firDeclaration ->
firDeclaration.psi == this || firDeclaration.psi == originalDeclaration
@@ -44,7 +74,8 @@ private fun KtDeclaration.findSourceOfNonLocalFirDeclarationByTraversingWholeTre
private fun KtDeclaration.findSourceNonLocalFirDeclarationByProvider(
firFileBuilder: FirFileBuilder,
firSymbolProvider: FirSymbolProvider,
moduleFileCache: ModuleFileCache
moduleFileCache: ModuleFileCache,
containerFirFile: FirFile?
): FirDeclaration? {
val candidate = when {
this is KtClassOrObject -> findFir(firSymbolProvider)
@@ -55,7 +86,7 @@ private fun KtDeclaration.findSourceNonLocalFirDeclarationByProvider(
containerClassFir?.declarations
} else {
val ktFile = containingKtFile
val firFile = firFileBuilder.buildRawFirFileWithCaching(ktFile, moduleFileCache, lazyBodiesMode = true)
val firFile = containerFirFile ?: firFileBuilder.buildRawFirFileWithCaching(ktFile, moduleFileCache, lazyBodiesMode = true)
firFile.declarations
}
val original = originalDeclaration
@@ -93,3 +124,4 @@ private fun KtTypeAlias.findFir(firSymbolProvider: FirSymbolProvider): FirTypeAl
?: error("Could not find type alias $typeAlias")
}
}
@@ -15,7 +15,7 @@ import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.idea.caches.project.IdeaModuleInfo
import org.jetbrains.kotlin.psi.KtDeclaration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.locks.Lock
internal inline fun <T> executeOrReturnDefaultValueOnPCE(defaultValue: T, action: () -> T): T =
@@ -31,7 +31,7 @@ internal inline fun <T : Any> executeWithoutPCE(crossinline action: () -> T): T
return result!!
}
internal inline fun <T : Any> ReentrantLock.lockWithPCECheck(lockingIntervalMs: Long, action: () -> T): T {
internal inline fun <T : Any> Lock.lockWithPCECheck(lockingIntervalMs: Long, action: () -> T): T {
var needToRun = true
var result: T? = null
while (needToRun) {
@@ -60,7 +60,7 @@ internal class KtFirCompletionCandidateChecker(
private inline fun <reified T : KtFirSymbol<F>, F : FirDeclaration, R> KtCallableSymbol.withResolvedFirOfType(
noinline action: (F) -> R,
): R? = this.safeAs<T>()?.firRef?.withFir(FirResolvePhase.BODY_RESOLVE, action)
): R? = this.safeAs<T>()?.firRef?.withFirResolvedToBodyResolve(action)
private fun checkExtension(
candidateSymbol: FirCallableDeclaration<*>,
@@ -5,9 +5,11 @@
package org.jetbrains.kotlin.idea.frontend.api.fir.symbols
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProvider
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.idea.fir.low.level.api.api.FirModuleResolveState
import org.jetbrains.kotlin.idea.fir.low.level.api.api.LowLevelFirApiFacade
import org.jetbrains.kotlin.idea.fir.low.level.api.api.getOrBuildFirOfType
import org.jetbrains.kotlin.idea.frontend.api.KtAnalysisSession
import org.jetbrains.kotlin.idea.frontend.api.ValidityTokenOwner
@@ -31,47 +33,70 @@ internal class KtFirSymbolProvider(
private val firSymbolProvider by weakRef(firSymbolProvider)
override fun getParameterSymbol(psi: KtParameter): KtParameterSymbol = withValidityAssertion {
firSymbolBuilder.buildParameterSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirValueParameter, KtParameterSymbol>(psi, resolveState) {
firSymbolBuilder.buildParameterSymbol(it)
}
}
override fun getFunctionSymbol(psi: KtNamedFunction): KtFunctionSymbol = withValidityAssertion {
firSymbolBuilder.buildFunctionSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirSimpleFunction, KtFunctionSymbol>(psi, resolveState) {
firSymbolBuilder.buildFunctionSymbol(it)
}
}
override fun getConstructorSymbol(psi: KtConstructor<*>): KtConstructorSymbol = withValidityAssertion {
firSymbolBuilder.buildConstructorSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirConstructor, KtConstructorSymbol>(psi, resolveState) {
firSymbolBuilder.buildConstructorSymbol(it)
}
}
override fun getTypeParameterSymbol(psi: KtTypeParameter): KtTypeParameterSymbol = withValidityAssertion {
firSymbolBuilder.buildTypeParameterSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirTypeParameter, KtTypeParameterSymbol>(psi, resolveState) {
firSymbolBuilder.buildTypeParameterSymbol(it)
}
}
override fun getTypeAliasSymbol(psi: KtTypeAlias): KtTypeAliasSymbol = withValidityAssertion {
firSymbolBuilder.buildTypeAliasSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirTypeAlias, KtTypeAliasSymbol>(psi, resolveState) {
firSymbolBuilder.buildTypeAliasSymbol(it)
}
}
override fun getEnumEntrySymbol(psi: KtEnumEntry): KtEnumEntrySymbol = withValidityAssertion {
firSymbolBuilder.buildEnumEntrySymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirEnumEntry, KtEnumEntrySymbol>(psi, resolveState) {
firSymbolBuilder.buildEnumEntrySymbol(it)
}
}
override fun getAnonymousFunctionSymbol(psi: KtNamedFunction): KtAnonymousFunctionSymbol = withValidityAssertion {
LowLevelFirApiFacade.withFirDeclarationOfType<FirSimpleFunction, KtFunctionSymbol>(psi, resolveState) {
firSymbolBuilder.buildFunctionSymbol(it)
}
firSymbolBuilder.buildAnonymousFunctionSymbol(psi.getOrBuildFirOfType(resolveState))
}
override fun getAnonymousFunctionSymbol(psi: KtLambdaExpression): KtAnonymousFunctionSymbol = withValidityAssertion {
firSymbolBuilder.buildAnonymousFunctionSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirAnonymousFunction, KtAnonymousFunctionSymbol>(psi, resolveState) {
firSymbolBuilder.buildAnonymousFunctionSymbol(it)
}
}
override fun getVariableSymbol(psi: KtProperty): KtVariableSymbol = withValidityAssertion {
firSymbolBuilder.buildVariableSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirProperty, KtVariableSymbol>(psi, resolveState) {
firSymbolBuilder.buildVariableSymbol(it)
}
}
override fun getClassOrObjectSymbol(psi: KtClassOrObject): KtClassOrObjectSymbol = withValidityAssertion {
firSymbolBuilder.buildClassSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirRegularClass, KtClassOrObjectSymbol>(psi, resolveState) {
firSymbolBuilder.buildClassSymbol(it)
}
}
override fun getPropertyAccessorSymbol(psi: KtPropertyAccessor): KtPropertyAccessorSymbol = withValidityAssertion {
firSymbolBuilder.buildPropertyAccessorSymbol(psi.getOrBuildFirOfType(resolveState))
LowLevelFirApiFacade.withFirDeclarationOfType<FirPropertyAccessor, KtPropertyAccessorSymbol>(psi, resolveState) {
firSymbolBuilder.buildPropertyAccessorSymbol(it)
}
}
override fun getClassOrObjectSymbolByClassId(classId: ClassId): KtClassOrObjectSymbol? = withValidityAssertion {
@@ -18,10 +18,24 @@ internal class FirRefWithValidityCheck<D : FirDeclaration>(fir: D, resolveState:
private val firWeakRef = WeakReference(fir)
private val resolveStateWeakRef = WeakReference(resolveState)
inline fun <R> withFir(phase: FirResolvePhase = FirResolvePhase.RAW_FIR, action: (fir: D) -> R): R {
inline fun <R> withFir(phase: FirResolvePhase = FirResolvePhase.RAW_FIR, crossinline action: (fir: D) -> R): R {
token.assertIsValid()
val fir = firWeakRef.get() ?: error("FirElement was garbage collected while analysis session is still valid")
return action(LowLevelFirApiFacade.resolvedFirToPhase(fir, phase, resolveState))
val fir = firWeakRef.get()
?: error("FirElement was garbage collected while analysis session is still valid")
val resolveState =
resolveStateWeakRef.get() ?: error("FirModuleResolveState was garbage collected while analysis session is still valid")
LowLevelFirApiFacade.resolvedFirToPhase(fir, phase, resolveState)
return resolveState.withFirDeclaration(fir) { action(it) }
}
inline fun <R> withFirResolvedToBodyResolve(action: (fir: D) -> R): R {
token.assertIsValid()
val fir = firWeakRef.get()
?: error("FirElement was garbage collected while analysis session is still valid")
val resolveState =
resolveStateWeakRef.get() ?: error("FirModuleResolveState was garbage collected while analysis session is still valid")
LowLevelFirApiFacade.resolvedFirToPhase(fir, FirResolvePhase.BODY_RESOLVE, resolveState)
return action(resolveState.withFirDeclaration(fir) { it })
}
val resolveState