FIR IDE: implement thread safe fir caches for IDE

This commit is contained in:
Ilya Kirillov
2021-01-15 18:03:05 +01:00
parent 191a948ffe
commit 1fef5859e3
6 changed files with 191 additions and 3 deletions
@@ -0,0 +1,21 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches
import org.jetbrains.kotlin.fir.caches.FirCache
import java.util.concurrent.ConcurrentHashMap
internal class FirThreadSafeCache<KEY : Any, VALUE, CONTEXT>(
private val createValue: (KEY, CONTEXT) -> VALUE
) : FirCache<KEY, VALUE, CONTEXT>() {
private val map = ConcurrentHashMap<KEY, Any>()
override fun getValue(key: KEY, context: CONTEXT): VALUE =
map.computeIfAbsentWithNullableValue(key) { createValue(it, context) }
override fun getValueIfComputed(key: KEY): VALUE? =
map[key]?.nullValueToNull()
}
@@ -0,0 +1,29 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches
import org.jetbrains.kotlin.fir.caches.FirCache
import java.util.concurrent.ConcurrentHashMap
internal class FirThreadSafeCacheWithPostCompute<KEY : Any, VALUE, CONTEXT, DATA>(
private val createValue: (KEY, CONTEXT) -> Pair<VALUE, DATA>,
private val postCompute: (KEY, VALUE, DATA) -> Unit
) : FirCache<KEY, VALUE, CONTEXT>() {
private val map = ConcurrentHashMap<KEY, ValueWithPostCompute<KEY, VALUE, DATA>>()
@Suppress("UNCHECKED_CAST")
override fun getValue(key: KEY, context: CONTEXT): VALUE =
map.computeIfAbsent(key) {
ValueWithPostCompute(
key,
calculate = { createValue(it, context) },
postCompute = postCompute
)
}.getValue()
override fun getValueIfComputed(key: KEY): VALUE? =
map[key]?.getValueIfComputed()
}
@@ -0,0 +1,19 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches
import org.jetbrains.kotlin.fir.caches.*
object FirThreadSafeCachesFactory : FirCachesFactory() {
override fun <KEY : Any, VALUE, CONTEXT> createCache(createValue: (KEY, CONTEXT) -> VALUE): FirCache<KEY, VALUE, CONTEXT> =
FirThreadSafeCache(createValue)
override fun <KEY : Any, VALUE, CONTEXT, DATA> createCacheWithPostCompute(
createValue: (KEY, CONTEXT) -> Pair<VALUE, DATA>,
postCompute: (KEY, VALUE, DATA) -> Unit
): FirCache<KEY, VALUE, CONTEXT> =
FirThreadSafeCacheWithPostCompute(createValue, postCompute)
}
@@ -0,0 +1,24 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches
import java.util.concurrent.ConcurrentMap
internal object NullValue
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
internal inline fun <VALUE> Any.nullValueToNull(): VALUE = when (this) {
NullValue -> null
else -> this
} as VALUE
internal inline fun <KEY : Any, RESULT> ConcurrentMap<KEY, Any>.computeIfAbsentWithNullableValue(
key: KEY,
crossinline compute: (KEY) -> Any?
): RESULT {
val value = computeIfAbsent(key) { k -> compute(k) ?: NullValue }
return value.nullValueToNull()
}
@@ -0,0 +1,92 @@
/*
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches
/**
* Lazily calculated value which runs postCompute in the same thread,
* assuming that postCompute may try to read that value inside current thread,
* So in the period then value is calculated but post compute was not finished,
* only thread that initiated the calculating may see the value,
* other threads will have to wait wait until that value is calculated
*/
internal class ValueWithPostCompute<KEY, VALUE, DATA>(
/**
* We need at least one final field to be written in constructor to guarantee safe initialization of our [ValueWithPostCompute]
*/
private val key: KEY,
calculate: (KEY) -> Pair<VALUE, DATA>,
postCompute: (KEY, VALUE, DATA) -> Unit,
) {
private var _calculate: ((KEY) -> Pair<VALUE, DATA>)? = calculate
private var _postCompute: ((KEY, VALUE, DATA) -> Unit)? = postCompute
/**
* can be in one of the following three states:
* [ValueIsNotComputed] -- value is not initialized and thread are now executing [_postCompute]
* [ExceptionWasThrownDuringValueComputation] -- exception was thrown during value computation, it will be rethrown on every value access
* [ValueIsPostComputingNow] -- thread with threadId has computed the value and only it can access it during post compute
* some value of type [VALUE] -- value is computed and post compute was executed, values is visible for all threads
*
* Value may be set only under [LazyValueWithPostCompute] intrinsic lock hold
* And may be read from any thread
*/
@Volatile
private var value: Any? = ValueIsNotComputed
@Suppress("UNCHECKED_CAST")
fun getValue(): VALUE {
when (val stateSnapshot = value) {
is ValueIsPostComputingNow -> {
if (stateSnapshot.threadId == Thread.currentThread().id) {
return stateSnapshot.value as VALUE
} else {
synchronized(this) { // wait until other thread which holds the lock now computes the value
return value as VALUE
}
}
}
ValueIsNotComputed -> synchronized(this) {
// if we entered synchronized section that's mean that the value is not yet calculated and was not started to be calculated
// or the some other thread calculated the value while we were waiting to acquire the lock
if (value != ValueIsNotComputed) { // some other thread calculated our value
return value as VALUE
}
val calculatedValue = try {
val (calculated, data) = _calculate!!(key)
value = ValueIsPostComputingNow(calculated, Thread.currentThread().id) // only current thread may see the value
_postCompute!!(key, calculated, data)
calculated
} catch (e: Throwable) {
value = ExceptionWasThrownDuringValueComputation(e)
throw e
}
_calculate = null
_postCompute = null
value = calculatedValue
return calculatedValue
}
is ExceptionWasThrownDuringValueComputation -> {
throw stateSnapshot.error
}
else -> {
return value as VALUE
}
}
}
@Suppress("UNCHECKED_CAST")
fun getValueIfComputed(): VALUE? = when (val snapshot = value) {
ValueIsNotComputed -> null
is ValueIsPostComputingNow -> null
is ExceptionWasThrownDuringValueComputation -> throw snapshot.error
else -> value as VALUE
}
private class ValueIsPostComputingNow(val value: Any?, val threadId: Long)
private class ExceptionWasThrownDuringValueComputation(val error: Throwable)
private object ValueIsNotComputed
}
@@ -12,6 +12,7 @@ import org.jetbrains.kotlin.fir.BuiltinTypes
import org.jetbrains.kotlin.fir.PrivateSessionConstructor
import org.jetbrains.kotlin.fir.SessionConfiguration
import org.jetbrains.kotlin.fir.backend.jvm.FirJvmTypeMapper
import org.jetbrains.kotlin.fir.caches.FirCachesFactory
import org.jetbrains.kotlin.fir.checkers.registerCommonCheckers
import org.jetbrains.kotlin.fir.dependenciesWithoutSelf
import org.jetbrains.kotlin.fir.java.JavaSymbolProvider
@@ -33,6 +34,7 @@ import org.jetbrains.kotlin.idea.fir.low.level.api.IdeFirPhaseManager
import org.jetbrains.kotlin.idea.fir.low.level.api.IdeSessionComponents
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.FirFileBuilder
import org.jetbrains.kotlin.idea.fir.low.level.api.file.builder.ModuleFileCacheImpl
import org.jetbrains.kotlin.idea.fir.low.level.api.fir.caches.FirThreadSafeCachesFactory
import org.jetbrains.kotlin.idea.fir.low.level.api.lazy.resolve.FirLazyDeclarationResolver
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.FirModuleWithDependenciesSymbolProvider
import org.jetbrains.kotlin.idea.fir.low.level.api.providers.FirIdeProvider
@@ -71,9 +73,9 @@ internal object FirIdeSessionFactory {
val cache = ModuleFileCacheImpl(this)
val firPhaseManager = IdeFirPhaseManager(FirLazyDeclarationResolver(firFileBuilder), cache, sessionInvalidator)
registerIdeComponents()
registerCommonComponents(languageVersionSettings)
registerResolveComponents()
registerIdeComponents()
val provider = FirIdeProvider(
project,
@@ -156,9 +158,9 @@ internal object FirIdeSessionFactory {
val kotlinClassFinder = VirtualFileFinderFactory.getInstance(project).create(searchScope)
FirIdeLibrariesSession(moduleInfo, project, searchScope, builtinTypes).apply {
registerIdeComponents()
registerCommonComponents(languageVersionSettings)
registerJavaSpecificResolveComponents()
registerIdeComponents()
val javaSymbolProvider = JavaSymbolProvider(this, project, searchScope)
@@ -198,8 +200,8 @@ internal object FirIdeSessionFactory {
languageVersionSettings: LanguageVersionSettings = LanguageVersionSettingsImpl.DEFAULT
): FirIdeBuiltinsAndCloneableSession {
return FirIdeBuiltinsAndCloneableSession(project, builtinTypes).apply {
registerCommonComponents(languageVersionSettings)
registerIdeComponents()
registerCommonComponents(languageVersionSettings)
val kotlinScopeProvider = KotlinScopeProvider(::wrapScopeWithJvmMapped)
register(
@@ -218,5 +220,6 @@ internal object FirIdeSessionFactory {
private fun FirIdeSession.registerIdeComponents() {
register(IdeSessionComponents::class, IdeSessionComponents.create(this))
register(FirCachesFactory::class, FirThreadSafeCachesFactory)
}
}