diff --git a/compiler/backend/src/org/jetbrains/jet/codegen/ExpressionCodegen.java b/compiler/backend/src/org/jetbrains/jet/codegen/ExpressionCodegen.java index fc041debc8d..41cf7fd8bce 100644 --- a/compiler/backend/src/org/jetbrains/jet/codegen/ExpressionCodegen.java +++ b/compiler/backend/src/org/jetbrains/jet/codegen/ExpressionCodegen.java @@ -2669,12 +2669,12 @@ If finally block is present, its last expression is the value of try expression. @Override public StackValue visitIsExpression(final JetIsExpression expression, StackValue receiver) { final StackValue match = StackValue.expression(TYPE_OBJECT, expression.getLeftHandSide(), this); - return generatePatternMatch(expression.getPattern(), expression.isNegated(), match, null); + return generatePatternMatch(expression.getPattern(), expression.isNegated(), match, false, null); } // on entering the function, expressionToMatch is already placed on stack, and we should consume it private StackValue generatePatternMatch(JetPattern pattern, boolean negated, StackValue expressionToMatch, - @Nullable Label nextEntry) { + boolean expressionToMatchIsNullable, @Nullable Label nextEntry) { if (pattern instanceof JetTypePattern) { JetTypeReference typeReference = ((JetTypePattern) pattern).getTypeReference(); JetType jetType = bindingContext.get(BindingContext.TYPE, typeReference); @@ -2692,9 +2692,18 @@ If finally block is present, its last expression is the value of try expression. expressionToMatch.dupReceiver(v); expressionToMatch.put(subjectType, v); JetExpression condExpression = ((JetExpressionPattern) pattern).getExpression(); - Type condType = isNumberPrimitive(subjectType) ? expressionType(condExpression) : TYPE_OBJECT; + boolean patternIsNullable = false; + JetType condJetType = bindingContext.get(BindingContext.EXPRESSION_TYPE, condExpression); + Type condType; + if (isNumberPrimitive(subjectType)) { + condType = asmType(condJetType); + } + else { + condType = TYPE_OBJECT; + patternIsNullable = condJetType != null && condJetType.isNullable(); + } gen(condExpression, condType); - return generateEqualsForExpressionsOnStack(JetTokens.EQEQ, subjectType, condType, false, false); + return generateEqualsForExpressionsOnStack(JetTokens.EQEQ, subjectType, condType, expressionToMatchIsNullable, patternIsNullable); } else { JetExpression condExpression = ((JetExpressionPattern) pattern).getExpression(); @@ -2714,7 +2723,7 @@ If finally block is present, its last expression is the value of try expression. expressionToMatch.put(varType, v); final int varIndex = myFrameMap.getIndex(variableDescriptor); v.store(varIndex, varType); - return generateWhenCondition(varType, varIndex, ((JetBindingPattern) pattern).getCondition(), null); + return generateWhenCondition(varType, varIndex, false, ((JetBindingPattern) pattern).getCondition(), null); } else { throw new UnsupportedOperationException("Unsupported pattern type: " + pattern); @@ -2743,7 +2752,7 @@ If finally block is present, its last expression is the value of try expression. v.mark(lblCheck); for (int i = 0; i < entries.size(); i++) { final StackValue tupleField = StackValue.field(TYPE_OBJECT, tupleClassName, "_" + (i + 1), false); - final StackValue stackValue = generatePatternMatch(entries.get(i).getPattern(), false, tupleField, nextEntry); + final StackValue stackValue = generatePatternMatch(entries.get(i).getPattern(), false, tupleField, false, nextEntry); stackValue.condJump(lblPopAndFail, true, v); } @@ -2793,7 +2802,8 @@ If finally block is present, its last expression is the value of try expression. @Override public StackValue visitWhenExpression(JetWhenExpression expression, StackValue receiver) { JetExpression expr = expression.getSubjectExpression(); - final Type subjectType = expressionType(expr); + JetType subjectJetType = bindingContext.get(BindingContext.EXPRESSION_TYPE, expr); + final Type subjectType = subjectJetType == null ? Type.VOID_TYPE : asmType(subjectJetType); final Type resultType = expressionType(expression); final int subjectLocal = expr != null ? myFrameMap.enterTemp(subjectType.getSize()) : -1; if(subjectLocal != -1) { @@ -2821,7 +2831,9 @@ If finally block is present, its last expression is the value of try expression. if (!whenEntry.isElse()) { final JetWhenCondition[] conditions = whenEntry.getConditions(); for (int i = 0; i < conditions.length; i++) { - StackValue conditionValue = generateWhenCondition(subjectType, subjectLocal, conditions[i], nextCondition); + StackValue conditionValue = generateWhenCondition(subjectType, subjectLocal, + subjectJetType != null && subjectJetType.isNullable(), + conditions[i], nextCondition); conditionValue.condJump(nextCondition, true, v); if (i < conditions.length - 1) { v.goTo(thisEntry); @@ -2848,7 +2860,8 @@ If finally block is present, its last expression is the value of try expression. return StackValue.onStack(resultType); } - private StackValue generateWhenCondition(Type subjectType, int subjectLocal, JetWhenCondition condition, @Nullable Label nextEntry) { + private StackValue generateWhenCondition(Type subjectType, int subjectLocal, boolean subjectIsNullable, + JetWhenCondition condition, @Nullable Label nextEntry) { if (condition instanceof JetWhenConditionInRange) { JetWhenConditionInRange conditionInRange = (JetWhenConditionInRange) condition; JetExpression rangeExpression = conditionInRange.getRangeExpression(); @@ -2878,7 +2891,8 @@ If finally block is present, its last expression is the value of try expression. throw new UnsupportedOperationException("unsupported kind of when condition"); } return generatePatternMatch(pattern, isNegated, - subjectLocal == -1 ? null : StackValue.local(subjectLocal, subjectType), nextEntry); + subjectLocal == -1 ? null : StackValue.local(subjectLocal, subjectType), + subjectIsNullable, nextEntry); } private boolean isIntRangeExpr(JetExpression rangeExpression) { diff --git a/compiler/testData/codegen/patternMatching/nullableWhen.kt b/compiler/testData/codegen/patternMatching/nullableWhen.kt new file mode 100644 index 00000000000..fe59c3183ae --- /dev/null +++ b/compiler/testData/codegen/patternMatching/nullableWhen.kt @@ -0,0 +1,10 @@ +fun f(p: Int?): Int { + return when(p) { + null -> 3 + else -> p!! + } +} + +fun box(): String { + return if (f(null) == 3) "OK" else "fail" +} diff --git a/compiler/tests/org/jetbrains/jet/codegen/PatternMatchingTest.java b/compiler/tests/org/jetbrains/jet/codegen/PatternMatchingTest.java index b4cf3a09651..e07b7a525eb 100644 --- a/compiler/tests/org/jetbrains/jet/codegen/PatternMatchingTest.java +++ b/compiler/tests/org/jetbrains/jet/codegen/PatternMatchingTest.java @@ -154,4 +154,8 @@ public class PatternMatchingTest extends CodegenTestCase { assertEquals("bit", foo.invoke(null, 1)); assertEquals("something", foo.invoke(null, 2)); } + + public void testNullableWhen() throws Exception { // KT-2148 + blackBoxFile("patternMatching/nullableWhen.kt"); + } }