diff --git a/idea/src/org/jetbrains/kotlin/idea/quickfix/DeprecatedSymbolUsageFix.kt b/idea/src/org/jetbrains/kotlin/idea/quickfix/DeprecatedSymbolUsageFix.kt index 708d9afa8ac..48130db2f7e 100644 --- a/idea/src/org/jetbrains/kotlin/idea/quickfix/DeprecatedSymbolUsageFix.kt +++ b/idea/src/org/jetbrains/kotlin/idea/quickfix/DeprecatedSymbolUsageFix.kt @@ -44,6 +44,7 @@ import org.jetbrains.kotlin.name.FqNameUnsafe import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.getReceiverExpression +import org.jetbrains.kotlin.psi.psiUtil.isAncestor import org.jetbrains.kotlin.psi.psiUtil.replaced import org.jetbrains.kotlin.psi.psiUtil.siblings import org.jetbrains.kotlin.resolve.BindingContext @@ -104,42 +105,58 @@ public class DeprecatedSymbolUsageFix( val expressionToReplace = qualifiedExpression ?: callExpression val USER_CODE_KEY = Key("USER_CODE") + val FROM_PARAMETER_KEY = Key("FROM_PARAMETER") + val FROM_THIS_KEY = Key("FROM_THIS") val explicitReceiver = qualifiedExpression?.getReceiverExpression() explicitReceiver?.putCopyableUserData(USER_CODE_KEY, Unit) - var thisReplacement = explicitReceiver + explicitReceiver?.putCopyableUserData(FROM_THIS_KEY, Unit) //TODO: infix and operator calls - var (expression, imports) = replaceWith.toExpression(descriptor.getOriginal(), element.getResolutionFacade(), file, project) + var (expression, imports, parameterUsages) = replaceWith.toExpression(descriptor.getOriginal(), element.getResolutionFacade(), file, project) + + //TODO: implicit receiver is not always "this" + //TODO: this@ + for (thisExpression in expression.collectThisExpressions()) { + if (explicitReceiver != null) { + thisExpression.replace(explicitReceiver) + } + else { + thisExpression.putCopyableUserData(FROM_THIS_KEY, Unit) + } + } + + fun argumentForParameter(parameter: ValueParameterDescriptor): JetExpression? { + //TODO: optional parameters + val arguments = resolvedCall.getValueArguments()[parameter] ?: return null //TODO: what if not? vararg? + return arguments.getArguments().firstOrNull()?.getArgumentExpression() //TODO: what if multiple? + } + + //TODO: check if complex expressions are used twice + //TODO: check for dropping complex expressions + for (parameter in descriptor.getValueParameters()) { + val argument = argumentForParameter(parameter) ?: continue + argument.putCopyableUserData(FROM_PARAMETER_KEY, parameter) + argument.putCopyableUserData(USER_CODE_KEY, Unit) + parameterUsages[parameter.getOriginal()]!!.forEach { it.replace(argument) } + } if (qualifiedExpression is JetSafeQualifiedExpression) { fun processSafeCall() { val qualified = expression as? JetQualifiedExpression if (qualified != null) { - val thisReceiver = qualified.getReceiverExpression() as? JetThisExpression - if (thisReceiver != null && thisReceiver.getLabelName() == null) { + if (qualified.getReceiverExpression().getCopyableUserData(FROM_THIS_KEY) != null) { val selector = qualified.getSelectorExpression() if (selector != null) { - expression = psiFactory.createExpressionByPattern("this?.$0", selector) + expression = psiFactory.createExpressionByPattern("$0?.$1", explicitReceiver!!, selector) return } } } if (expressionToReplace.isUsedAsExpression(bindingContext)) { - if (!isNameUsed("it", callExpression, expression)) { - expression = psiFactory.createExpressionByPattern("$0?.let { $1 }", explicitReceiver!!, expression) - thisReplacement = psiFactory.createExpression("it") - } - else { - val nameValidator = object : JetNameValidator() { - override fun validateInner(name: String) = !isNameUsed(name, callExpression, expression) - } - val name = JetNameSuggester.suggestNamesForExpression(explicitReceiver!!, nameValidator, "t").first() - val nameInCode = IdeDescriptorRenderers.SOURCE_CODE.renderName(Name.identifier(name)) - expression = psiFactory.createExpressionByPattern("$0?.let { $nameInCode -> $1 }", explicitReceiver, expression) - thisReplacement = psiFactory.createExpression(nameInCode) - } + val thisReplaced = expression.collectExpressionsWithData(FROM_THIS_KEY, Unit) + expression = expression.introduceValue(explicitReceiver!!, thisReplaced, safeCall = true) } else { expression = psiFactory.createExpressionByPattern("if ($0 != null) { $1 }", explicitReceiver!!, expression) @@ -148,42 +165,6 @@ public class DeprecatedSymbolUsageFix( processSafeCall() } - val parametersByName = descriptor.getValueParameters().toMap { it.getName().asString() } - expression.accept(object : JetVisitorVoid(){ - override fun visitSimpleNameExpression(expression: JetSimpleNameExpression) { - val qualified = expression.getParent() as? JetDotQualifiedExpression - if (qualified != null && expression == qualified.getSelectorExpression()) return - val name = expression.getReferencedName() - val parameter = parametersByName[name] ?: return //TODO: is this always correct? Lambda inside? - val arguments = resolvedCall.getValueArguments()[parameter] ?: return //TODO: what if not? vararg? - val argumentExpression = arguments.getArguments().firstOrNull()?.getArgumentExpression() ?: return //TODO: what if multiple? - argumentExpression.putCopyableUserData(USER_CODE_KEY, Unit) - expression.replace(argumentExpression) - - //TODO: check if complex expressions are used twice - //TODO: check for dropping complex expressions - } - - override fun visitThisExpression(expression: JetThisExpression) { - if (expression.getLabelName() != null) return //TODO - if (thisReplacement != null) { - expression.replace(thisReplacement!!) - } - //TODO: implicit receiver is not always "this" - } - - override fun visitJetElement(element: JetElement) { - // we do not use acceptChildren because it does not work with replacement - var child: PsiElement? = element.getFirstChild() - while (child != null) { - // whitespace may get invalidated on replace - val next = child.siblings(withItself = false).firstOrNull { it !is PsiWhiteSpace } - child.accept(this) - child = next - } - } - }) - var result = expressionToReplace.replaced(expression) //TODO: drop import of old function (if not needed anymore)? @@ -212,6 +193,8 @@ public class DeprecatedSymbolUsageFix( result.accept(object : PsiRecursiveElementVisitor() { override fun visitElement(element: PsiElement) { element.putCopyableUserData(USER_CODE_KEY, null) + element.putCopyableUserData(FROM_PARAMETER_KEY, null) + element.putCopyableUserData(FROM_THIS_KEY, null) } }) @@ -219,7 +202,18 @@ public class DeprecatedSymbolUsageFix( editor?.moveCaret(offset) } - private fun ReplaceWith.toExpression(symbolDescriptor: CallableDescriptor, resolutionFacade: ResolutionFacade, file: JetFile/*TODO: drop it*/, project: Project): Pair> { + private data class ReplacementExpression( + val expression: JetExpression, + val imports: Collection, + val parameterUsages: Map> + ) + + private fun ReplaceWith.toExpression( + symbolDescriptor: CallableDescriptor, + resolutionFacade: ResolutionFacade, + file: JetFile/*TODO: drop it*/, + project: Project + ): ReplacementExpression { val psiFactory = JetPsiFactory(project) var expression = psiFactory.createExpression(expression) @@ -240,6 +234,8 @@ public class DeprecatedSymbolUsageFix( val receiversToAdd = ArrayList>() + val parameterUsageKey = Key("parameterUsageKey") + expression.accept(object : JetVisitorVoid(){ override fun visitSimpleNameExpression(expression: JetSimpleNameExpression) { val target = bindingContext[BindingContext.REFERENCE_TARGET, expression] ?: return @@ -251,6 +247,10 @@ public class DeprecatedSymbolUsageFix( } if (expression.getReceiverExpression() == null) { + if (target is ValueParameterDescriptor && target.getContainingDeclaration() == symbolDescriptor) { + expression.putCopyableUserData(parameterUsageKey, target) + } + val resolvedCall = expression.getResolvedCall(bindingContext) if (resolvedCall != null && resolvedCall.getStatus().isSuccess()) { val receiver = if (resolvedCall.getResultingDescriptor().isExtension) @@ -285,7 +285,17 @@ public class DeprecatedSymbolUsageFix( } } - return expression to importFqNames + val parameterUsages = symbolDescriptor.getValueParameters() + .map { it to expression.collectExpressionsWithData(parameterUsageKey, it) } + .toMap() + + expression.accept(object : PsiRecursiveElementVisitor() { + override fun visitElement(element: PsiElement) { + element.putCopyableUserData(parameterUsageKey, null) + } + }) + + return ReplacementExpression(expression, importFqNames, parameterUsages) } private fun getResolutionScope(descriptor: DeclarationDescriptor): JetScope { @@ -314,21 +324,86 @@ public class DeprecatedSymbolUsageFix( } } - private fun isNameUsed(name: String, vararg inExpressions: JetExpression): Boolean { - var result = false - inExpressions.forEach { - it.accept(object : JetVisitorVoid(){ - override fun visitSimpleNameExpression(expression: JetSimpleNameExpression) { - if (expression.getReferencedName() == name) { - result = true - } - } + private fun JetExpression.introduceValue(value: JetExpression, usages: Collection, safeCall: Boolean): JetExpression { + assert(usages.all { isAncestor(it, strict = true) }) - override fun visitJetElement(element: JetElement) { - element.acceptChildren(this) - } - }) + val psiFactory = JetPsiFactory(this) + + fun nameInCode(name: String) = IdeDescriptorRenderers.SOURCE_CODE.renderName(Name.identifier(name)) + + fun replaceUsages(name: String) { + val nameInCode = psiFactory.createExpression(nameInCode(name)) + for (usage in usages) { + usage.replace(nameInCode) + } } + + val dot = if (safeCall) "?." else "." + + fun isNameUsed(name: String) = collectNameUsages(name).any { nameUsage -> usages.none { it.isAncestor(nameUsage) } } + + if (!isNameUsed("it")) { + replaceUsages("it") + return psiFactory.createExpressionByPattern("$0${dot}let { $1 }", value, this) + } + else { + val nameValidator = object : JetNameValidator() { + override fun validateInner(name: String) = !isNameUsed(name) + } + val name = JetNameSuggester.suggestNamesForExpression(value, nameValidator, "t").first() + replaceUsages(name) + return psiFactory.createExpressionByPattern("$0${dot}let { ${nameInCode(name)} -> $1 }", value, this) + } + } + + private fun JetExpression.collectExpressionsWithData(key: Key, value: T): Collection { + val result = ArrayList() + this.accept(object : JetVisitorVoid(){ + override fun visitExpression(expression: JetExpression) { + if (expression.getCopyableUserData(key) == value) { + result.add(expression) + } + else { + super.visitExpression(expression) + } + } + + override fun visitJetElement(element: JetElement) { + element.acceptChildren(this) + } + }) + return result + } + + private fun JetExpression.collectThisExpressions(): Collection { + val result = ArrayList() + this.accept(object : JetVisitorVoid(){ + override fun visitThisExpression(expression: JetThisExpression) { + if (expression.getLabelName() == null) { + result.add(expression) + } + } + + override fun visitJetElement(element: JetElement) { + element.acceptChildren(this) + } + }) + return result + } + + private fun JetExpression.collectNameUsages(name: String): ArrayList { + val result = ArrayList() + this.accept(object : JetVisitorVoid(){ + override fun visitSimpleNameExpression(expression: JetSimpleNameExpression) { + if (expression.getReceiverExpression() == null && expression.getReferencedName() == name) { + result.add(expression) + } + } + + override fun visitJetElement(element: JetElement) { + element.acceptChildren(this) + } + }) return result }