Rename continuation fields according the convention and count them in IR

This commit is contained in:
Iaroslav Postovalov
2020-12-10 02:27:31 +07:00
committed by Ilmir Usmanov
parent cd2b05eb00
commit 8a7bc2ef6f
10 changed files with 56 additions and 10 deletions
@@ -1113,7 +1113,7 @@ inline fun withInstructionAdapter(block: InstructionAdapter.() -> Unit): InsnLis
return tmpMethodNode.instructions
}
internal fun Type.normalize() =
fun Type.normalize() =
when (sort) {
Type.ARRAY, Type.OBJECT -> AsmTypes.OBJECT_TYPE
else -> this
@@ -1525,6 +1525,11 @@ public class FirBytecodeTextTestGenerated extends AbstractFirBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/simple.kt");
}
@TestMetadata("twoRefs.kt")
public void testTwoRefs() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/twoRefs.kt");
}
@TestMetadata("unusedParamNotSpill.kt")
public void testUnusedParamNotSpill() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/unusedParamNotSpill.kt");
@@ -128,6 +128,8 @@ class JvmBackendContext(
val inlineClassReplacements = MemoizedInlineClassReplacements(state.functionsWithInlineClassReturnTypesMangled, irFactory, this)
internal val continuationClassesVarsCountByType: MutableMap<IrClass, Map<Type, Int>> = hashMapOf()
internal fun referenceClass(descriptor: ClassDescriptor): IrClassSymbol =
symbolTable.lazyWrapper.referenceClass(descriptor)
@@ -358,7 +358,16 @@ class ClassCodegen private constructor(
// lazily so that if tail call optimization kicks in, the unused class will not be written to the output.
val continuationClass = method.continuationClass() // null if `SuspendLambda.invokeSuspend` - `this` is continuation itself
val continuationClassCodegen = lazy { if (continuationClass != null) getOrCreate(continuationClass, context, method) else this }
node.acceptWithStateMachine(method, this, smapCopyingVisitor) { continuationClassCodegen.value.visitor }
node.acceptWithStateMachine(
method,
this,
smapCopyingVisitor,
context.continuationClassesVarsCountByType[continuationClass] ?: emptyMap()
) {
continuationClassCodegen.value.visitor
}
if (continuationClass != null && (continuationClassCodegen.isInitialized() || method.isSuspendCapturingCrossinline())) {
continuationClassCodegen.value.generate()
}
@@ -35,13 +35,15 @@ import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.types.Variance
import org.jetbrains.org.objectweb.asm.MethodVisitor
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.tree.MethodNode
internal fun MethodNode.acceptWithStateMachine(
irFunction: IrFunction,
classCodegen: ClassCodegen,
methodVisitor: MethodVisitor,
obtainContinuationClassBuilder: () -> ClassBuilder
varsCountByType: Map<Type, Int>,
obtainContinuationClassBuilder: () -> ClassBuilder,
) {
val state = classCodegen.context.state
val languageVersionSettings = state.languageVersionSettings
@@ -82,7 +84,8 @@ internal fun MethodNode.acceptWithStateMachine(
disableTailCallOptimizationForFunctionReturningUnit = irFunction.isSuspend && irFunction.suspendFunctionOriginal().let {
it.returnType.isUnit() && it.anyOfOverriddenFunctionsReturnsNonUnit()
},
useOldSpilledVarTypeAnalysis = state.configuration.getBoolean(JVMConfigurationKeys.USE_OLD_SPILLED_VAR_TYPE_ANALYSIS)
useOldSpilledVarTypeAnalysis = state.configuration.getBoolean(JVMConfigurationKeys.USE_OLD_SPILLED_VAR_TYPE_ANALYSIS),
initialVarsCountByType = varsCountByType,
)
accept(visitor)
}
@@ -18,11 +18,11 @@ import org.jetbrains.kotlin.backend.jvm.ir.IrInlineReferenceLocator
import org.jetbrains.kotlin.codegen.coroutines.COROUTINE_LABEL_FIELD_NAME
import org.jetbrains.kotlin.codegen.coroutines.INVOKE_SUSPEND_METHOD_NAME
import org.jetbrains.kotlin.codegen.coroutines.SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME
import org.jetbrains.kotlin.codegen.coroutines.normalize
import org.jetbrains.kotlin.codegen.inline.coroutines.FOR_INLINE_SUFFIX
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.*
@@ -38,6 +38,8 @@ import org.jetbrains.kotlin.ir.visitors.*
import org.jetbrains.kotlin.load.java.JavaDescriptorVisibilities
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin
import org.jetbrains.org.objectweb.asm.Type
internal val suspendLambdaPhase = makeIrFilePhase(
::SuspendLambdaLowering,
@@ -164,7 +166,7 @@ private class SuspendLambdaLowering(context: JvmBackendContext) : SuspendLowerin
+ context.irBuiltIns.anyNType
)
superTypes = listOf(suspendLambda.defaultType, functionNType)
val usedParams = mutableListOf<IrSymbolOwner>()
val usedParams = ArrayList<IrSymbolOwner>(function.explicitParameters.size)
// marking the parameters referenced in the function
function.acceptChildrenVoid(
@@ -173,22 +175,28 @@ private class SuspendLambdaLowering(context: JvmBackendContext) : SuspendLowerin
if (element is IrDeclarationReference && element.symbol is IrValueParameterSymbol && element.symbol.owner in function.explicitParameters)
usedParams += element.symbol.owner
else
Unit
element.acceptChildrenVoid(this)
},
)
addField(COROUTINE_LABEL_FIELD_NAME, context.irBuiltIns.intType, JavaDescriptorVisibilities.PACKAGE_VISIBILITY)
val varsCountByType = HashMap<Type, Int>()
val parametersFields = function.explicitParameters.filter { it in usedParams }.map {
addField {
val normalizedType = context.typeMapper.mapType(it.type).normalize()
val index = varsCountByType[normalizedType]?.plus(1) ?: 0
varsCountByType[normalizedType] = index
// Rename `$this` to avoid being caught by inlineCodegenUtils.isCapturedFieldName()
name = if (it.index < 0) Name.identifier("p\$") else it.name
name = Name.identifier("${normalizedType.descriptor[0]}$$index")
type = it.type
origin = LocalDeclarationsLowering.DECLARATION_ORIGIN_FIELD_FOR_CAPTURED_VALUE
isFinal = false
visibility = if (it.index < 0) DescriptorVisibilities.PRIVATE else JavaDescriptorVisibilities.PACKAGE_VISIBILITY
}
}
context.continuationClassesVarsCountByType[this] = varsCountByType
val constructor = addPrimaryConstructorForLambda(suspendLambda, arity)
val invokeToOverride = functionNClass.functions.single {
it.owner.valueParameters.size == arity + 1 && it.owner.name.asString() == "invoke"
@@ -0,0 +1,9 @@
suspend fun blackhole(a: Any?) {}
suspend fun cleanUpExample(a: String, b: String) {
blackhole(a) // 1
blackhole(b) // 2
}
// 3 ACONST_NULL
// 2 PUTFIELD .*L\$0 : .*;
@@ -1,5 +1,5 @@
val f: suspend (Int) -> Unit = { unused ->
}
// 0 GETFIELD p\$0
// 0 PUTFIELD p\$0
// 0 GETFIELD .*I\$0
// 0 PUTFIELD .*I\$0
@@ -1515,6 +1515,11 @@ public class BytecodeTextTestGenerated extends AbstractBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/simple.kt");
}
@TestMetadata("twoRefs.kt")
public void testTwoRefs() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/twoRefs.kt");
}
@TestMetadata("unusedParamNotSpill.kt")
public void testUnusedParamNotSpill() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/unusedParamNotSpill.kt");
@@ -1525,6 +1525,11 @@ public class IrBytecodeTextTestGenerated extends AbstractIrBytecodeTextTest {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/simple.kt");
}
@TestMetadata("twoRefs.kt")
public void testTwoRefs() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/twoRefs.kt");
}
@TestMetadata("unusedParamNotSpill.kt")
public void testUnusedParamNotSpill() throws Exception {
runTest("compiler/testData/codegen/bytecodeText/coroutines/cleanup/unusedParamNotSpill.kt");