diff --git a/idea/src/org/jetbrains/kotlin/idea/inspections/branchedTransformations/IfThenToSafeAccessInspection.kt b/idea/src/org/jetbrains/kotlin/idea/inspections/branchedTransformations/IfThenToSafeAccessInspection.kt index 6ea1b614163..d584a9927f0 100644 --- a/idea/src/org/jetbrains/kotlin/idea/inspections/branchedTransformations/IfThenToSafeAccessInspection.kt +++ b/idea/src/org/jetbrains/kotlin/idea/inspections/branchedTransformations/IfThenToSafeAccessInspection.kt @@ -104,6 +104,8 @@ private fun IfThenToSelectData.clausesReplaceableBySafeCall(): Boolean = when { context.diagnostics.forElement(condition) .any { it.factory == Errors.SENSELESS_COMPARISON || it.factory == Errors.USELESS_IS_CHECK } -> false baseClause.evaluatesTo(receiverExpression) -> true + (baseClause as? KtCallExpression)?.calleeExpression?.evaluatesTo(receiverExpression) == true + && baseClause.isCallingInvokeFunction(context) -> true baseClause.hasFirstReceiverOf(receiverExpression) -> withoutResultInCallChain(baseClause, context) baseClause.anyArgumentEvaluatesTo(receiverExpression) -> true receiverExpression is KtThisExpression -> getImplicitReceiver()?.let { it.type == receiverExpression.getType(context) } == true diff --git a/idea/src/org/jetbrains/kotlin/idea/intentions/branchedTransformations/IfThenUtils.kt b/idea/src/org/jetbrains/kotlin/idea/intentions/branchedTransformations/IfThenUtils.kt index f7532c08e2f..d0ec1386765 100644 --- a/idea/src/org/jetbrains/kotlin/idea/intentions/branchedTransformations/IfThenUtils.kt +++ b/idea/src/org/jetbrains/kotlin/idea/intentions/branchedTransformations/IfThenUtils.kt @@ -12,12 +12,15 @@ import com.intellij.openapi.util.TextRange import com.intellij.psi.search.LocalSearchScope import com.intellij.psi.search.searches.ReferencesSearch import org.jetbrains.kotlin.KtNodeTypes +import org.jetbrains.kotlin.builtins.functions.FunctionInvokeDescriptor +import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor import org.jetbrains.kotlin.idea.caches.resolve.analyze import org.jetbrains.kotlin.idea.caches.resolve.findModuleDescriptor import org.jetbrains.kotlin.idea.caches.resolve.getResolutionFacade import org.jetbrains.kotlin.idea.caches.resolve.resolveToCall import org.jetbrains.kotlin.idea.core.KotlinNameSuggester import org.jetbrains.kotlin.idea.core.replaced +import org.jetbrains.kotlin.idea.intentions.callExpression import org.jetbrains.kotlin.idea.intentions.getLeftMostReceiverExpression import org.jetbrains.kotlin.idea.intentions.replaceFirstReceiver import org.jetbrains.kotlin.idea.refactoring.inline.KotlinInlineValHandler @@ -242,12 +245,36 @@ data class IfThenToSelectData( receiverExpression, baseClause ).insertSafeCalls(factory) - baseClause is KtCallExpression -> baseClause.replaceCallWithLet(receiverExpression, factory) - else -> baseClause.insertSafeCalls(factory) + baseClause is KtCallExpression -> { + val callee = baseClause.calleeExpression + if (callee != null && baseClause.isCallingInvokeFunction(context)) { + factory.createExpressionByPattern("$0?.invoke()", callee) + } else { + baseClause.replaceCallWithLet(receiverExpression, factory) + } + } + else -> { + var replaced = baseClause.insertSafeCalls(factory) + if (replaced is KtQualifiedExpression) { + val call = replaced.callExpression + val callee = call?.calleeExpression + if (callee != null && call.isCallingInvokeFunction(context)) { + replaced = factory.createExpressionByPattern("$0?.${callee.text}?.invoke()", replaced.receiverExpression) + } + } + replaced + } } } } + internal fun KtExpression.isCallingInvokeFunction(context: BindingContext): Boolean { + if (this !is KtCallExpression) return false + val resolvedCall = getResolvedCall(context) ?: resolveToCall() ?: return false + val descriptor = resolvedCall.resultingDescriptor as? SimpleFunctionDescriptor ?: return false + return descriptor is FunctionInvokeDescriptor || descriptor.isOperator && descriptor.name.asString() == "invoke" + } + internal fun getImplicitReceiver(): ImplicitReceiver? { val resolvedCall = baseClause.getResolvedCall(context) ?: return null if (resolvedCall.getExplicitReceiverValue() != null) return null diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt new file mode 100644 index 00000000000..08d3c4e6b2e --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt @@ -0,0 +1,8 @@ +// HIGHLIGHT: INFORMATION +fun test(foo: Foo?) { + if (foo != null) foo() +} + +class Foo { + operator fun invoke() {} +} \ No newline at end of file diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt.after b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt.after new file mode 100644 index 00000000000..88c18c3a9a0 --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt.after @@ -0,0 +1,8 @@ +// HIGHLIGHT: INFORMATION +fun test(foo: Foo?) { + foo?.invoke() +} + +class Foo { + operator fun invoke() {} +} \ No newline at end of file diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt new file mode 100644 index 00000000000..f70bbc6c25b --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt @@ -0,0 +1,11 @@ +class Foo(val bar: Bar) + +class Bar { + operator fun invoke() {} +} + +fun test(foo: Foo?) { + if (foo != null) { + foo.bar() + } +} diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt.after b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt.after new file mode 100644 index 00000000000..eff2fd7e317 --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt.after @@ -0,0 +1,9 @@ +class Foo(val bar: Bar) + +class Bar { + operator fun invoke() {} +} + +fun test(foo: Foo?) { + foo?.bar?.invoke() +} diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt new file mode 100644 index 00000000000..62b4f34d582 --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt @@ -0,0 +1,4 @@ +// HIGHLIGHT: INFORMATION +fun test(foo: (() -> Unit)?) { + if (foo != null) foo() +} \ No newline at end of file diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt.after b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt.after new file mode 100644 index 00000000000..b0d96e75a07 --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt.after @@ -0,0 +1,4 @@ +// HIGHLIGHT: INFORMATION +fun test(foo: (() -> Unit)?) { + foo?.invoke() +} \ No newline at end of file diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt new file mode 100644 index 00000000000..825925928e8 --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt @@ -0,0 +1,7 @@ +class Foo(val f: () -> Unit) + +fun test(foo: Foo?) { + if (foo != null) { + foo.f() + } +} \ No newline at end of file diff --git a/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt.after b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt.after new file mode 100644 index 00000000000..75d61e92dab --- /dev/null +++ b/idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt.after @@ -0,0 +1,5 @@ +class Foo(val f: () -> Unit) + +fun test(foo: Foo?) { + foo?.f?.invoke() +} \ No newline at end of file diff --git a/idea/tests/org/jetbrains/kotlin/idea/inspections/LocalInspectionTestGenerated.java b/idea/tests/org/jetbrains/kotlin/idea/inspections/LocalInspectionTestGenerated.java index 948d1fb1c55..c1280dff6d1 100644 --- a/idea/tests/org/jetbrains/kotlin/idea/inspections/LocalInspectionTestGenerated.java +++ b/idea/tests/org/jetbrains/kotlin/idea/inspections/LocalInspectionTestGenerated.java @@ -522,6 +522,26 @@ public class LocalInspectionTestGenerated extends AbstractLocalInspectionTest { runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/call4.kt"); } + @TestMetadata("callInvokeOperator.kt") + public void testCallInvokeOperator() throws Exception { + runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator.kt"); + } + + @TestMetadata("callInvokeOperator2.kt") + public void testCallInvokeOperator2() throws Exception { + runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callInvokeOperator2.kt"); + } + + @TestMetadata("callVariable.kt") + public void testCallVariable() throws Exception { + runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable.kt"); + } + + @TestMetadata("callVariable2.kt") + public void testCallVariable2() throws Exception { + runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/callVariable2.kt"); + } + @TestMetadata("conditionComparesNullWithNull.kt") public void testConditionComparesNullWithNull() throws Exception { runTest("idea/testData/inspectionsLocal/branched/ifThenToSafeAccess/conditionComparesNullWithNull.kt");