diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/nullCheck/RedundantNullCheckMethodTransformer.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/nullCheck/RedundantNullCheckMethodTransformer.kt index 7df1917106f..9c93c6a25ed 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/nullCheck/RedundantNullCheckMethodTransformer.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/nullCheck/RedundantNullCheckMethodTransformer.kt @@ -72,6 +72,7 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio val value = when { insn.isInstanceOfOrNullCheck() -> frame.top() insn.isCheckNotNull() -> frame.top() + insn.isCheckNotNullWithMessage() -> frame.peek(1) insn.isCheckExpressionValueIsNotNull() -> frame.peek(1) else -> null } as? StrictBasicValue ?: continue @@ -88,6 +89,7 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio opcode == Opcodes.IFNONNULL || opcode == Opcodes.INSTANCEOF || isCheckNotNull() || + isCheckNotNullWithMessage() || isCheckExpressionValueIsNotNull() private fun transformTrivialChecks(nullabilityMap: Map) { @@ -99,11 +101,13 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio Opcodes.INSTANCEOF -> transformInstanceOf(insn as TypeInsnNode, nullability, value) Opcodes.INVOKESTATIC -> { - if (insn.isCheckNotNull()) { - transformTrivialCheckNotNull(insn, nullability) - } - if (insn.isCheckExpressionValueIsNotNull()) { - transformTrivialCheckExpressionValueIsNotNull(insn, nullability) + when { + insn.isCheckNotNull() -> + transformTrivialCheckNotNull(insn, nullability) + insn.isCheckNotNullWithMessage() -> + transformTrivialCheckNotNullWithMessage(insn, nullability) + insn.isCheckExpressionValueIsNotNull() -> + transformTrivialCheckExpressionValueIsNotNull(insn, nullability) } } } @@ -150,6 +154,17 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio } } + private fun transformTrivialCheckNotNullWithMessage(insn: AbstractInsnNode, nullability: Nullability) { + if (nullability != Nullability.NOT_NULL) return + val ldcInsn = insn.previous?.takeIf { it.opcode == Opcodes.LDC } ?: return + val previousInsn = ldcInsn.previous?.takeIf { it.opcode == Opcodes.DUP || it.opcode == Opcodes.ALOAD } ?: return + methodNode.instructions.run { + remove(previousInsn) + remove(ldcInsn) + remove(insn) + } + } + private fun transformTrivialCheckExpressionValueIsNotNull(insn: AbstractInsnNode, nullability: Nullability) { if (nullability != Nullability.NOT_NULL) return val ldcInsn = insn.previous?.takeIf { it.opcode == Opcodes.LDC } ?: return @@ -185,12 +200,14 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio } insn.isCheckNotNull() -> { - val previous = insn.previous ?: continue - val aLoadInsn = if (previous.opcode == Opcodes.DUP) { - previous.previous ?: continue - } else previous - if (aLoadInsn.opcode != Opcodes.ALOAD) continue - addDependentCheck(insn, aLoadInsn as VarInsnNode) + val checkedValueInsn = insn.previous ?: continue + addDependentCheckForCheckNotNull(insn, checkedValueInsn) + } + + insn.isCheckNotNullWithMessage() -> { + val ldcInsn = insn.previous?.takeIf { it.opcode == Opcodes.LDC } ?: continue + val checkedValueInsn = ldcInsn.previous ?: continue + addDependentCheckForCheckNotNull(insn, checkedValueInsn) } insn.isCheckParameterIsNotNull() -> { @@ -221,6 +238,16 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio } } + private fun addDependentCheckForCheckNotNull(insn: AbstractInsnNode, checkedValueInsn: AbstractInsnNode) { + val aLoadInsn = if (checkedValueInsn.opcode == Opcodes.DUP) { + checkedValueInsn.previous ?: return + } else { + checkedValueInsn + } + if (aLoadInsn.opcode != Opcodes.ALOAD) return + addDependentCheck(insn, aLoadInsn as VarInsnNode) + } + private fun addDependentCheck(insn: AbstractInsnNode, aLoadInsn: VarInsnNode) { checksDependingOnVariable.getOrPut(aLoadInsn.`var`) { SmartList() @@ -249,7 +276,8 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio injectAssumptionsForNullCheck(varIndex, insn as JumpInsnNode) Opcodes.INVOKESTATIC -> { when { - insn.isCheckNotNull() || insn.isCheckParameterIsNotNull() || insn.isCheckExpressionValueIsNotNull() -> + insn.isCheckNotNull() || insn.isCheckNotNullWithMessage() || + insn.isCheckParameterIsNotNull() || insn.isCheckExpressionValueIsNotNull() -> injectAssumptionsForNotNullAssertion(varIndex, insn) insn.isPseudo(PseudoInsn.STORE_NOT_NULL) -> injectCodeForStoreNotNull(insn) @@ -300,13 +328,16 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio } private fun NullabilityAssumptions.injectAssumptionsForNotNullAssertion(varIndex: Int, insn: AbstractInsnNode) { - // ALOAD v - // DUP - // INVOKESTATIC checkNotNull + // ( INVOKESTATIC checkNotNull + // | LDC ; INVOKESTATIC checkNotNull(Object, String)V + // ) // <...> -- v is not null here (otherwise an exception was thrown) // ALOAD v - // INVOKESTATIC checkNotNull + // DUP? + // ( INVOKESTATIC checkNotNull + // | LDC ; INVOKESTATIC checkNotNull(Object, String)V + // ) // <...> -- v is not null here (otherwise an exception was thrown) // ALOAD v @@ -317,7 +348,9 @@ class RedundantNullCheckMethodTransformer(private val generationState: Generatio // ALOAD v // DUP // LDC * - // INVOKESTATIC checkExpressionValueIsNotNull/checkNotNullExpressionValue + // ( INVOKESTATIC checkExpressionValueIsNotNull + // | INVOKESTATIC checkNotNullExpressionValue + // ) // <...> -- v is not null here (otherwise an exception was thrown) methodNode.instructions.insert(insn, listOfSynthetics { @@ -436,6 +469,13 @@ internal fun AbstractInsnNode.isCheckNotNull() = desc == "(Ljava/lang/Object;)V" } +internal fun AbstractInsnNode.isCheckNotNullWithMessage() = + isInsn(Opcodes.INVOKESTATIC) { + owner == IntrinsicMethods.INTRINSICS_CLASS_NAME && + name == "checkNotNull" && + desc == "(Ljava/lang/Object;Ljava/lang/String;)V" + } + fun MethodNode.usesLocalExceptParameterNullCheck(index: Int): Boolean = instructions.toArray().any { it is VarInsnNode && it.opcode == Opcodes.ALOAD && it.`var` == index && !it.isParameterCheckedForNull() diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/temporaryVals/TemporaryVariablesEliminationTransformer.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/temporaryVals/TemporaryVariablesEliminationTransformer.kt index 4112d74e2cc..f4fa3cdc9b8 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/temporaryVals/TemporaryVariablesEliminationTransformer.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/temporaryVals/TemporaryVariablesEliminationTransformer.kt @@ -10,6 +10,7 @@ import org.jetbrains.kotlin.codegen.optimization.common.InsnSequence import org.jetbrains.kotlin.codegen.optimization.common.isMeaningful import org.jetbrains.kotlin.codegen.optimization.common.removeUnusedLocalVariables import org.jetbrains.kotlin.codegen.optimization.nullCheck.isCheckExpressionValueIsNotNull +import org.jetbrains.kotlin.codegen.optimization.nullCheck.isCheckNotNullWithMessage import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer import org.jetbrains.kotlin.codegen.state.GenerationState import org.jetbrains.kotlin.utils.SmartList @@ -288,7 +289,7 @@ class TemporaryVariablesEliminationTransformer(private val state: GenerationStat if ((aLoad1Insn as VarInsnNode).`var` == tmp.index && (ldcInsn as LdcInsnNode).cst is String && - invokeStaticInsn.isCheckExpressionValueIsNotNull() && + (invokeStaticInsn.isCheckExpressionValueIsNotNull() || invokeStaticInsn.isCheckNotNullWithMessage()) && (aLoad2Insn as VarInsnNode).`var` == tmp.index ) { // Replace instruction sequence: diff --git a/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBytecodeTextTestGenerated.java b/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBytecodeTextTestGenerated.java index 4b8646b1f1c..012ed8eb5ef 100644 --- a/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBytecodeTextTestGenerated.java +++ b/compiler/fir/fir2ir/tests-gen/org/jetbrains/kotlin/test/runners/codegen/FirBytecodeTextTestGenerated.java @@ -5506,6 +5506,12 @@ public class FirBytecodeTextTestGenerated extends AbstractFirBytecodeTextTest { runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt"); } + @Test + @TestMetadata("noTemporaryInCheckedCast.kt") + public void testNoTemporaryInCheckedCast() throws Exception { + runTest("compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt"); + } + @Test @TestMetadata("notNullReceiversInChain.kt") public void testNotNullReceiversInChain() throws Exception { diff --git a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/TypeOperatorLowering.kt b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/TypeOperatorLowering.kt index 8770951f569..305dfe47637 100644 --- a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/TypeOperatorLowering.kt +++ b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/TypeOperatorLowering.kt @@ -57,7 +57,9 @@ internal val typeOperatorLowering = makeIrFilePhase( description = "Lower IrTypeOperatorCalls to (implicit) casts and instanceof checks" ) -private class TypeOperatorLowering(private val context: JvmBackendContext) : FileLoweringPass, IrBuildingTransformer(context) { +private class TypeOperatorLowering(private val backendContext: JvmBackendContext) : + FileLoweringPass, IrBuildingTransformer(backendContext) { + override fun lower(irFile: IrFile) = irFile.transformChildrenVoid() private fun IrExpression.transformVoid() = transform(this@TypeOperatorLowering, null) @@ -80,28 +82,44 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil } } - private fun lowerCast(argument: IrExpression, type: IrType): IrExpression = when { - type.isReifiedTypeParameter -> - builder.irAs(argument, type) - argument.type.isInlineClassType() && argument.type.isSubtypeOfClass(type.erasedUpperBound.symbol) -> - argument - type.isNullable() || argument.isDefinitelyNotNull() -> - builder.irAs(argument, type) - else -> { - with(builder) { - irLetS(argument, irType = context.irBuiltIns.anyNType) { valueSymbol -> - irIfNull( - type, - irGet(valueSymbol.owner), - irCall(throwTypeCastException).apply { - putValueArgument(0, irString("null cannot be cast to non-null type ${type.render()}")) - }, - builder.irAs(irGet(valueSymbol.owner), type.makeNullable()) - ) + private fun lowerCast(argument: IrExpression, type: IrType): IrExpression = + when { + type.isReifiedTypeParameter -> + builder.irAs(argument, type) + argument.type.isInlineClassType() && argument.type.isSubtypeOfClass(type.erasedUpperBound.symbol) -> + argument + type.isNullable() || argument.isDefinitelyNotNull() -> + builder.irAs(argument, type) + else -> { + with(builder) { + irLetS(argument, irType = context.irBuiltIns.anyNType) { tmp -> + val message = irString("null cannot be cast to non-null type ${type.render()}") + if (backendContext.state.unifiedNullChecks) { + // Avoid branching to improve code coverage (KT-27427). + // We have to generate a null check here, because even if argument is of non-null type, + // it can be uninitialized value, which is 'null' for reference types in JMM. + // Most of such null checks will never actually throw, but we can't do anything about it. + irBlock { + +irCall(backendContext.ir.symbols.checkNotNullWithMessage).apply { + putValueArgument(0, irGet(tmp.owner)) + putValueArgument(1, message) + } + +irAs(irGet(tmp.owner), type.makeNullable()) + } + } else { + irIfNull( + type, + irGet(tmp.owner), + irCall(throwTypeCastException).apply { + putValueArgument(0, message) + }, + irAs(irGet(tmp.owner), type.makeNullable()) + ) + } + } } } } - } // TODO extract null check elimination on IR somewhere? private fun IrExpression.isDefinitelyNotNull(): Boolean = @@ -112,7 +130,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil is IrConstructorCall -> true is IrCall -> - this.symbol == context.irBuiltIns.checkNotNullSymbol + this.symbol == backendContext.irBuiltIns.checkNotNullSymbol else -> false } @@ -122,7 +140,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil return !irVariable.isVar && irVariable.initializer?.isDefinitelyNotNull() == true } - private val jvmIndyLambdaMetafactoryIntrinsic = context.ir.symbols.indyLambdaMetafactoryIntrinsic + private val jvmIndyLambdaMetafactoryIntrinsic = backendContext.ir.symbols.indyLambdaMetafactoryIntrinsic private fun JvmIrBuilder.jvmMethodHandle(handle: Handle) = irCall(backendContext.ir.symbols.jvmMethodHandle).apply { @@ -278,15 +296,15 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil byDeserializedLambdaInfo[deserializedLambdaInfo] = serializableMethodRefInfo } - val deserializeLambdaFun = context.irFactory.buildFun { + val deserializeLambdaFun = backendContext.irFactory.buildFun { name = Name.identifier("\$deserializeLambda\$") visibility = DescriptorVisibilities.PRIVATE origin = JvmLoweredDeclarationOrigin.DESERIALIZE_LAMBDA_FUN } deserializeLambdaFun.parent = irClass - val lambdaParameter = deserializeLambdaFun.addValueParameter("lambda", context.ir.symbols.serializedLambda.irType) - deserializeLambdaFun.returnType = context.irBuiltIns.anyType - deserializeLambdaFun.body = context.createJvmIrBuilder(deserializeLambdaFun.symbol, UNDEFINED_OFFSET, UNDEFINED_OFFSET).run { + val lambdaParameter = deserializeLambdaFun.addValueParameter("lambda", backendContext.ir.symbols.serializedLambda.irType) + deserializeLambdaFun.returnType = backendContext.irBuiltIns.anyType + deserializeLambdaFun.body = backendContext.createJvmIrBuilder(deserializeLambdaFun.symbol, UNDEFINED_OFFSET, UNDEFINED_OFFSET).run { irBlockBody { val tmp = irTemporary( irCall(backendContext.ir.symbols.serializedLambda.getImplMethodName).apply { @@ -332,9 +350,9 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil private fun mapDeserializedLambda(info: SerializableMethodRefInfo) = DeserializedLambdaInfo( - functionalInterfaceClass = context.typeMapper.mapType(info.samType).internalName, - implMethodHandle = context.methodSignatureMapper.mapToMethodHandle(info.implFunSymbol.owner), - functionalInterfaceMethod = context.methodSignatureMapper.mapAsmMethod(info.samMethodSymbol.owner) + functionalInterfaceClass = backendContext.typeMapper.mapType(info.samType).internalName, + implMethodHandle = backendContext.methodSignatureMapper.mapToMethodHandle(info.implFunSymbol.owner), + functionalInterfaceMethod = backendContext.methodSignatureMapper.mapAsmMethod(info.samMethodSymbol.owner) ) private fun JvmIrBuilder.generateSerializedLambdaEquals( @@ -376,7 +394,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil ) } - private val equalsAny = context.irBuiltIns.anyClass.getSimpleFunction("equals")!! + private val equalsAny = backendContext.irBuiltIns.anyClass.getSimpleFunction("equals")!! private fun JvmIrBuilder.irObjectEquals(receiver: IrExpression, arg: IrExpression) = irCall(equalsAny).apply { @@ -478,7 +496,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil ) } - return context.createJvmIrBuilder(implFunSymbol, startOffset, endOffset) + return backendContext.createJvmIrBuilder(implFunSymbol, startOffset, endOffset) .createLambdaMetafactoryCall( samMethod.symbol, implFunSymbol, instanceMethodRef.symbol, shouldBeSerializable, requiredBridges, dynamicCall ) @@ -551,12 +569,12 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil samMethod: IrSimpleFunction, extraOverriddenMethods: List ): Collection { - val jvmInstanceMethod = context.methodSignatureMapper.mapAsmMethod(instanceMethod) - val jvmSamMethod = context.methodSignatureMapper.mapAsmMethod(samMethod) + val jvmInstanceMethod = backendContext.methodSignatureMapper.mapAsmMethod(instanceMethod) + val jvmSamMethod = backendContext.methodSignatureMapper.mapAsmMethod(samMethod) val signatureToNonFakeOverride = LinkedHashMap() for (overridden in extraOverriddenMethods) { - val jvmOverriddenMethod = context.methodSignatureMapper.mapAsmMethod(overridden) + val jvmOverriddenMethod = backendContext.methodSignatureMapper.mapAsmMethod(overridden) if (jvmOverriddenMethod != jvmInstanceMethod && jvmOverriddenMethod != jvmSamMethod) { signatureToNonFakeOverride[jvmOverriddenMethod] = overridden } @@ -574,12 +592,12 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil val dynamicCallArguments = ArrayList() - val irDynamicCallTarget = context.irFactory.buildFun { + val irDynamicCallTarget = backendContext.irFactory.buildFun { origin = JvmLoweredDeclarationOrigin.INVOKEDYNAMIC_CALL_TARGET name = samMethod.name returnType = erasedSamType }.apply { - parent = context.ir.symbols.kotlinJvmInternalInvokeDynamicPackage + parent = backendContext.ir.symbols.kotlinJvmInternalInvokeDynamicPackage val targetFun = targetRef.symbol.owner val refDispatchReceiver = targetRef.dispatchReceiver @@ -640,7 +658,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil ) } - return context.createJvmIrBuilder(irDynamicCallTarget.symbol) + return backendContext.createJvmIrBuilder(irDynamicCallTarget.symbol) .irCall(irDynamicCallTarget.symbol) .apply { for (i in dynamicCallArguments.indices) { @@ -656,11 +674,7 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil irComposite(resultType = expression.type) { +expression.argument.transformVoid() // TODO: Don't generate these casts in the first place - if (!expression.argument.type.isSubtypeOf( - expression.type.makeNullable(), - this@TypeOperatorLowering.context.typeSystem - ) - ) { + if (!expression.argument.type.isSubtypeOf(expression.type.makeNullable(), backendContext.typeSystem)) { +IrCompositeImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, expression.type) } } @@ -762,14 +776,11 @@ private class TypeOperatorLowering(private val context: JvmBackendContext) : Fil declaration.fileParent.getKtFile()!!.viewProvider.contents private val throwTypeCastException: IrSimpleFunctionSymbol = - if (context.state.unifiedNullChecks) - context.ir.symbols.throwNullPointerException - else - context.ir.symbols.throwTypeCastException + backendContext.ir.symbols.throwTypeCastException private val checkExpressionValueIsNotNull: IrSimpleFunctionSymbol = - if (context.state.unifiedNullChecks) - context.ir.symbols.checkNotNullExpressionValue + if (backendContext.state.unifiedNullChecks) + backendContext.ir.symbols.checkNotNullExpressionValue else - context.ir.symbols.checkExpressionValueIsNotNull + backendContext.ir.symbols.checkExpressionValueIsNotNull } diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmSymbols.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmSymbols.kt index 2e00beaeb44..0125b79ce06 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmSymbols.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/JvmSymbols.kt @@ -145,6 +145,10 @@ class JvmSymbols( klass.addFunction("checkNotNull", irBuiltIns.unitType, isStatic = true).apply { addValueParameter("object", irBuiltIns.anyNType) } + klass.addFunction("checkNotNull", irBuiltIns.unitType, isStatic = true).apply { + addValueParameter("object", irBuiltIns.anyNType) + addValueParameter("message", irBuiltIns.stringType) + } klass.addFunction("throwNpe", irBuiltIns.unitType, isStatic = true) klass.declarations.add(irFactory.buildClass { @@ -162,7 +166,10 @@ class JvmSymbols( intrinsicsClass.functions.single { it.owner.name.asString() == "checkNotNullExpressionValue" } val checkNotNull: IrSimpleFunctionSymbol = - intrinsicsClass.functions.single { it.owner.name.asString() == "checkNotNull" } + intrinsicsClass.owner.functions.single { it.name.asString() == "checkNotNull" && it.valueParameters.size == 1 }.symbol + + val checkNotNullWithMessage: IrSimpleFunctionSymbol = + intrinsicsClass.owner.functions.single { it.name.asString() == "checkNotNull" && it.valueParameters.size == 2 }.symbol val throwNpe: IrSimpleFunctionSymbol = intrinsicsClass.functions.single { it.owner.name.asString() == "throwNpe" } diff --git a/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/builders/ExpressionHelpers.kt b/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/builders/ExpressionHelpers.kt index 17ff92b4852..d9e0920a4b1 100644 --- a/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/builders/ExpressionHelpers.kt +++ b/compiler/ir/ir.tree/src/org/jetbrains/kotlin/ir/builders/ExpressionHelpers.kt @@ -10,8 +10,12 @@ import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.expressions.impl.* import org.jetbrains.kotlin.ir.symbols.* -import org.jetbrains.kotlin.ir.types.* -import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.types.IrType +import org.jetbrains.kotlin.ir.types.typeWith +import org.jetbrains.kotlin.ir.util.isImmutable +import org.jetbrains.kotlin.ir.util.parentAsClass +import org.jetbrains.kotlin.ir.util.primaryConstructor +import org.jetbrains.kotlin.ir.util.render import org.jetbrains.kotlin.utils.addToStdlib.assertedCast val IrBuilderWithScope.parent get() = scope.getLocalDeclarationParent() @@ -23,20 +27,25 @@ inline fun IrBuilderWithScope.irLetS( irType: IrType? = null, body: (IrValueSymbol) -> IrExpression ): IrExpression { - val (valueSymbol, irTemporary) = if (value is IrGetValue && value.symbol.owner.isImmutable) { - value.symbol to null + val irTemporary: IrVariable? + val valueSymbol: IrValueSymbol + if (value is IrGetValue && value.symbol.owner.isImmutable) { + irTemporary = null + valueSymbol = value.symbol } else { - scope.createTemporaryVariable(value, nameHint, irType = irType).let { it.symbol to it } + irTemporary = scope.createTemporaryVariable(value, nameHint, irType = irType) + valueSymbol = irTemporary.symbol } val irResult = body(valueSymbol) - return if (irTemporary == null) { - irResult + if (irTemporary == null) return irResult + val irBlock = IrBlockImpl(startOffset, endOffset, irResult.type, origin) + irBlock.statements.add(irTemporary) + if (irResult is IrStatementContainer) { + irBlock.statements.addAll(irResult.statements) } else { - val irBlock = IrBlockImpl(startOffset, endOffset, irResult.type, origin) - irBlock.statements.add(irTemporary) irBlock.statements.add(irResult) - irBlock } + return irBlock } fun IrStatementsBuilder.irTemporary( diff --git a/compiler/testData/codegen/bytecodeText/forLoop/intrinsicArrayConstructorsUseCounterLoop.kt b/compiler/testData/codegen/bytecodeText/forLoop/intrinsicArrayConstructorsUseCounterLoop.kt index 6eb14f9aa50..1972b9b6e55 100644 --- a/compiler/testData/codegen/bytecodeText/forLoop/intrinsicArrayConstructorsUseCounterLoop.kt +++ b/compiler/testData/codegen/bytecodeText/forLoop/intrinsicArrayConstructorsUseCounterLoop.kt @@ -31,7 +31,7 @@ fun testDoubleArray(n: Int) = DoubleArray(n) { it.toDouble() } fun testObjectArray(n: Int) = - Array(n) { it as Any } + Array(n) { it.toString() } // 0 IF_ICMPGT // 0 IF_CMPEQ diff --git a/compiler/testData/codegen/bytecodeText/nullCheckOptimization/noNullCheckAfterCast.kt b/compiler/testData/codegen/bytecodeText/nullCheckOptimization/noNullCheckAfterCast.kt index 35af2351982..f94df9d3b6f 100644 --- a/compiler/testData/codegen/bytecodeText/nullCheckOptimization/noNullCheckAfterCast.kt +++ b/compiler/testData/codegen/bytecodeText/nullCheckOptimization/noNullCheckAfterCast.kt @@ -26,4 +26,8 @@ fun test3() { fun getB(): B = B() +// JVM_TEMPLATES // 1 IFNONNULL +// JVM_IR_TEMPLATES +// 0 IFNONNULL +// 1 INVOKESTATIC kotlin/jvm/internal/Intrinsics.checkNotNull \(Ljava/lang/Object;Ljava/lang/String;\)V diff --git a/compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt b/compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt new file mode 100644 index 00000000000..d12f3ca0f47 --- /dev/null +++ b/compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt @@ -0,0 +1,6 @@ +fun foo(): Any? = "abc" + +fun test() = foo() as String + +// 0 ASTORE +// 0 ALOAD diff --git a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BytecodeTextTestGenerated.java b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BytecodeTextTestGenerated.java index f06503a7552..eaf8965092b 100644 --- a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BytecodeTextTestGenerated.java +++ b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/BytecodeTextTestGenerated.java @@ -5356,6 +5356,12 @@ public class BytecodeTextTestGenerated extends AbstractBytecodeTextTest { runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt"); } + @Test + @TestMetadata("noTemporaryInCheckedCast.kt") + public void testNoTemporaryInCheckedCast() throws Exception { + runTest("compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt"); + } + @Test @TestMetadata("notNullReceiversInChain.kt") public void testNotNullReceiversInChain() throws Exception { diff --git a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBytecodeTextTestGenerated.java b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBytecodeTextTestGenerated.java index 02ca05d4a37..69bf549f6e5 100644 --- a/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBytecodeTextTestGenerated.java +++ b/compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/codegen/IrBytecodeTextTestGenerated.java @@ -5506,6 +5506,12 @@ public class IrBytecodeTextTestGenerated extends AbstractIrBytecodeTextTest { runTest("compiler/testData/codegen/bytecodeText/temporaryVals/elvisChain.kt"); } + @Test + @TestMetadata("noTemporaryInCheckedCast.kt") + public void testNoTemporaryInCheckedCast() throws Exception { + runTest("compiler/testData/codegen/bytecodeText/temporaryVals/noTemporaryInCheckedCast.kt"); + } + @Test @TestMetadata("notNullReceiversInChain.kt") public void testNotNullReceiversInChain() throws Exception {