JVM: optimize temporary kotlin.jvm.internal.Refs as well

i.e. remove the condition that there must be an LVT entry. Such
temporary `Ref`s can be created, for example, by the JVM_IR backend
if a lambda inlined at an IR level (e.g. argument to `assert`/`Array`)
is the target of a non-local return from a function inlined at bytecode
level (e.g. `run`):

    IntArray(n) { i ->
        intOrNull?.let { return@IntArray it }
        someInt
    }

->

    val `tmp$0` = IntArray(n)
    for (i in 0 until `tmp$0`.size) {
        var `tmp$1`: Int
        do {
            intOrNull?.let {
                `tmp$1` = it // causes `tmp$1` to become an IntRef
                break
            }
            `tmp$1` = someInt
        } while (false)
        `tmp$0`[i] = `tmp$1`
    }
This commit is contained in:
pyos
2020-05-12 14:53:18 +02:00
committed by max-kammerer
parent 0f2ca5d84c
commit ad53fc931e
5 changed files with 78 additions and 114 deletions
@@ -37,17 +37,9 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
// Tracks proper usages of objects corresponding to captured variables.
//
// The 'kotlin.jvm.internal.Ref.*' instance can be replaced with a local variable,
// if all of the following conditions are satisfied:
// * It is created inside a current method.
// * The only permitted operations on it are:
// - store to a local variable
// - ALOAD, ASTORE
// - DUP, POP
// - GETFIELD <owner>.element, PUTFIELD <owner>.element
// * There's a corresponding local variable definition,
// and all ALOAD/ASTORE instructions operate on that particular local variable.
// * Its 'element' field is initialized at start of local variable visibility range.
// The 'kotlin.jvm.internal.Ref.*' instance can be replaced with a local variable, if
// * it is created inside a current method;
// * the only operations on it are ALOAD, ASTORE, DUP, POP, GETFIELD element, PUTFIELD element.
//
// Note that for code that doesn't create Ref objects explicitly these conditions are true,
// unless the Ref object escapes to a local class constructor (including local classes for lambdas).
@@ -58,18 +50,9 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
var initCallInsn: MethodInsnNode? = null
var localVar: LocalVariableNode? = null
var localVarIndex = -1
val astoreInsns: MutableCollection<VarInsnNode> = LinkedHashSet()
val aloadInsns: MutableCollection<VarInsnNode> = LinkedHashSet()
val stackInsns: MutableCollection<AbstractInsnNode> = LinkedHashSet()
val wrapperInsns: MutableCollection<AbstractInsnNode> = LinkedHashSet()
val getFieldInsns: MutableCollection<FieldInsnNode> = LinkedHashSet()
val putFieldInsns: MutableCollection<FieldInsnNode> = LinkedHashSet()
var cleanVarInstruction: VarInsnNode? = null
fun canRewrite(): Boolean =
!hazard &&
initCallInsn != null &&
localVar != null &&
localVarIndex >= 0
override fun onUseAsTainted() {
hazard = true
@@ -79,26 +62,29 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
private class Transformer(private val internalClassName: String, private val methodNode: MethodNode) {
private val refValues = ArrayList<CapturedVarDescriptor>()
private val refValuesByNewInsn = LinkedHashMap<TypeInsnNode, CapturedVarDescriptor>()
private val insns = methodNode.instructions.toArray()
private lateinit var frames: Array<out Frame<BasicValue>?>
val hasRewritableRefValues: Boolean
get() = refValues.isNotEmpty()
fun run() {
createRefValues()
if (!hasRewritableRefValues) return
if (refValues.isEmpty()) return
analyze()
if (!hasRewritableRefValues) return
val frames = analyze(internalClassName, methodNode, Interpreter())
trackPops(frames)
assignLocalVars(frames)
rewrite()
for (refValue in refValues) {
if (!refValue.hazard) {
rewriteRefValue(refValue)
}
}
methodNode.removeEmptyCatchBlocks()
methodNode.removeUnusedLocalVariables()
}
private fun AbstractInsnNode.getIndex() = methodNode.instructions.indexOf(this)
private fun createRefValues() {
for (insn in insns) {
for (insn in methodNode.instructions) {
if (insn.opcode == Opcodes.NEW && insn is TypeInsnNode) {
val type = Type.getObjectType(insn.desc)
if (AsmTypes.isSharedVarType(type)) {
@@ -113,19 +99,15 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
private inner class Interpreter : ReferenceTrackingInterpreter() {
override fun newOperation(insn: AbstractInsnNode): BasicValue =
refValuesByNewInsn[insn]?.let { descriptor ->
ProperTrackedReferenceValue(descriptor.refType, descriptor)
}
?: super.newOperation(insn)
refValuesByNewInsn[insn]?.let { ProperTrackedReferenceValue(it.refType, it) } ?: super.newOperation(insn)
override fun processRefValueUsage(value: TrackedReferenceValue, insn: AbstractInsnNode, position: Int) {
for (descriptor in value.descriptors) {
if (descriptor !is CapturedVarDescriptor) throw AssertionError("Unexpected descriptor: $descriptor")
when {
insn.opcode == Opcodes.ALOAD ->
descriptor.aloadInsns.add(insn as VarInsnNode)
insn.opcode == Opcodes.ASTORE ->
descriptor.astoreInsns.add(insn as VarInsnNode)
insn.opcode == Opcodes.DUP -> descriptor.wrapperInsns.add(insn)
insn.opcode == Opcodes.ALOAD -> descriptor.wrapperInsns.add(insn)
insn.opcode == Opcodes.ASTORE -> descriptor.wrapperInsns.add(insn)
insn.opcode == Opcodes.GETFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
descriptor.getFieldInsns.add(insn)
insn.opcode == Opcodes.PUTFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
@@ -135,32 +117,18 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
descriptor.hazard = true
else
descriptor.initCallInsn = insn
insn.opcode == Opcodes.DUP ->
descriptor.stackInsns.add(insn)
else ->
descriptor.hazard = true
else -> descriptor.hazard = true
}
}
}
}
private fun analyze() {
frames = MethodTransformer.analyze(internalClassName, methodNode, Interpreter())
trackPops()
assignLocalVars()
refValues.removeAll { !it.canRewrite() }
}
private fun trackPops() {
for (i in insns.indices) {
private fun trackPops(frames: Array<out Frame<BasicValue>?>) {
for ((i, insn) in methodNode.instructions.withIndex()) {
val frame = frames[i] ?: continue
val insn = insns[i]
when (insn.opcode) {
Opcodes.POP -> {
frame.top()?.getCapturedVarOrNull()?.run { stackInsns.add(insn) }
frame.top()?.getCapturedVarOrNull()?.run { wrapperInsns.add(insn) }
}
Opcodes.POP2 -> {
val top = frame.top()
@@ -176,7 +144,7 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
private fun BasicValue.getCapturedVarOrNull() =
safeAs<ProperTrackedReferenceValue>()?.descriptor?.safeAs<CapturedVarDescriptor>()
private fun assignLocalVars() {
private fun assignLocalVars(frames: Array<out Frame<BasicValue>?>) {
for (localVar in methodNode.localVariables) {
val type = Type.getType(localVar.desc)
if (!AsmTypes.isSharedVarType(type)) continue
@@ -197,51 +165,20 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
for (refValue in refValues) {
if (refValue.hazard) continue
val localVar = refValue.localVar ?: continue
val oldVarIndex = localVar.index
if (refValue.valueType.size != 1) {
if (refValue.localVar == null || refValue.valueType.size != 1) {
refValue.localVarIndex = methodNode.maxLocals
methodNode.maxLocals += 2
localVar.index = refValue.localVarIndex
methodNode.maxLocals += refValue.valueType.size
} else {
refValue.localVarIndex = localVar.index
refValue.localVarIndex = refValue.localVar!!.index
}
val cleanInstructions = findCleanInstructions(refValue, oldVarIndex, methodNode.instructions)
if (cleanInstructions.size > 1) {
refValue.hazard = true
continue
}
refValue.cleanVarInstruction = cleanInstructions.firstOrNull()
}
}
private fun findCleanInstructions(refValue: CapturedVarDescriptor, oldVarIndex: Int, instructions: InsnList): List<VarInsnNode> {
return InsnSequence(instructions).filterIsInstance<VarInsnNode>().filter {
it.opcode == Opcodes.ASTORE && it.`var` == oldVarIndex
}.filter {
it.previous?.opcode == Opcodes.ACONST_NULL
}.filter {
val operationIndex = instructions.indexOf(it)
val localVariableNode = refValue.localVar!!
instructions.indexOf(localVariableNode.start) < operationIndex && operationIndex < instructions.indexOf(
localVariableNode.end
)
}.toList()
}
private fun rewrite() {
for (refValue in refValues) {
if (!refValue.canRewrite()) continue
rewriteRefValue(refValue)
private fun LocalVariableNode.findCleanInstructions() =
InsnSequence(methodNode.instructions).dropWhile { it != start }.takeWhile { it != end }.filter {
it is VarInsnNode && it.opcode == Opcodes.ASTORE && it.`var` == index && it.previous?.opcode == Opcodes.ACONST_NULL
}
methodNode.removeEmptyCatchBlocks()
methodNode.removeUnusedLocalVariables()
}
// Be careful to not remove instructions that are the only instruction for a line number. That will
// break debugging. If the previous instruction is a line number and the following instruction is
// a label followed by a line number, insert a nop instead of deleting the instruction.
@@ -255,34 +192,38 @@ class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
private fun rewriteRefValue(capturedVar: CapturedVarDescriptor) {
methodNode.instructions.run {
val localVar = capturedVar.localVar!!
localVar.signature = null
localVar.desc = capturedVar.valueType.descriptor
val loadOpcode = capturedVar.valueType.getOpcode(Opcodes.ILOAD)
val storeOpcode = capturedVar.valueType.getOpcode(Opcodes.ISTORE)
if (capturedVar.putFieldInsns.none { it.getIndex() < localVar.start.getIndex() }) {
// variable needs to be initialized before its live range can begin
insertBefore(capturedVar.newInsn, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
insertBefore(capturedVar.newInsn, VarInsnNode(storeOpcode, capturedVar.localVarIndex))
val localVar = capturedVar.localVar
if (localVar != null) {
if (capturedVar.putFieldInsns.none { it.getIndex() < localVar.start.getIndex() }) {
// variable needs to be initialized before its live range can begin
insertBefore(capturedVar.newInsn, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
insertBefore(capturedVar.newInsn, VarInsnNode(storeOpcode, capturedVar.localVarIndex))
}
for (insn in localVar.findCleanInstructions()) {
// after visiting block codegen tries to delete all allocated references:
// see ExpressionCodegen.addLeaveTaskToRemoveLocalVariableFromFrameMap
if (storeOpcode == Opcodes.ASTORE) {
set(insn.previous, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
} else {
remove(insn.previous)
remove(insn)
}
}
localVar.index = capturedVar.localVarIndex
localVar.desc = capturedVar.valueType.descriptor
localVar.signature = null
}
remove(capturedVar.newInsn)
remove(capturedVar.initCallInsn!!)
capturedVar.stackInsns.forEach { removeOrReplaceByNop(it) }
capturedVar.aloadInsns.forEach { removeOrReplaceByNop(it) }
capturedVar.astoreInsns.forEach { removeOrReplaceByNop(it) }
capturedVar.wrapperInsns.forEach { removeOrReplaceByNop(it) }
capturedVar.getFieldInsns.forEach { set(it, VarInsnNode(loadOpcode, capturedVar.localVarIndex)) }
capturedVar.putFieldInsns.forEach { set(it, VarInsnNode(storeOpcode, capturedVar.localVarIndex)) }
//after visiting block codegen tries to delete all allocated references:
// see ExpressionCodegen.addLeaveTaskToRemoveLocalVariableFromFrameMap
capturedVar.cleanVarInstruction?.let {
remove(it.previous)
remove(it)
}
}
}
@@ -974,6 +974,12 @@ public class FirBytecodeTextTestGenerated extends AbstractFirBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/capturedVarsOfSize2.kt");
}
@Test
@TestMetadata("returnValueOfArrayConstructor.kt")
public void testReturnValueOfArrayConstructor() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/returnValueOfArrayConstructor.kt");
}
@Test
@TestMetadata("sharedSlotsWithCapturedVars.kt")
public void testSharedSlotsWithCapturedVars() throws Exception {
@@ -0,0 +1,5 @@
fun f() = IntArray(1) { run { return@IntArray 1 } }
// On JVM_IR, the return is an assignment to a captured var followed by
// a non-local `break` from a `do ... while (false)`. The var should be optimized.
// 0 IntRef
@@ -974,6 +974,12 @@ public class BytecodeTextTestGenerated extends AbstractBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/capturedVarsOfSize2.kt");
}
@Test
@TestMetadata("returnValueOfArrayConstructor.kt")
public void testReturnValueOfArrayConstructor() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/returnValueOfArrayConstructor.kt");
}
@Test
@TestMetadata("sharedSlotsWithCapturedVars.kt")
public void testSharedSlotsWithCapturedVars() throws Exception {
@@ -974,6 +974,12 @@ public class IrBytecodeTextTestGenerated extends AbstractIrBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/capturedVarsOfSize2.kt");
}
@Test
@TestMetadata("returnValueOfArrayConstructor.kt")
public void testReturnValueOfArrayConstructor() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/capturedVarsOptimization/returnValueOfArrayConstructor.kt");
}
@Test
@TestMetadata("sharedSlotsWithCapturedVars.kt")
public void testSharedSlotsWithCapturedVars() throws Exception {