FIR IDE: Add simple types importing

Some tests are not passing
This commit is contained in:
Roman Golyshev
2020-12-22 15:34:21 +03:00
committed by Space
parent 0609aa1e2e
commit e34370554d
17 changed files with 205 additions and 34 deletions
@@ -46,6 +46,16 @@ public class FirShortenRefsTestGenerated extends AbstractFirShortenRefsTest {
runTest("idea/testData/shortenRefsFir/types/ParameterType.kt");
}
@TestMetadata("ParameterTypeConflictingTopLevelClassNotUsed.kt")
public void testParameterTypeConflictingTopLevelClassNotUsed() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeConflictingTopLevelClassNotUsed.kt");
}
@TestMetadata("ParameterTypeConflictingTopLevelClassUsed.kt")
public void testParameterTypeConflictingTopLevelClassUsed() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeConflictingTopLevelClassUsed.kt");
}
@TestMetadata("ParameterTypeFunctionalType.kt")
public void testParameterTypeFunctionalType() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeFunctionalType.kt");
@@ -66,6 +76,11 @@ public class FirShortenRefsTestGenerated extends AbstractFirShortenRefsTest {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeNestedType.kt");
}
@TestMetadata("ParameterTypeNonImportedClass.kt")
public void testParameterTypeNonImportedClass() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeNonImportedClass.kt");
}
@TestMetadata("ParameterTypeStarImportedTypeLoses.kt")
public void testParameterTypeStarImportedTypeLoses() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeStarImportedTypeLoses.kt");
@@ -81,6 +96,11 @@ public class FirShortenRefsTestGenerated extends AbstractFirShortenRefsTest {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeTopLevelTypeWins.kt");
}
@TestMetadata("ParameterTypeTwoNonImportedClassesConflict.kt")
public void testParameterTypeTwoNonImportedClassesConflict() throws Exception {
runTest("idea/testData/shortenRefsFir/types/ParameterTypeTwoNonImportedClassesConflict.kt");
}
@TestMetadata("VariableType.kt")
public void testVariableType() throws Exception {
runTest("idea/testData/shortenRefsFir/types/VariableType.kt");
@@ -5,27 +5,12 @@
package org.jetbrains.kotlin.idea.frontend.api.components
import com.intellij.openapi.application.ApplicationManager
import com.intellij.psi.SmartPsiElementPointer
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtUserType
abstract class KtReferenceShortener : KtAnalysisSessionComponent() {
abstract fun collectShortenings(file: KtFile, from: Int, to: Int): ShortenCommand
}
class ShortenCommand(
val targetFile: KtFile,
val importsToAdd: List<FqName>,
val typesToShorten: List<SmartPsiElementPointer<KtUserType>>
) {
fun invokeShortening() {
ApplicationManager.getApplication().assertWriteAccessAllowed()
for (typePointer in typesToShorten) {
val type = typePointer.element ?: continue
type.deleteQualifier()
}
}
interface ShortenCommand {
fun invokeShortening()
}
@@ -5,6 +5,9 @@
package org.jetbrains.kotlin.idea.frontend.api.fir.components
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.project.Project
import com.intellij.psi.SmartPsiElementPointer
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.declarations.FirFile
import org.jetbrains.kotlin.fir.psi
@@ -19,15 +22,14 @@ import org.jetbrains.kotlin.idea.fir.low.level.api.api.FirModuleResolveState
import org.jetbrains.kotlin.idea.fir.low.level.api.api.getOrBuildFir
import org.jetbrains.kotlin.idea.fir.low.level.api.api.getOrBuildFirOfType
import org.jetbrains.kotlin.idea.frontend.api.ValidityToken
import org.jetbrains.kotlin.idea.frontend.api.components.KtReferenceShortener
import org.jetbrains.kotlin.idea.frontend.api.components.ShortenCommand
import org.jetbrains.kotlin.idea.frontend.api.components.KtReferenceShortener
import org.jetbrains.kotlin.idea.frontend.api.fir.KtFirAnalysisSession
import org.jetbrains.kotlin.idea.frontend.api.fir.utils.addImportToFile
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.kotlin.psi.KtUserType
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.createSmartPointer
internal class KtFirReferenceShortener(
@@ -39,10 +41,12 @@ internal class KtFirReferenceShortener(
resolveFileToBodyResolve(file)
val firFile = file.getOrBuildFirOfType<FirFile>(firResolveState)
val typesToImport = mutableListOf<FqName>()
val typesToShorten = mutableListOf<KtUserType>()
firFile.acceptChildren(TypesCollectingVisitor(typesToShorten))
return ShortenCommand(file, emptyList(), typesToShorten.map { it.createSmartPointer() })
firFile.acceptChildren(TypesCollectingVisitor(typesToImport, typesToShorten))
return ShortenCommandImpl(file, typesToImport, typesToShorten.map { it.createSmartPointer() })
}
private fun findFirstClassifierInScopesByName(positionScopes: List<FirScope>, targetClassName: Name): ClassId? {
@@ -56,14 +60,12 @@ internal class KtFirReferenceShortener(
return null
}
private fun resolveFileToBodyResolve(file: KtFile) {
for (declaration in file.declarations) {
declaration.getOrBuildFir(firResolveState) // temporary hack, resolves declaration to BODY_RESOLVE stage
}
}
@OptIn(ExperimentalStdlibApi::class)
private fun FirScope.findFirstClassifierByName(name: Name): FirClassifierSymbol<*>? {
var element: FirClassifierSymbol<*>? = null
@@ -83,7 +85,10 @@ internal class KtFirReferenceShortener(
return availableScopes.asReversed()
}
private inner class TypesCollectingVisitor(private val collectedTypes: MutableList<KtUserType>) : FirVisitorVoid() {
private inner class TypesCollectingVisitor(
private val typesToImport: MutableList<FqName>,
private val typesToShorten: MutableList<KtUserType>,
) : FirVisitorVoid() {
override fun visitElement(element: FirElement) {
element.acceptChildren(this)
}
@@ -103,23 +108,62 @@ internal class KtFirReferenceShortener(
if (wholeTypeElement.qualifier == null) return
val typeToShorten = findBiggestClassifierToShorten(wholeClassifierId, wholeTypeElement) ?: return
collectedTypes.add(typeToShorten)
collectTypeIfNeedsToBeShortened(wholeClassifierId, wholeTypeElement)
}
private fun findBiggestClassifierToShorten(wholeClassifierId: ClassId, wholeTypeElement: KtUserType): KtUserType? {
private fun collectTypeIfNeedsToBeShortened(wholeClassifierId: ClassId, wholeTypeElement: KtUserType) {
val allClassIds = generateSequence(wholeClassifierId) { it.outerClassId }
val allTypeElements = generateSequence(wholeTypeElement) { it.qualifier }
val positionScopes = findScopesAtPosition(wholeTypeElement) ?: return null
val positionScopes = findScopesAtPosition(wholeTypeElement) ?: return
for ((classId, typeElement) in allClassIds.zip(allTypeElements)) {
val firstFoundClass = findFirstClassifierInScopesByName(positionScopes, classId.shortClassName)
if (firstFoundClass == classId) return typeElement
if (firstFoundClass == classId) {
addTypeToShorten(typeElement)
return
}
}
return null
// none class matched
val (mostTopLevelClassId, mostTopLevelTypeElement) = allClassIds.zip(allTypeElements).last()
val firstFoundClass = findFirstClassifierInScopesByName(positionScopes, mostTopLevelClassId.shortClassName)
check(firstFoundClass != mostTopLevelClassId) { "This should not be true" }
if (firstFoundClass == null) {
addTypeToImportAndShorten(mostTopLevelClassId.asSingleFqName(), mostTopLevelTypeElement)
}
}
private fun addTypeToShorten(typeElement: KtUserType) {
typesToShorten.add(typeElement)
}
private fun addTypeToImportAndShorten(classFqName: FqName, mostTopLevelTypeElement: KtUserType) {
typesToImport.add(classFqName)
typesToShorten.add(mostTopLevelTypeElement)
}
}
}
}
private class ShortenCommandImpl(
val targetFile: KtFile,
val importsToAdd: List<FqName>,
val typesToShorten: List<SmartPsiElementPointer<KtUserType>>
) : ShortenCommand {
override fun invokeShortening() {
ApplicationManager.getApplication().assertWriteAccessAllowed()
for (nameToImport in importsToAdd) {
addImportToFile(targetFile.project, targetFile, nameToImport)
}
for (typePointer in typesToShorten) {
val type = typePointer.element ?: continue
type.deleteQualifier()
}
}
}
@@ -0,0 +1,57 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.frontend.api.fir.utils
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtCodeFragment
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtImportDirective
import org.jetbrains.kotlin.psi.KtPsiFactory
import org.jetbrains.kotlin.resolve.ImportPath
private val SimpleImportPathComparator: Comparator<ImportPath> = compareBy(ImportPath::toString)
/**
* This is a partial copy from `org.jetbrains.kotlin.idea.util.ImportInsertHelperImpl.Companion.addImport`.
*
* We want it as a copy because we do not yet care about imports ordering, so we do not need a fancy comparator.
*/
internal fun addImportToFile(
project: Project,
file: KtFile,
fqName: FqName,
allUnder: Boolean = false,
alias: Name? = null
) {
val importPath = ImportPath(fqName, allUnder, alias)
val psiFactory = KtPsiFactory(project)
if (file is KtCodeFragment) {
val newDirective = psiFactory.createImportDirective(importPath)
file.addImportsFromString(newDirective.text)
}
val importList = file.importList
?: error("Trying to insert import $fqName into a file ${file.name} of type ${file::class.java} with no import list.")
val newDirective = psiFactory.createImportDirective(importPath)
val imports = importList.imports
if (imports.isEmpty()) { //TODO: strange hack
importList.add(psiFactory.createNewLine())
importList.add(newDirective)
} else {
val insertAfter = imports
.lastOrNull {
val directivePath = it.importPath
directivePath != null && SimpleImportPathComparator.compare(directivePath, importPath) <= 0
}
importList.addAfter(newDirective, insertAfter)
}
}
@@ -0,0 +1,3 @@
package dependency
class Foo
@@ -0,0 +1,6 @@
// FIR_COMPARISON
package test
class Foo
<selection>fun foo(p: dependency.Foo) {}</selection>
@@ -0,0 +1,8 @@
// FIR_COMPARISON
package test
import dependency.Foo
class Foo
fun foo(p: Foo) {}
@@ -0,0 +1,3 @@
package dependency
class Foo
@@ -0,0 +1,8 @@
// FIR_COMPARISON
package test
class Foo
<selection>fun foo(p: dependency.Foo) {}</selection>
fun bar(): Foo {}
@@ -0,0 +1,8 @@
// FIR_COMPARISON
package test
class Foo
fun foo(p: dependency.Foo) {}
fun bar(): Foo {}
@@ -0,0 +1,3 @@
package dependency
class T
@@ -0,0 +1,4 @@
// FIR_COMPARISON
package test
<selection>fun foo(p: dependency.T) {}</selection>
@@ -0,0 +1,6 @@
// FIR_COMPARISON
package test
import dependency.T
fun foo(p: T) {}
@@ -0,0 +1,3 @@
package dependency1
class T
@@ -0,0 +1,3 @@
package dependency2
class T
@@ -0,0 +1,4 @@
// FIR_COMPARISON
package test
<selection>fun foo(p1: dependency1.T, p2: dependency2.T) {}</selection>
@@ -0,0 +1,6 @@
// FIR_COMPARISON
package test
import dependency1.T
fun foo(p1: T, p2: dependency2.T) {}