diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/intrinsics/Equals.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/intrinsics/Equals.kt index 08be60465fb..42c07ddccdd 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/intrinsics/Equals.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/intrinsics/Equals.kt @@ -85,7 +85,7 @@ class Equals(val operator: IElementType) : IntrinsicMethod() { // what comparison means. The optimization does not apply to `object == primitive` as equals // could be overridden for the object. if ((opToken == IrStatementOrigin.EQEQ || opToken == IrStatementOrigin.EXCLEQ) && - ((AsmUtil.isIntOrLongPrimitive(leftType) && !AsmUtil.isPrimitive(rightType)) || + ((AsmUtil.isIntOrLongPrimitive(leftType) && !isPrimitive(rightType)) || (AsmUtil.isIntOrLongPrimitive(rightType) && AsmUtil.isBoxedPrimitiveType(leftType))) ) { val aValue = a.accept(codegen, data).materializedAt(leftType, a.type) diff --git a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/JvmOptimizationLowering.kt b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/JvmOptimizationLowering.kt index e752ba19413..990d4793b83 100644 --- a/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/JvmOptimizationLowering.kt +++ b/compiler/ir/backend.jvm/src/org/jetbrains/kotlin/backend/jvm/lower/JvmOptimizationLowering.kt @@ -10,16 +10,17 @@ import org.jetbrains.kotlin.backend.common.lower.createIrBuilder import org.jetbrains.kotlin.backend.common.lower.irBlock import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase import org.jetbrains.kotlin.backend.jvm.JvmBackendContext -import org.jetbrains.kotlin.codegen.intrinsics.Not +import org.jetbrains.kotlin.backend.jvm.ir.createJvmIrBuilder +import org.jetbrains.kotlin.codegen.AsmUtil import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.ir.IrStatement -import org.jetbrains.kotlin.ir.builders.irGetField -import org.jetbrains.kotlin.ir.builders.irSetField +import org.jetbrains.kotlin.ir.builders.* import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.expressions.impl.IrBlockImpl import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl +import org.jetbrains.kotlin.ir.symbols.IrSymbol import org.jetbrains.kotlin.ir.symbols.impl.IrPublicSymbolBase import org.jetbrains.kotlin.ir.types.* import org.jetbrains.kotlin.ir.util.* @@ -64,6 +65,42 @@ class JvmOptimizationLowering(val context: JvmBackendContext) : FileLoweringPass else -> null } + private class SafeCallInfo( + val scopeSymbol: IrSymbol, + val tmpVal: IrVariable, + val ifNullBranch: IrBranch, + val ifNotNullBranch: IrBranch + ) + + private fun parseSafeCall(expression: IrExpression): SafeCallInfo? { + val block = expression as? IrBlock ?: return null + if (block.origin != IrStatementOrigin.SAFE_CALL) return null + if (block.statements.size != 2) return null + val tmpVal = block.statements[0] as? IrVariable ?: return null + val scopeOwner = tmpVal.parent as? IrDeclaration ?: return null + val scopeSymbol = scopeOwner.symbol + val whenExpr = block.statements[1] as? IrWhen ?: return null + if (whenExpr.branches.size != 2) return null + + val ifNullBranch = whenExpr.branches[0] + val ifNullBranchCondition = ifNullBranch.condition + if (ifNullBranchCondition !is IrCall) return null + if (ifNullBranchCondition.symbol != context.irBuiltIns.eqeqSymbol) return null + val arg0 = ifNullBranchCondition.getValueArgument(0) + if (arg0 !is IrGetValue || arg0.symbol != tmpVal.symbol) return null + val arg1 = ifNullBranchCondition.getValueArgument(1) + if (arg1 !is IrConst<*> || arg1.value != null) return null + val ifNullBranchResult = ifNullBranch.result + if (ifNullBranchResult !is IrConst<*> || ifNullBranchResult.value != null) return null + + val ifNotNullBranch = whenExpr.branches[1] + return SafeCallInfo(scopeSymbol, tmpVal, ifNullBranch, ifNotNullBranch) + } + + private fun IrType.isJvmPrimitive(): Boolean = + // TODO get rid of type mapper (take care of '@EnhancedNullability', maybe some other stuff). + AsmUtil.isPrimitive(context.typeMapper.mapType(this)) + override fun lower(irFile: IrFile) { val transformer = object : IrElementTransformer { @@ -105,20 +142,93 @@ class JvmOptimizationLowering(val context: JvmBackendContext) : FileLoweringPass } getOperandsIfCallToEQEQOrEquals(expression)?.let { (left, right) -> - return when { - left.isNullConst() && right.isNullConst() -> - IrConstImpl.constTrue(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType) + if (left.isNullConst() && right.isNullConst()) + return IrConstImpl.constTrue(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType) - left.isNullConst() && right is IrConst<*> || right.isNullConst() && left is IrConst<*> -> - IrConstImpl.constFalse(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType) + if (left.isNullConst() && right is IrConst<*> || right.isNullConst() && left is IrConst<*>) + return IrConstImpl.constFalse(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType) - else -> expression + val safeCallLeft = parseSafeCall(left) + if (safeCallLeft != null && right.type.isJvmPrimitive()) { + return rewriteSafeCallEqeqPrimitive(safeCallLeft, right, expression) } + + val safeCallRight = parseSafeCall(right) + if (safeCallRight != null && left.type.isJvmPrimitive()) { + return rewritePrimitiveEqeqSafeCall(left, safeCallRight, expression) + } + + return expression } return expression } + private fun rewriteSafeCallEqeqPrimitive(safeCall: SafeCallInfo, primitive: IrExpression, eqeqCall: IrCall): IrExpression = + context.createJvmIrBuilder(safeCall.scopeSymbol).run { + // Fuze safe call with primitive equality to avoid boxing the primitive. + // 'a?.<...> == p' becomes: + // { + // val tmp = a + // when { + // tmp == null -> false + // else -> tmp == p + // } + // } + irBlock { + +safeCall.tmpVal + +irWhen( + eqeqCall.type, + listOf( + irBranch(safeCall.ifNullBranch.condition, irFalse()), + irElseBranch( + irCall(eqeqCall.symbol).apply { + putValueArgument(0, safeCall.ifNotNullBranch.result) + putValueArgument(1, primitive) + } + ) + ) + ) + } + } + + private fun rewritePrimitiveEqeqSafeCall(primitive: IrExpression, safeCall: SafeCallInfo, eqeqCall: IrCall): IrExpression = + context.createJvmIrBuilder(safeCall.scopeSymbol).run { + // Fuze safe call with primitive equality to avoid boxing the primitive. + // 'p == a?.<...>' becomes: + // { + // val tmp_p = p // should evaluate 'p' before 'a' + // val tmp = a + // when { + // tmp == null -> false + // else -> tmp_p == tmp + // } + // } + // 'tmp_p' above could be elided if 'p' is a variable or a constant. + irBlock { + val lhs = + if (primitive.isTrivial()) + primitive + else { + val tmp = irTemporary(primitive) + irGet(tmp) + } + +safeCall.tmpVal + +irWhen( + eqeqCall.type, + listOf( + irBranch(safeCall.ifNullBranch.condition, irFalse()), + irElseBranch( + irCall(eqeqCall.symbol).apply { + putValueArgument(0, lhs) + putValueArgument(1, safeCall.ifNotNullBranch.result) + } + ) + ) + ) + } + } + private fun IrType.isByteOrShort() = isByte() || isShort() // For `==` and `!=`, get rid of safe calls to convert `Byte?` or `Short?` to `Int?`. diff --git a/compiler/testData/codegen/bytecodeText/boxingOptimization/safeCallToPrimitiveEquality.kt b/compiler/testData/codegen/bytecodeText/boxingOptimization/safeCallToPrimitiveEquality.kt index 8bdb3e09c21..afdf3866b32 100644 --- a/compiler/testData/codegen/bytecodeText/boxingOptimization/safeCallToPrimitiveEquality.kt +++ b/compiler/testData/codegen/bytecodeText/boxingOptimization/safeCallToPrimitiveEquality.kt @@ -1,7 +1,3 @@ -// IGNORE_BACKEND_FIR: JVM_IR -// IGNORE_BACKEND: JVM_IR -// TODO KT-36646 Don't box primitive values in equality comparison with nullable primitive values in JVM_IR - fun Long.id() = this fun String.drop2() = if (length >= 2) subSequence(2, length) else null