From 0609aa1e2e05f050aaaed5cb48924b90a6fcc78a Mon Sep 17 00:00:00 2001 From: Roman Golyshev Date: Mon, 21 Dec 2020 20:00:09 +0300 Subject: [PATCH] FIR IDE: Refactor `KtFirReferenceShortener` --- .../fir/components/KtFirReferenceShortener.kt | 75 +++++++++++-------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/components/KtFirReferenceShortener.kt b/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/components/KtFirReferenceShortener.kt index 54da7a7b099..6fa60c8d2a1 100644 --- a/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/components/KtFirReferenceShortener.kt +++ b/idea/idea-frontend-fir/src/org/jetbrains/kotlin/idea/frontend/api/fir/components/KtFirReferenceShortener.kt @@ -24,6 +24,7 @@ import org.jetbrains.kotlin.idea.frontend.api.components.ShortenCommand import org.jetbrains.kotlin.idea.frontend.api.fir.KtFirAnalysisSession import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.psi.KtUserType @@ -39,37 +40,7 @@ internal class KtFirReferenceShortener( val firFile = file.getOrBuildFirOfType(firResolveState) val typesToShorten = mutableListOf() - - firFile.acceptChildren(object : FirVisitorVoid() { - override fun visitElement(element: FirElement) { - element.acceptChildren(this) - } - - override fun visitResolvedTypeRef(resolvedTypeRef: FirResolvedTypeRef) { - resolvedTypeRef.acceptChildren(this) - - val wholeTypeReference = resolvedTypeRef.psi as? KtTypeReference ?: return - - val wholeTypeElement = wholeTypeReference.typeElement as? KtUserType ?: return - if (wholeTypeElement.qualifier == null) return - - val wholeClassifierId = resolvedTypeRef.type.classId ?: return - - val allClassIds = generateSequence(wholeClassifierId) { it.outerClassId } - val allTypeElements = generateSequence(wholeTypeElement) { it.qualifier } - - val positionScopes = findScopesAtPosition(wholeTypeReference) ?: return - - for ((classId, typeElement) in allClassIds.zip(allTypeElements)) { - val firstFoundClass = findFirstClassifierInScopesByName(positionScopes, classId.shortClassName) - - if (firstFoundClass == classId) { - typesToShorten.add(typeElement) - break - } - } - } - }) + firFile.acceptChildren(TypesCollectingVisitor(typesToShorten)) return ShortenCommand(file, emptyList(), typesToShorten.map { it.createSmartPointer() }) } @@ -105,10 +76,50 @@ internal class KtFirReferenceShortener( return element } - private fun findScopesAtPosition(targetTypeReference: KtTypeReference): List? { + private fun findScopesAtPosition(targetTypeReference: KtElement): List? { val towerDataContext = firResolveState.getTowerDataContextForElement(targetTypeReference) ?: return null val availableScopes = towerDataContext.towerDataElements.mapNotNull { it.scope } return availableScopes.asReversed() } + + private inner class TypesCollectingVisitor(private val collectedTypes: MutableList) : FirVisitorVoid() { + override fun visitElement(element: FirElement) { + element.acceptChildren(this) + } + + override fun visitResolvedTypeRef(resolvedTypeRef: FirResolvedTypeRef) { + processTypeRef(resolvedTypeRef) + + resolvedTypeRef.acceptChildren(this) + resolvedTypeRef.delegatedTypeRef?.accept(this) + } + + private fun processTypeRef(resolvedTypeRef: FirResolvedTypeRef) { + val wholeTypeReference = resolvedTypeRef.psi as? KtTypeReference ?: return + + val wholeClassifierId = resolvedTypeRef.type.classId ?: return + val wholeTypeElement = wholeTypeReference.typeElement as? KtUserType ?: return + + if (wholeTypeElement.qualifier == null) return + + val typeToShorten = findBiggestClassifierToShorten(wholeClassifierId, wholeTypeElement) ?: return + collectedTypes.add(typeToShorten) + } + + private fun findBiggestClassifierToShorten(wholeClassifierId: ClassId, wholeTypeElement: KtUserType): KtUserType? { + val allClassIds = generateSequence(wholeClassifierId) { it.outerClassId } + val allTypeElements = generateSequence(wholeTypeElement) { it.qualifier } + + val positionScopes = findScopesAtPosition(wholeTypeElement) ?: return null + + for ((classId, typeElement) in allClassIds.zip(allTypeElements)) { + val firstFoundClass = findFirstClassifierInScopesByName(positionScopes, classId.shortClassName) + + if (firstFoundClass == classId) return typeElement + } + + return null + } + } } \ No newline at end of file