JVM_IR fold safe calls and elvises

This commit is contained in:
Dmitry Petrov
2021-09-28 17:20:13 +03:00
committed by teamcityserver
parent 7370d096ee
commit 9325660f06
23 changed files with 266 additions and 78 deletions
@@ -39745,6 +39745,12 @@ public class FirBlackBoxCodegenTestGenerated extends AbstractFirBlackBoxCodegenT
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@Test
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@Nested
@@ -5448,6 +5448,12 @@ public class FirBytecodeTextTestGenerated extends AbstractFirBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/arrayCompoundAssignment.kt");
}
@Test
@TestMetadata("elvisChain.kt")
public void testElvisChain() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt");
}
@Test
@TestMetadata("notNullReceiversInChain.kt")
public void testNotNullReceiversInChain() throws Exception {
@@ -31,19 +31,60 @@ val jvmSafeCallFoldingPhase = makeIrFilePhase(
class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLoweringPass {
// Overall idea here is to represent (possibly chained) safe calls as an if-expression in the form:
// if ( { val tmp = <safe_receiver>; tmp != null } )
// <safe_call>
// else
// null
// when {
// { val tmp = <safe_receiver>; tmp != null } -> <safe_call_result>
// else -> null
// }
// This allows chaining safe calls like 'a?.foo()?.bar()?.qux()':
// if ( { val tmp1 = a; tmp1 != null } &&
// { val tmp2 = tmp1.foo(); tmp2 != null } &&
// { val tmp3 = tmp2.bar(); tmp3 != null }
// )
// tmp3.qux()
// else
// null
// This also allows fusing safe calls with elvises (and some other operations).
// when {
// { val tmp1 = a; tmp1 != null } &&
// { val tmp2 = tmp1.foo(); tmp2 != null } &&
// { val tmp3 = tmp2.bar(); tmp3 != null } ->
// tmp3.qux()
// else ->
// null
// }
// Folded safe call always has the following form:
// when {
// <safe_call_condition> -> <safe_call_result>
// else -> null
// }
// Note that <safe_call_result> itself might be nullable.
//
// Similarly, elvises can be represented in the form:
// when {
// { val tmp = <elvis_lhs>; tmp != null } ->
// tmp
// else ->
// <elvis_rhs>
// }
// which also allows chaining.
// Folded and possibly chained elvis has the following form:
// when {
// <elvis_condition_1> -> <elvis_lhs_1>
// ...
// <elvis_condition_N> -> <elvis_lhs_N>
// else -> <elvis_rhs>
// }
// where <elvis_lhs_k> are guaranteed to be non-null if the corresponding <elvis_condition_k> is true.
//
// This allows representing a chain of safe calls and elvises,
// e.g., 'a?.foo() ?: b?.bar()?.qux() ?: c'
// in a folded form:
// when {
// { val tmp1 = a; tmp1 != null } &&
// { val tmp2 = tmp1.foo; tmp1 != null } ->
// tmp2
// { val tmp3 = b; tmp3 != null } &&
// { val tmp4 = tmp2.bar(); tmp4 != null } &&
// { val tmp5 = tmp4.qux(); tmp5 != null } ->
// tmp5
// else ->
// c
// }
// which can be further simplified depending on the nullability of subexpressions in 'a?.foo() ?: b?.bar()?.qux() ?: c'.
// In bytecode this produces a chain of temporary STORE-LOAD and IFNULL checks that can be optimized to a compact sequence
// of stack operations and IFNULL checks.
override fun lower(irFile: IrFile) {
irFile.transformChildrenVoid(Transformer())
@@ -107,13 +148,13 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
}
private fun foldSafeCall(safeCallInfo: SafeCallInfo): IrExpression {
// Rewrite a safe call in the form:
// We have a safe call in the form:
// { // SAFE_CALL
// val tmp = <safe_receiver>
// if (tmp == null)
// null
// else
// <call[tmp]>
// when {
// tmp == null -> null
// else -> <safe_call_result>
// }
// }
val safeCallBlock = safeCallInfo.block
val startOffset = safeCallBlock.startOffset
@@ -126,18 +167,20 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
// Chained safe call.
// If <safe_receiver> is a FOLDED_SAFE_CALL form, rewrite safe call to:
// { // FOLDED_SAFE_CALL
// if ( <safe_receiver_condition> && { val tmp = <safe_receiver_result>; tmp != null } )
// <call[tmp]>
// else
// null
// when {
// <safe_receiver_condition> && { val tmp = <safe_receiver_result>; tmp != null } ->
// <safe_call_result>
// else ->
// null
// }
// }
// where
// <safe_receiver> =
// { // FOLDED_SAFE_CALL
// if ( <safe_receiver_condition> )
// <safe_receiver_result>
// else
// null
// when {
// <safe_receiver_condition> -> <safe_receiver_result>
// else -> null
// }
// }
val foldedBlock: IrBlock = tmpValInitializer
val foldedWhen = foldedBlock.statements[0] as IrWhen
@@ -162,10 +205,12 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
// Simple safe call.
// If <safe_receiver> itself is not a FOLDED_SAFE_CALL form, rewrite safe call to:
// { // FOLDED_SAFE_CALL
// if ( { val tmp = <safe_receiver>; tmp != null } )
// <call[tmp]>
// else
// null
// when {
// { val tmp = <safe_receiver>; tmp != null } ->
// <safe_call_result>
// else ->
// null
// }
// }
val foldedCondition =
@@ -203,24 +248,37 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
// Given elvis expression:
// { // ELVIS
// val tmp = <elvis_lhs>
// if (tmp == null)
// <elvis_rhs>
// else
// null
// when {
// tmp == null -> <elvis_lhs>
// else -> null
// }
// }
// where <elvis_lhs> is a folded safe call in the form:
// { // FOLDED_SAFE_CALL
// if ( <safe_call_condition> )
// <safe_call_result>
// else
// null
// when {
// <safe_call_condition> -> <safe_call_result>
// else -> null
// }
// }
// rewrite it to
// { // FOLDED_ELVIS
// if ( <safe_call_condition> && { val tmp = <safe_call_result>; tmp != null } )
// tmp
// else
// <elvis_rhs>
// when {
// <safe_call_condition> && { val tmp = <safe_call_result>; tmp != null } ->
// tmp
// else ->
// <elvis_rhs>
// }
// }
// If <elvis_rhs> is a folded safe call with non-null result, we can inline the underlying 'when':
// { // FOLDED_ELVIS
// when {
// <safe_call_condition> && { val tmp = <safe_call_result>; tmp != null } ->
// tmp
// <rhs_safe_call_condition> ->
// <rhs_safe_call_result>
// else ->
// null
// }
// }
val safeCallWhen = elvisLhs.statements[0] as IrWhen
@@ -236,39 +294,43 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
irValNotNull(startOffset, endOffset, elvisTmpVal)
)
)
val foldedWhen = IrWhenImpl(
startOffset, endOffset, elvisType, JvmLoweredStatementOrigin.FOLDED_ELVIS,
listOf(
IrBranchImpl(
startOffset, endOffset,
irAndAnd(safeCallCondition, foldedConditionPart),
IrGetValueImpl(startOffset, endOffset, elvisTmpVal.symbol)
),
IrBranchImpl(
startOffset, endOffset,
irTrue(startOffset, endOffset),
elvisInfo.elvisRhs
)
val branches = ArrayList<IrBranch>()
branches.add(
IrBranchImpl(
startOffset, endOffset,
irAndAnd(safeCallCondition, foldedConditionPart),
IrGetValueImpl(startOffset, endOffset, elvisTmpVal.symbol)
)
)
return foldedWhen.wrapWithBlock(JvmLoweredStatementOrigin.FOLDED_ELVIS)
val elvisRhs = elvisInfo.elvisRhs
if (elvisRhs.isFoldedSafeCallWithNonNullResult()) {
val rhsInnerWhen = (elvisRhs as IrBlock).statements[0] as IrWhen
branches.addAll(rhsInnerWhen.branches)
} else {
branches.add(IrBranchImpl(startOffset, endOffset, irTrue(startOffset, endOffset), elvisInfo.elvisRhs))
}
return IrWhenImpl(startOffset, endOffset, elvisType, JvmLoweredStatementOrigin.FOLDED_ELVIS, branches)
.wrapWithBlock(JvmLoweredStatementOrigin.FOLDED_ELVIS)
}
elvisLhs is IrBlock && elvisLhs.origin == JvmLoweredStatementOrigin.FOLDED_ELVIS -> {
// Append branches to the inner elvis:
// val t = { // FOLDED_ELVIS
// if (...) ...
// else if ...
// else <innerElvisRhs>
// when {
// ... <innerElvisBranches> ...
// else -> <innerElvisRhs>
// }
// }
// when {
// t == null -> <outerElvisRhs>
// else -> t
// }
// if (t != null) t else <outerElvisRhs>
// =>
// { // FOLDED_ELVIS
// if (...) ...
// else if ...
// else if ( { val t = <innerElvisRhs>; t != null } )
// t
// else
// <outerElvisRhs>
// when {
// ... <innerElvisBranches> ...
// { val t = <innerElvisRhs>; t != null } -> t
// else -> <outerElvisRhs>
// }
// }
// TODO maybe we can do somewhat better if we analyze innerElvisRhs as well
val innerElvisWhen = elvisLhs.statements[0] as IrWhen
@@ -295,11 +357,45 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
return innerElvisWhen.wrapWithBlock(JvmLoweredStatementOrigin.FOLDED_ELVIS)
}
else -> {
return elvisInfo.block
// Plain elvis.
// { // ELVIS
// val tmp = <lhs>
// when {
// tmp == null -> <rhs>
// else -> tmp
// }
// }
// Fold it as:
// { // FOLDED_ELVIS
// when {
// { val tmp = <lhs>; tmp != null } -> tmp
// else -> rhs
// }
// }
val newCondition = IrCompositeImpl(
startOffset, endOffset, context.irBuiltIns.booleanType, null,
listOf(elvisTmpVal, irValNotNull(startOffset, endOffset, elvisTmpVal))
)
val foldedWhen = IrWhenImpl(
startOffset, endOffset, elvisType, JvmLoweredStatementOrigin.FOLDED_ELVIS,
listOf(
IrBranchImpl(startOffset, endOffset, newCondition, IrGetValueImpl(startOffset, endOffset, elvisTmpVal.symbol)),
IrBranchImpl(startOffset, endOffset, irTrue(startOffset, endOffset), elvisInfo.elvisRhs)
)
)
return foldedWhen.wrapWithBlock(JvmLoweredStatementOrigin.FOLDED_ELVIS)
}
}
}
private fun IrExpression.isFoldedSafeCallWithNonNullResult(): Boolean {
if (this !is IrBlock) return false
if (this.origin != JvmLoweredStatementOrigin.FOLDED_SAFE_CALL) return false
val innerWhen = this.statements[0] as? IrWhen ?: return false
val safeCallResult = innerWhen.branches[0].result
return !safeCallResult.type.isNullable()
}
override fun visitCall(expression: IrCall): IrExpression {
expression.transformChildrenVoid()
@@ -330,10 +426,8 @@ class JvmSafeCallChainFoldingLowering(val context: JvmBackendContext) : FileLowe
return safeCallWhen.wrapWithBlock(origin = null)
}
}
return expression
}
}
}
@@ -0,0 +1,8 @@
fun String.foo(): String? = null
fun test(a: String?, b: String?, c: String) =
a ?: b?.foo() ?: c
// = (a ?: b?.boo()) ?: c
// Here 'b?.foo()' returns null, which may break elvis semantics if we fold it carelessly.
fun box() = test(null, "abc", "OK")
@@ -14,5 +14,5 @@ class A(val x: String) {
// JVM_IR_TEMPLATES
// 4 IFNULL
// 2 IFNONNULL
// 2 ACONST_NULL
// 0 IFNONNULL
// 0 ACONST_NULL
@@ -7,3 +7,4 @@ fun test(xs: IntArray, dx: Int) {
// JVM_IR_TEMPLATES
// 5 ALOAD
// 6 ILOAD
@@ -0,0 +1,6 @@
fun test(a: Any?, b: Any?, c: Any) =
a ?: b ?: c
// 2 IFNONNULL
// 0 IFNULL
// 0 ACONST_NULL
@@ -5,13 +5,14 @@ class C(val s: String)
fun test(na: A?) =
na?.b?.c?.s
// 1 POP
// 1 ACONST_NULL
// JVM_IR_TEMPLATES
// 1 DUP
// 1 IFNULL
// 0 IFNONNULL
// 1 ACONST_NULL
// JVM_TEMPLATES
// 3 DUP
// 3 IFNULL
// 3 IFNULL
// 0 IFNONNULL
// 1 ACONST_NULL
@@ -6,3 +6,6 @@ fun test(an: A?) = an?.b?.c?.s
// JVM_IR_TEMPLATES
// 0 ASTORE
// 1 IFNULL
// 0 IFNONNULL
// 1 ACONST_NULL
@@ -6,3 +6,6 @@ fun test(an: A?) = an?.bn?.cn?.s
// JVM_IR_TEMPLATES
// 0 ASTORE
// 1 ACONST_NULL
// 3 IFNULL
// 0 IFNONNULL
@@ -12,3 +12,6 @@ object Host {
// JVM_IR_TEMPLATES
// 0 ASTORE
// 1 ACONST_NULL
// 1 IFNULL
// 0 IFNONNULL
@@ -16,3 +16,6 @@ fun test(an: A?) = an?.b?.c?.s
// JVM_IR_TEMPLATES
// 0 ASTORE
// 1 ACONST_NULL
// 1 IFNULL
// 0 IFNONNULL
@@ -1,11 +1,12 @@
fun test(a: Any?, b: Any?, c: String) =
a?.toString() ?: b?.toString() ?: c
// 2 IFNULL
// 1 ACONST_NULL
// JVM_IR_TEMPLATES
// 1 IFNONNULL
// 2 IFNULL
// 0 IFNONNULL
// 0 ACONST_NULL
// JVM_TEMPLATES
// 2 IFNULL
// 1 ACONST_NULL
// 2 IFNONNULL
@@ -11,7 +11,11 @@ fun test(a: Any?) =
// JVM_IR_TEMPLATES
// 1 DUP
// 1 IFNULL
// 0 ACONST_NULL
// 0 IFNONNULL
// JVM_TEMPLATES
// 2 DUP
// 2 IFNULL
// 0 ACONST_NULL
// 0 IFNONNULL
@@ -39589,6 +39589,12 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest {
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@Test
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@Nested
@@ -5304,6 +5304,12 @@ public class BytecodeTextTestGenerated extends AbstractBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/arrayCompoundAssignment.kt");
}
@Test
@TestMetadata("elvisChain.kt")
public void testElvisChain() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt");
}
@Test
@TestMetadata("notNullReceiversInChain.kt")
public void testNotNullReceiversInChain() throws Exception {
@@ -39745,6 +39745,12 @@ public class IrBlackBoxCodegenTestGenerated extends AbstractIrBlackBoxCodegenTes
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@Test
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@Nested
@@ -5448,6 +5448,12 @@ public class IrBytecodeTextTestGenerated extends AbstractIrBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/arrayCompoundAssignment.kt");
}
@Test
@TestMetadata("elvisChain.kt")
public void testElvisChain() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt");
}
@Test
@TestMetadata("notNullReceiversInChain.kt")
public void testNotNullReceiversInChain() throws Exception {
@@ -31687,6 +31687,11 @@ public class LightAnalysisModeTestGenerated extends AbstractLightAnalysisModeTes
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@TestMetadata("compiler/testData/codegen/box/sam")
@@ -26862,6 +26862,11 @@ public class IrJsCodegenBoxES6TestGenerated extends AbstractIrJsCodegenBoxES6Tes
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@TestMetadata("compiler/testData/codegen/box/sam")
@@ -26268,6 +26268,11 @@ public class IrJsCodegenBoxTestGenerated extends AbstractIrJsCodegenBoxTest {
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@TestMetadata("compiler/testData/codegen/box/sam")
@@ -26193,6 +26193,11 @@ public class JsCodegenBoxTestGenerated extends AbstractJsCodegenBoxTest {
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@TestMetadata("compiler/testData/codegen/box/sam")
@@ -15587,6 +15587,11 @@ public class IrCodegenBoxWasmTestGenerated extends AbstractIrCodegenBoxWasmTest
public void testSafeCallOnLong() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallOnLong.kt");
}
@TestMetadata("safeCallWithElvisFolding.kt")
public void testSafeCallWithElvisFolding() throws Exception {
runTest("compiler/testData/codegen/box/safeCall/safeCallWithElvisFolding.kt");
}
}
@TestMetadata("compiler/testData/codegen/box/sam")