diff --git a/j2k/j2k.iml b/j2k/j2k.iml index c76077ab780..41c3c2ee98c 100644 --- a/j2k/j2k.iml +++ b/j2k/j2k.iml @@ -7,10 +7,10 @@ - + diff --git a/j2k/src/org/jetbrains/jet/j2k/Converter.kt b/j2k/src/org/jetbrains/jet/j2k/Converter.kt index 793fb9ff3d8..63c20ab46be 100644 --- a/j2k/src/org/jetbrains/jet/j2k/Converter.kt +++ b/j2k/src/org/jetbrains/jet/j2k/Converter.kt @@ -221,7 +221,7 @@ public class Converter(val project: Project, val settings: ConverterSettings) { private fun convertMethod(method: PsiMethod, membersToRemove: MutableSet): Function { methodReturnType = method.getReturnType() - val returnType = convertType(method.getReturnType(), method.nullabilityFromAnnotations()) + val returnType = convertMethodReturnType(method) val modifiers = convertModifiers(method) @@ -275,6 +275,38 @@ public class Converter(val project: Project, val settings: ConverterSettings) { } } + private fun convertMethodReturnType(method: PsiMethod): Type { + var nullability = method.nullabilityFromAnnotations() + + if (nullability == Nullability.Default) { + var isInAnonymousClass = false + method.getBody()?.accept(object: JavaRecursiveElementVisitor() { + override fun visitAnonymousClass(aClass: PsiAnonymousClass) { + isInAnonymousClass = true + super.visitAnonymousClass(aClass) + isInAnonymousClass = false + } + + override fun visitReturnStatement(statement: PsiReturnStatement) { + if (!isInAnonymousClass && statement.getReturnValue()?.nullability() == Nullability.Nullable) { + nullability = Nullability.Nullable + } + } + }) + } + + if (nullability == Nullability.Default) { + val scope = searchScope(method) + if (scope != null) { + if (findMethodCalls(method, scope).any { isNullableFromUsage(it) }) { + nullability = Nullability.Nullable + } + } + } + + return convertType(method.getReturnType(), nullability) + } + /** * Overrides of methods from Object should not be marked as overrides in Kotlin unless the class itself has java ancestors */ @@ -374,7 +406,7 @@ public class Converter(val project: Project, val settings: ConverterSettings) { private fun findBackingFieldForConstructorParameter(parameter: PsiParameter, constructor: PsiMethod): Pair? { val body = constructor.getBody() ?: return null - val refs = findVariableReferences(parameter, body) + val refs = findVariableUsages(parameter, body) if (refs.any { PsiUtil.isAccessedForWriting(it) }) return null @@ -393,7 +425,7 @@ public class Converter(val project: Project, val settings: ConverterSettings) { if (statement.getParent() != body) continue // and no other assignments to field should exist in the constructor - if (findVariableReferences(field, body).any { it != assignee && PsiUtil.isAccessedForWriting(it) && isQualifierEmptyOrThis(it) }) continue + if (findVariableUsages(field, body).any { it != assignee && PsiUtil.isAccessedForWriting(it) && isQualifierEmptyOrThis(it) }) continue //TODO: check access to field before assignment return field to statement @@ -486,22 +518,43 @@ public class Converter(val project: Project, val settings: ConverterSettings) { } if (nullability == Nullability.Default) { - val scope = usageScope(variable) + val scope = searchScope(variable) if (scope != null) { - if (findVariableReferences(variable, scope).any { isVariableNullableFromUsage(it) }) { + if (findVariableUsages(variable, scope).any { isNullableFromUsage(it) }) { nullability = Nullability.Nullable } } } + if (nullability == Nullability.Default && variable is PsiParameter) { + val method = variable.getDeclarationScope() as? PsiMethod + if (method != null) { + val scope = searchScope(method) + if (scope != null) { + val parameters = method.getParameterList().getParameters() + val parameterIndex = parameters.indexOf(variable) + for (call in findMethodCalls(method, scope)) { + val args = call.getArgumentList().getExpressions() + if (args.size == parameters.size) { + if (args[parameterIndex].nullability() == Nullability.Nullable) { + nullability = Nullability.Nullable + break + } + } + } + } + } + } + return convertType(variable.getType(), nullability) } - private fun usageScope(variable: PsiVariable): PsiElement? { - return when(variable) { - is PsiParameter -> variable.getDeclarationScope() - is PsiField -> if (variable.hasModifierProperty(PsiModifier.PRIVATE)) variable.getContainingClass() else variable.getContainingFile() - is PsiLocalVariable -> variable.getContainingMethod() + private fun searchScope(element: PsiElement): PsiElement? { + return when(element) { + is PsiParameter -> element.getDeclarationScope() + is PsiField -> if (element.hasModifierProperty(PsiModifier.PRIVATE)) element.getContainingClass() else element.getContainingFile() + is PsiMethod -> if (element.hasModifierProperty(PsiModifier.PRIVATE)) element.getContainingClass() else element.getContainingFile() + is PsiLocalVariable -> element.getContainingMethod() else -> null } } @@ -529,15 +582,15 @@ public class Converter(val project: Project, val settings: ConverterSettings) { } } - private fun isVariableNullableFromUsage(ref: PsiReferenceExpression): Boolean { - val parent = ref.getParent() ?: return false - if (parent is PsiAssignmentExpression && parent.getOperationTokenType() == JavaTokenType.EQ && ref == parent.getLExpression()) { + private fun isNullableFromUsage(usage: PsiExpression): Boolean { + val parent = usage.getParent() ?: return false + if (parent is PsiAssignmentExpression && parent.getOperationTokenType() == JavaTokenType.EQ && usage == parent.getLExpression()) { return parent.getRExpression()?.nullability() == Nullability.Nullable } else if (parent is PsiBinaryExpression) { val operationType = parent.getOperationTokenType() if (operationType == JavaTokenType.EQEQ || operationType == JavaTokenType.NE) { - val otherOperand = if (ref == parent.getLOperand()) parent.getROperand() else parent.getLOperand() + val otherOperand = if (usage == parent.getLOperand()) parent.getROperand() else parent.getLOperand() return otherOperand?.nullability() == Nullability.Nullable } } diff --git a/j2k/src/org/jetbrains/jet/j2k/Utils.kt b/j2k/src/org/jetbrains/jet/j2k/Utils.kt index 3ae7152d322..f2a968372f5 100644 --- a/j2k/src/org/jetbrains/jet/j2k/Utils.kt +++ b/j2k/src/org/jetbrains/jet/j2k/Utils.kt @@ -19,32 +19,32 @@ package org.jetbrains.jet.j2k import org.jetbrains.jet.j2k.ast.Identifier import org.jetbrains.jet.j2k.ast.Field import org.jetbrains.jet.lang.types.expressions.OperatorConventions -import java.util.ArrayList import org.jetbrains.jet.j2k.ast.Nullability import com.intellij.psi.* import com.intellij.psi.util.PsiUtil +import com.intellij.psi.search.LocalSearchScope +import com.intellij.psi.search.searches.ReferencesSearch fun quoteKeywords(packageName: String): String = packageName.split("\\.").map { Identifier(it).toKotlin() }.makeString(".") -fun findVariableReferences(variable: PsiVariable, scope: PsiElement): Collection { - class Visitor : JavaRecursiveElementVisitor() { - val refs = ArrayList() +fun findVariableUsages(variable: PsiVariable, scope: PsiElement): Collection { + return ReferencesSearch.search(variable, LocalSearchScope(scope)).findAll().filterIsInstance(javaClass()) +} - override fun visitReferenceExpression(expression: PsiReferenceExpression) { - super.visitReferenceExpression(expression) - if (expression.isReferenceTo(variable)) { - refs.add(expression) - } +fun findMethodCalls(method: PsiMethod, scope: PsiElement): Collection { + return ReferencesSearch.search(method, LocalSearchScope(scope)).findAll().map { + if (it is PsiReferenceExpression) { + val methodCall = it.getParent() as? PsiMethodCallExpression + if (methodCall?.getMethodExpression() == it) methodCall else null } - } - - val visitor = Visitor() - scope.accept(visitor) - return visitor.refs + else { + null + } + }.filterNotNull() } fun PsiVariable.countWriteAccesses(scope: PsiElement?): Int - = if (scope != null) findVariableReferences(this, scope).count { PsiUtil.isAccessedForWriting(it) } else 0 + = if (scope != null) findVariableUsages(this, scope).count { PsiUtil.isAccessedForWriting(it) } else 0 fun PsiModifierListOwner.nullabilityFromAnnotations(): Nullability { val annotations = getModifierList()?.getAnnotations() ?: return Nullability.Default diff --git a/j2k/tests/test/org/jetbrains/jet/j2k/test/JavaToKotlinConverterTestGenerated.java b/j2k/tests/test/org/jetbrains/jet/j2k/test/JavaToKotlinConverterTestGenerated.java index d5abbf8f57b..6b74bb3104b 100644 --- a/j2k/tests/test/org/jetbrains/jet/j2k/test/JavaToKotlinConverterTestGenerated.java +++ b/j2k/tests/test/org/jetbrains/jet/j2k/test/JavaToKotlinConverterTestGenerated.java @@ -1810,6 +1810,41 @@ public class JavaToKotlinConverterTestGenerated extends AbstractJavaToKotlinConv doTest("j2k/tests/testData/ast/nullability/FieldInitializedWithNull.java"); } + @TestMetadata("MethodInvokedWithNullArg.java") + public void testMethodInvokedWithNullArg() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.java"); + } + + @TestMetadata("MethodInvokedWithNullArg2.java") + public void testMethodInvokedWithNullArg2() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.java"); + } + + @TestMetadata("MethodInvokedWithTernaryNullArg.java") + public void testMethodInvokedWithTernaryNullArg() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.java"); + } + + @TestMetadata("MethodResultComparedWithNull.java") + public void testMethodResultComparedWithNull() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.java"); + } + + @TestMetadata("MethodReturnsNull.java") + public void testMethodReturnsNull() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodReturnsNull.java"); + } + + @TestMetadata("MethodReturnsNullInAnonymousClass.java") + public void testMethodReturnsNullInAnonymousClass() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.java"); + } + + @TestMetadata("MethodReturnsTernaryNull.java") + public void testMethodReturnsTernaryNull() throws Exception { + doTest("j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.java"); + } + @TestMetadata("ParameterComparedWithNull.java") public void testParameterComparedWithNull() throws Exception { doTest("j2k/tests/testData/ast/nullability/ParameterComparedWithNull.java"); diff --git a/j2k/tests/testData/ast/kotlinExclusion/kt-656.kt b/j2k/tests/testData/ast/kotlinExclusion/kt-656.kt index c01283f0bc7..c937addb228 100644 --- a/j2k/tests/testData/ast/kotlinExclusion/kt-656.kt +++ b/j2k/tests/testData/ast/kotlinExclusion/kt-656.kt @@ -1,7 +1,7 @@ package demo class Test() : java.lang.Iterable { - override fun iterator(): java.util.Iterator { + override fun iterator(): java.util.Iterator? { return null } @@ -12,7 +12,7 @@ class Test() : java.lang.Iterable { } class FullTest() : java.lang.Iterable { - override fun iterator(): java.util.Iterator { + override fun iterator(): java.util.Iterator? { return null } diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.java b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.java new file mode 100644 index 00000000000..3e616176614 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.java @@ -0,0 +1,8 @@ +//file +class C { + private void foo(String s){} + + void bar() { + foo(null) + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.kt b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.kt new file mode 100644 index 00000000000..2b784988bc6 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg.kt @@ -0,0 +1,8 @@ +class C() { + private fun foo(s: String?) { + } + + fun bar() { + foo(null) + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.java b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.java new file mode 100644 index 00000000000..51d580f53d9 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.java @@ -0,0 +1,10 @@ +//file +class C { + public void foo(String s){} +} + +class D { + void bar(C c) { + c.foo(null); + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.kt b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.kt new file mode 100644 index 00000000000..9d8c3d6d0d9 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithNullArg2.kt @@ -0,0 +1,10 @@ +class C() { + public fun foo(s: String?) { + } +} + +class D() { + fun bar(c: C) { + c.foo(null) + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.java b/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.java new file mode 100644 index 00000000000..d3e65b101b5 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.java @@ -0,0 +1,8 @@ +//file +class C { + private void foo(String s){} + + void bar(boolean b) { + foo(b ? "a" : null) + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.kt b/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.kt new file mode 100644 index 00000000000..7718cd50010 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodInvokedWithTernaryNullArg.kt @@ -0,0 +1,11 @@ +class C() { + private fun foo(s: String?) { + } + + fun bar(b: Boolean) { + foo((if (b) + "a" + else + null)) + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.java b/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.java new file mode 100644 index 00000000000..7a439558d6b --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.java @@ -0,0 +1,12 @@ +//file +interface I { + String getString(); +} + +class C { + void foo(I i) { + if (i.getString() == null) { + println("null") + } + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.kt b/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.kt new file mode 100644 index 00000000000..ed9ed55a8d3 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodResultComparedWithNull.kt @@ -0,0 +1,11 @@ +trait I { + public fun getString(): String? +} + +class C() { + fun foo(i: I) { + if (i.getString() == null) { + println("null") + } + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsNull.java b/j2k/tests/testData/ast/nullability/MethodReturnsNull.java new file mode 100644 index 00000000000..44478fdb1d6 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsNull.java @@ -0,0 +1,11 @@ +//file +class C { + String foo(boolean b) { + if (b) { + return "abc" + } + else { + return null + } + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsNull.kt b/j2k/tests/testData/ast/nullability/MethodReturnsNull.kt new file mode 100644 index 00000000000..69986554f46 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsNull.kt @@ -0,0 +1,9 @@ +class C() { + fun foo(b: Boolean): String? { + if (b) { + return "abc" + } else { + return null + } + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.java b/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.java new file mode 100644 index 00000000000..4d0a1cbab02 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.java @@ -0,0 +1,19 @@ +import java.lang.Override; +import java.lang.String; + +//file +interface Getter { + String get() +} + +class C { + String foo(boolean b) { + Getter getter = new Getter() { + @Override + public String get() { + return null; + } + }; + return ""; + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.kt b/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.kt new file mode 100644 index 00000000000..9acea5f077a --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsNullInAnonymousClass.kt @@ -0,0 +1,16 @@ +import java.lang.Override + +trait Getter { + public fun get(): String +} + +class C() { + fun foo(b: Boolean): String { + val getter = object : Getter() { + override fun get(): String? { + return null + } + } + return "" + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.java b/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.java new file mode 100644 index 00000000000..57e25579403 --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.java @@ -0,0 +1,6 @@ +//file +class C { + String foo(boolean b) { + return b ? "abc" : null + } +} \ No newline at end of file diff --git a/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.kt b/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.kt new file mode 100644 index 00000000000..607a860bcab --- /dev/null +++ b/j2k/tests/testData/ast/nullability/MethodReturnsTernaryNull.kt @@ -0,0 +1,8 @@ +class C() { + fun foo(b: Boolean): String? { + return (if (b) + "abc" + else + null) + } +} \ No newline at end of file