diff --git a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCallResolver.kt b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCallResolver.kt index b716b265170..eb4b9a3cfab 100644 --- a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCallResolver.kt +++ b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCallResolver.kt @@ -48,12 +48,14 @@ import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.findAssignment import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.utils.addToStdlib.safeAs +import java.util.concurrent.ConcurrentHashMap internal class KtFirCallResolver( override val analysisSession: KtFirAnalysisSession, override val token: ValidityToken, ) : KtCallResolver(), KtFirAnalysisSessionComponent { private val diagnosticCache = mutableListOf() + private val cache: ConcurrentHashMap = ConcurrentHashMap() override fun resolveAccessorCall(call: KtSimpleNameExpression): KtCall? = withValidityAssertion { when (val fir = call.getOrBuildFir(firResolveState)) { @@ -88,44 +90,52 @@ internal class KtFirCallResolver( } override fun resolveCall(call: KtBinaryExpression): KtCall? = withValidityAssertion { - when (val fir = call.getOrBuildFir(firResolveState)) { - is FirFunctionCall -> resolveCall(fir) - is FirComparisonExpression -> resolveCall(fir.compareToCall) - is FirEqualityOperatorCall -> null // TODO - else -> null + cache.computeIfAbsent(call) { + when (val fir = call.getOrBuildFir(firResolveState)) { + is FirFunctionCall -> resolveCall(fir) + is FirComparisonExpression -> resolveCall(fir.compareToCall) + is FirEqualityOperatorCall -> null // TODO + else -> null + } } } override fun resolveCall(call: KtUnaryExpression): KtCall? = withValidityAssertion { - when (val fir = call.getOrBuildFir(firResolveState)) { - is FirFunctionCall -> resolveCall(fir) - is FirBlock -> { - // Desugared increment or decrement block. See [BaseFirBuilder#generateIncrementOrDecrementBlock] - // There would be corresponding inc()/dec() call that is assigned back to a temp variable. - val prefix = fir.statements.filterIsInstance().find { it.rValue is FirFunctionCall } - (prefix?.rValue as? FirFunctionCall)?.let { resolveCall(it) } + cache.computeIfAbsent(call) { + when (val fir = call.getOrBuildFir(firResolveState)) { + is FirFunctionCall -> resolveCall(fir) + is FirBlock -> { + // Desugared increment or decrement block. See [BaseFirBuilder#generateIncrementOrDecrementBlock] + // There would be corresponding inc()/dec() call that is assigned back to a temp variable. + val prefix = fir.statements.filterIsInstance().find { it.rValue is FirFunctionCall } + (prefix?.rValue as? FirFunctionCall)?.let { resolveCall(it) } + } + is FirCheckNotNullCall -> null // TODO + else -> null } - is FirCheckNotNullCall -> null // TODO - else -> null } } override fun resolveCall(call: KtCallElement): KtCall? = withValidityAssertion { - return when (val fir = call.getOrBuildFir(firResolveState)) { - is FirArrayOfCall -> resolveArrayOfCall(fir) - is FirFunctionCall -> resolveCall(fir) - is FirAnnotationCall -> fir.asAnnotationCall() - is FirDelegatedConstructorCall -> fir.asDelegatedConstructorCall() - is FirConstructor -> fir.asDelegatedConstructorCall() - is FirSafeCallExpression -> fir.regularQualifiedAccess.safeAs()?.let { resolveCall(it) } - else -> null + cache.computeIfAbsent(call) { + when (val fir = call.getOrBuildFir(firResolveState)) { + is FirArrayOfCall -> resolveArrayOfCall(fir) + is FirFunctionCall -> resolveCall(fir) + is FirAnnotationCall -> fir.asAnnotationCall() + is FirDelegatedConstructorCall -> fir.asDelegatedConstructorCall() + is FirConstructor -> fir.asDelegatedConstructorCall() + is FirSafeCallExpression -> fir.regularQualifiedAccess.safeAs()?.let { resolveCall(it) } + else -> null + } } } override fun resolveCall(call: KtArrayAccessExpression): KtCall? = withValidityAssertion { - return when (val fir = call.getOrBuildFir(firResolveState)) { - is FirFunctionCall -> resolveCall(fir) - else -> null + cache.computeIfAbsent(call) { + when (val fir = call.getOrBuildFir(firResolveState)) { + is FirFunctionCall -> resolveCall(fir) + else -> null + } } }