IfThenToSafeAccessInspection: fix it works correctly for variable/operator call

#KT-30513 Fixed
#KT-17071 Fixed
This commit is contained in:
Toshiaki Kameyama
2020-01-06 17:49:29 +09:00
committed by Yan Zhulanow
parent 871ad2b909
commit 5095caee50
11 changed files with 107 additions and 2 deletions
@@ -104,6 +104,8 @@ private fun IfThenToSelectData.clausesReplaceableBySafeCall(): Boolean = when {
context.diagnostics.forElement(condition)
.any { it.factory == Errors.SENSELESS_COMPARISON || it.factory == Errors.USELESS_IS_CHECK } -> false
baseClause.evaluatesTo(receiverExpression) -> true
(baseClause as? KtCallExpression)?.calleeExpression?.evaluatesTo(receiverExpression) == true
&& baseClause.isCallingInvokeFunction(context) -> true
baseClause.hasFirstReceiverOf(receiverExpression) -> withoutResultInCallChain(baseClause, context)
baseClause.anyArgumentEvaluatesTo(receiverExpression) -> true
receiverExpression is KtThisExpression -> getImplicitReceiver()?.let { it.type == receiverExpression.getType(context) } == true
@@ -12,12 +12,15 @@ import com.intellij.openapi.util.TextRange
import com.intellij.psi.search.LocalSearchScope
import com.intellij.psi.search.searches.ReferencesSearch
import org.jetbrains.kotlin.KtNodeTypes
import org.jetbrains.kotlin.builtins.functions.FunctionInvokeDescriptor
import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor
import org.jetbrains.kotlin.idea.caches.resolve.analyze
import org.jetbrains.kotlin.idea.caches.resolve.findModuleDescriptor
import org.jetbrains.kotlin.idea.caches.resolve.getResolutionFacade
import org.jetbrains.kotlin.idea.caches.resolve.resolveToCall
import org.jetbrains.kotlin.idea.core.KotlinNameSuggester
import org.jetbrains.kotlin.idea.core.replaced
import org.jetbrains.kotlin.idea.intentions.callExpression
import org.jetbrains.kotlin.idea.intentions.getLeftMostReceiverExpression
import org.jetbrains.kotlin.idea.intentions.replaceFirstReceiver
import org.jetbrains.kotlin.idea.refactoring.inline.KotlinInlineValHandler
@@ -242,12 +245,36 @@ data class IfThenToSelectData(
receiverExpression,
baseClause
).insertSafeCalls(factory)
baseClause is KtCallExpression -> baseClause.replaceCallWithLet(receiverExpression, factory)
else -> baseClause.insertSafeCalls(factory)
baseClause is KtCallExpression -> {
val callee = baseClause.calleeExpression
if (callee != null && baseClause.isCallingInvokeFunction(context)) {
factory.createExpressionByPattern("$0?.invoke()", callee)
} else {
baseClause.replaceCallWithLet(receiverExpression, factory)
}
}
else -> {
var replaced = baseClause.insertSafeCalls(factory)
if (replaced is KtQualifiedExpression) {
val call = replaced.callExpression
val callee = call?.calleeExpression
if (callee != null && call.isCallingInvokeFunction(context)) {
replaced = factory.createExpressionByPattern("$0?.${callee.text}?.invoke()", replaced.receiverExpression)
}
}
replaced
}
}
}
}
internal fun KtExpression.isCallingInvokeFunction(context: BindingContext): Boolean {
if (this !is KtCallExpression) return false
val resolvedCall = getResolvedCall(context) ?: resolveToCall() ?: return false
val descriptor = resolvedCall.resultingDescriptor as? SimpleFunctionDescriptor ?: return false
return descriptor is FunctionInvokeDescriptor || descriptor.isOperator && descriptor.name.asString() == "invoke"
}
internal fun getImplicitReceiver(): ImplicitReceiver? {
val resolvedCall = baseClause.getResolvedCall(context) ?: return null
if (resolvedCall.getExplicitReceiverValue() != null) return null
@@ -0,0 +1,8 @@
// HIGHLIGHT: INFORMATION
fun test(foo: Foo?) {
<caret>if (foo != null) foo()
}
class Foo {
operator fun invoke() {}
}
@@ -0,0 +1,8 @@
// HIGHLIGHT: INFORMATION
fun test(foo: Foo?) {
foo?.invoke()
}
class Foo {
operator fun invoke() {}
}
@@ -0,0 +1,11 @@
class Foo(val bar: Bar)
class Bar {
operator fun invoke() {}
}
fun test(foo: Foo?) {
<caret>if (foo != null) {
foo.bar()
}
}
@@ -0,0 +1,9 @@
class Foo(val bar: Bar)
class Bar {
operator fun invoke() {}
}
fun test(foo: Foo?) {
foo?.bar?.invoke()
}
@@ -0,0 +1,4 @@
// HIGHLIGHT: INFORMATION
fun test(foo: (() -> Unit)?) {
<caret>if (foo != null) foo()
}
@@ -0,0 +1,4 @@
// HIGHLIGHT: INFORMATION
fun test(foo: (() -> Unit)?) {
foo?.invoke()
}
@@ -0,0 +1,7 @@
class Foo(val f: () -> Unit)
fun test(foo: Foo?) {
<caret>if (foo != null) {
foo.f()
}
}
@@ -0,0 +1,5 @@
class Foo(val f: () -> Unit)
fun test(foo: Foo?) {
foo?.f?.invoke()
}
@@ -522,6 +522,26 @@ public class LocalInspectionTestGenerated extends AbstractLocalInspectionTest {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/call4.kt");
}
@TestMetadata("callInvokeOperator.kt")
public void testCallInvokeOperator() throws Exception {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt");
}
@TestMetadata("callInvokeOperator2.kt")
public void testCallInvokeOperator2() throws Exception {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt");
}
@TestMetadata("callVariable.kt")
public void testCallVariable() throws Exception {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt");
}
@TestMetadata("callVariable2.kt")
public void testCallVariable2() throws Exception {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt");
}
@TestMetadata("conditionComparesNullWithNull.kt")
public void testConditionComparesNullWithNull() throws Exception {
runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/conditionComparesNullWithNull.kt");