Generate state machine for named functions in their bodies

Inline functions aren't supported yet in the change

 #KT-17585 In Progress
This commit is contained in:
Denis Zharkov
2017-04-25 15:10:48 +03:00
parent 59d89a1ae3
commit c3a032ea0b
10 changed files with 345 additions and 133 deletions
@@ -646,6 +646,8 @@ public class AsmUtil {
@NotNull FrameMap frameMap
) {
if (state.isParamAssertionsDisabled()) return;
// currently when resuming a suspend function we pass default values instead of real arguments (i.e. nulls for references)
if (descriptor.isSuspend()) return;
// Private method is not accessible from other classes, no assertions needed
if (getVisibilityAccessFlag(descriptor) == ACC_PRIVATE) return;
@@ -35,7 +35,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns;
import org.jetbrains.kotlin.codegen.binding.CalculatedClosure;
import org.jetbrains.kotlin.codegen.binding.CodegenBinding;
import org.jetbrains.kotlin.codegen.context.*;
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegen;
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenForLambda;
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenUtilKt;
import org.jetbrains.kotlin.codegen.coroutines.ResolvedCallWithRealDescriptor;
import org.jetbrains.kotlin.codegen.extensions.ExpressionCodegenExtension;
@@ -52,6 +52,7 @@ import org.jetbrains.kotlin.codegen.when.SwitchCodegen;
import org.jetbrains.kotlin.codegen.when.SwitchCodegenUtil;
import org.jetbrains.kotlin.config.ApiVersion;
import org.jetbrains.kotlin.descriptors.*;
import org.jetbrains.kotlin.descriptors.impl.AnonymousFunctionDescriptor;
import org.jetbrains.kotlin.descriptors.impl.LocalVariableDescriptor;
import org.jetbrains.kotlin.descriptors.impl.SyntheticFieldDescriptor;
import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptor;
@@ -942,7 +943,7 @@ public class ExpressionCodegen extends KtVisitor<StackValue, StackValue> impleme
declaration.getContainingFile()
);
ClosureCodegen coroutineCodegen = CoroutineCodegen.createByLambda(this, descriptor, declaration, cv);
ClosureCodegen coroutineCodegen = CoroutineCodegenForLambda.create(this, descriptor, declaration, cv);
ClosureCodegen closureCodegen = coroutineCodegen != null ? coroutineCodegen : new ClosureCodegen(
state, declaration, samType, context.intoClosure(descriptor, this, typeMapper),
functionReferenceTarget, strategy, parentCodegen, cv
@@ -1188,12 +1189,14 @@ public class ExpressionCodegen extends KtVisitor<StackValue, StackValue> impleme
bindingContext.get(ENCLOSING_SUSPEND_FUNCTION_FOR_SUSPEND_FUNCTION_CALL, resolvedCall.getCall());
if (enclosingSuspendLambdaForSuspensionPoint == null) return null;
return genCoroutineInstanceBySuspendFunction(enclosingSuspendLambdaForSuspensionPoint);
return genCoroutineInstanceForSuspendLambda(enclosingSuspendLambdaForSuspensionPoint);
}
@Nullable
private StackValue genCoroutineInstanceBySuspendFunction(@NotNull FunctionDescriptor suspendFunction) {
private StackValue genCoroutineInstanceForSuspendLambda(@NotNull FunctionDescriptor suspendFunction) {
if (!CoroutineCodegenUtilKt.isStateMachineNeeded(suspendFunction, bindingContext)) return null;
if (!(suspendFunction instanceof AnonymousFunctionDescriptor)) return null;
ClassDescriptor suspendLambdaClassDescriptor = bindingContext.get(CodegenBinding.CLASS_FOR_CALLABLE, suspendFunction);
assert suspendLambdaClassDescriptor != null : "Coroutine class descriptor should not be null";
@@ -135,7 +135,8 @@ public class FunctionCodegen {
strategy = new SuspendFunctionGenerationStrategy(
state,
CoroutineCodegenUtilKt.<FunctionDescriptor>unwrapInitialDescriptorForSuspendFunction(functionDescriptor),
function
function,
v.getThisName()
);
}
else {
@@ -43,6 +43,7 @@ class JvmRuntimeTypes(module: ModuleDescriptor) {
private val localVariableReference: ClassDescriptor by klass("LocalVariableReference")
private val mutableLocalVariableReference: ClassDescriptor by klass("MutableLocalVariableReference")
private val coroutineImplClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImpl") }
private val coroutineImplForNamedFunctionClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImplForNamedFunction") }
private val propertyReferences: List<ClassDescriptor> by lazy {
(0..2).map { i -> createClass(kotlinJvmInternalPackage, "PropertyReference$i") }
@@ -84,11 +85,13 @@ class JvmRuntimeTypes(module: ModuleDescriptor) {
if (descriptor.isSuspend) {
return mutableListOf<KotlinType>().apply {
add(coroutineImplClass.defaultType)
if (descriptor.isSuspendLambda) {
add(coroutineImplClass.defaultType)
add(functionType)
}
else {
add(coroutineImplForNamedFunctionClass.defaultType)
}
}
}
@@ -21,7 +21,6 @@ import org.jetbrains.kotlin.codegen.*
import org.jetbrains.kotlin.codegen.binding.CodegenBinding
import org.jetbrains.kotlin.codegen.context.ClosureContext
import org.jetbrains.kotlin.codegen.context.MethodContext
import org.jetbrains.kotlin.codegen.state.GenerationState
import org.jetbrains.kotlin.coroutines.isSuspendLambda
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.descriptors.annotations.Annotations
@@ -42,7 +41,7 @@ import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.typeUtil.makeNullable
import org.jetbrains.kotlin.utils.addToStdlib.safeAs
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.kotlin.utils.sure
import org.jetbrains.org.objectweb.asm.MethodVisitor
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
@@ -50,19 +49,64 @@ import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
import org.jetbrains.org.objectweb.asm.commons.Method
class CoroutineCodegen private constructor(
abstract class AbstractCoroutineCodegen(
outerExpressionCodegen: ExpressionCodegen,
element: KtElement,
private val closureContext: ClosureContext,
classBuilder: ClassBuilder,
private val originalSuspendFunctionDescriptor: FunctionDescriptor,
private val isSuspendLambda: Boolean
closureContext: ClosureContext,
classBuilder: ClassBuilder
) : ClosureCodegen(
outerExpressionCodegen.state,
element, null, closureContext, null,
FailingFunctionGenerationStrategy,
outerExpressionCodegen.parentCodegen, classBuilder
) {
override fun generateConstructor(): Method {
val args = calculateConstructorParameters(typeMapper, closure, asmType)
val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray()
val constructor = Method("<init>", Type.VOID_TYPE, argTypes)
val mv = v.newMethod(
OtherOrigin(element, funDescriptor), visibilityFlag, "<init>", constructor.descriptor, null,
ArrayUtil.EMPTY_STRING_ARRAY
)
if (state.classBuilderMode.generateBodies) {
mv.visitCode()
val iv = InstructionAdapter(mv)
iv.generateClosureFieldsInitializationFromParameters(closure, args)
iv.load(0, AsmTypes.OBJECT_TYPE)
if (passArityToSuperClass) {
iv.iconst(calculateArity())
}
iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE)
val superClassConstructorDescriptor = Type.getMethodDescriptor(
Type.VOID_TYPE,
*(if (passArityToSuperClass) arrayOf(Type.INT_TYPE) else emptyArray()),
CONTINUATION_ASM_TYPE
)
iv.invokespecial(superClassAsmType.internalName, "<init>", superClassConstructorDescriptor, false)
iv.visitInsn(Opcodes.RETURN)
FunctionCodegen.endVisit(iv, "constructor", element)
}
return constructor
}
abstract protected val passArityToSuperClass: Boolean
}
class CoroutineCodegenForLambda private constructor(
outerExpressionCodegen: ExpressionCodegen,
element: KtElement,
private val closureContext: ClosureContext,
classBuilder: ClassBuilder,
private val originalSuspendFunctionDescriptor: FunctionDescriptor
) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) {
private val classDescriptor = closureContext.contextDescriptor
private val builtIns = funDescriptor.builtIns
@@ -126,16 +170,9 @@ class CoroutineCodegen private constructor(
generateDoResume()
}
override fun generateBridges() {
if (!isSuspendLambda) return
super.generateBridges()
}
override fun generateBody() {
super.generateBody()
if (!isSuspendLambda) return
// create() = ...
functionCodegen.generateMethod(JvmDeclarationOrigin.NO_ORIGIN, createCoroutineDescriptor,
object : FunctionGenerationStrategy.CodegenBased(state) {
@@ -187,41 +224,14 @@ class CoroutineCodegen private constructor(
areturn(AsmTypes.OBJECT_TYPE)
}
override val passArityToSuperClass get() = true
override fun generateConstructor(): Method {
val args = calculateConstructorParameters(typeMapper, closure, asmType)
val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray()
val constructor = Method("<init>", Type.VOID_TYPE, argTypes)
val mv = v.newMethod(
OtherOrigin(element, funDescriptor), visibilityFlag, "<init>", constructor.descriptor, null,
ArrayUtil.EMPTY_STRING_ARRAY
)
constructorToUseFromInvoke = constructor
if (state.classBuilderMode.generateBodies) {
mv.visitCode()
val iv = InstructionAdapter(mv)
iv.generateClosureFieldsInitializationFromParameters(closure, args)
iv.load(0, AsmTypes.OBJECT_TYPE)
iv.iconst(calculateArity())
iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE)
val superClassConstructorDescriptor = Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE, CONTINUATION_ASM_TYPE)
iv.invokespecial(superClassAsmType.internalName, "<init>", superClassConstructorDescriptor, false)
iv.visitInsn(Opcodes.RETURN)
FunctionCodegen.endVisit(iv, "constructor", element)
}
return constructor
constructorToUseFromInvoke = super.generateConstructor()
return constructorToUseFromInvoke
}
private fun generateCreateCoroutineMethod(codegen: ExpressionCodegen) {
assert(isSuspendLambda) { "create method should only be generated for suspend lambdas" }
val classDescriptor = closureContext.contextDescriptor
val owner = typeMapper.mapClass(classDescriptor)
@@ -260,15 +270,11 @@ class CoroutineCodegen private constructor(
}
private fun ExpressionCodegen.initializeCoroutineParameters() {
if (!isSuspendLambda && !originalSuspendFunctionDescriptor.isTailrec) return
for (parameter in allFunctionParameters()) {
val fieldStackValue =
if (isSuspendLambda)
StackValue.field(
parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false)
)
else
closureContext.lookupInContext(parameter, null, state, /* ignoreNoOuter = */ false)
StackValue.field(
parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false)
)
val mappedType = typeMapper.mapType(parameter.type)
fieldStackValue.put(mappedType, v)
@@ -277,14 +283,7 @@ class CoroutineCodegen private constructor(
v.store(newIndex, mappedType)
}
// necessary for proper tailrec codegen
val actualMethodStartLabel = Label()
v.visitLabel(actualMethodStartLabel)
context.setMethodStartLabel(actualMethodStartLabel)
if (isSuspendLambda) {
initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters)
}
initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters)
}
private fun allFunctionParameters() =
@@ -308,7 +307,12 @@ class CoroutineCodegen private constructor(
object : FunctionGenerationStrategy.FunctionDefault(state, element as KtDeclarationWithBody) {
override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor {
return CoroutineTransformerMethodVisitor(mv, access, name, desc, null, null, v)
return CoroutineTransformerMethodVisitor(
mv, access, name, desc, null, null,
obtainClassBuilderForCoroutineState = { v },
containingClassInternalName = v.thisName,
isForNamedFunction = false
)
}
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
@@ -319,34 +323,17 @@ class CoroutineCodegen private constructor(
)
}
override fun generateKotlinMetadataAnnotation() {
if (isSuspendLambda) {
super.generateKotlinMetadataAnnotation()
}
else {
writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) {
// Do not write method metadata for raw coroutine state machines
}
}
}
companion object {
fun shouldCreateByLambda(
originalSuspendLambdaDescriptor: CallableDescriptor,
declaration: KtElement): Boolean {
return (declaration is KtFunctionLiteral && originalSuspendLambdaDescriptor.isSuspendLambda)
}
@JvmStatic
fun createByLambda(
fun create(
expressionCodegen: ExpressionCodegen,
originalSuspendLambdaDescriptor: FunctionDescriptor,
declaration: KtElement,
classBuilder: ClassBuilder
): ClosureCodegen? {
if (!shouldCreateByLambda(originalSuspendLambdaDescriptor, declaration)) return null
if (declaration !is KtFunctionLiteral || !originalSuspendLambdaDescriptor.isSuspendLambda) return null
return CoroutineCodegen(
return CoroutineCodegenForLambda(
expressionCodegen,
declaration,
expressionCodegen.context.intoCoroutineClosure(
@@ -354,31 +341,116 @@ class CoroutineCodegen private constructor(
originalSuspendLambdaDescriptor, expressionCodegen, expressionCodegen.state.typeMapper
),
classBuilder,
originalSuspendLambdaDescriptor,
isSuspendLambda = true
originalSuspendLambdaDescriptor
)
}
}
}
class CoroutineCodegenForNamedFunction private constructor(
outerExpressionCodegen: ExpressionCodegen,
element: KtElement,
closureContext: ClosureContext,
classBuilder: ClassBuilder,
originalSuspendFunctionDescriptor: FunctionDescriptor
) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) {
private val classDescriptor = closureContext.contextDescriptor
private val suspendFunctionJvmView =
bindingContext[CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendFunctionDescriptor]!!
// protected fun doResume(): Any?
private val doResumeDescriptor =
SimpleFunctionDescriptorImpl.create(
classDescriptor, Annotations.EMPTY, Name.identifier(DO_RESUME_METHOD_NAME), CallableMemberDescriptor.Kind.DECLARATION,
funDescriptor.source
).apply doResume@{
initialize(
/* receiverParameterType = */ null,
classDescriptor.thisAsReceiverParameter,
/* typeParameters = */ emptyList(),
listOf(),
builtIns.nullableAnyType,
Modality.FINAL,
Visibilities.PUBLIC
)
}
override val passArityToSuperClass get() = false
override fun generateBridges() {
// Do not generate any closure bridges
}
override fun generateClosureBody() {
generateDoResume()
}
private fun generateDoResume() {
functionCodegen.generateMethod(
OtherOrigin(element),
doResumeDescriptor,
object : FunctionGenerationStrategy.CodegenBased(state) {
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
val captureThisType = closure.captureThis?.let(typeMapper::mapType)
if (captureThisType != null) {
StackValue.field(
captureThisType, Type.getObjectType(v.thisName), AsmUtil.CAPTURED_THIS_FIELD,
false, StackValue.LOCAL_0
).put(captureThisType, codegen.v)
}
val callableMethod = typeMapper.mapToCallableMethod(suspendFunctionJvmView, false)
for (argumentType in callableMethod.getAsmMethod().argumentTypes.dropLast(1)) {
AsmUtil.pushDefaultValueOnStack(argumentType, codegen.v)
}
codegen.v.load(0, AsmTypes.OBJECT_TYPE)
callableMethod.genInvokeInstruction(codegen.v)
codegen.v.visitInsn(Opcodes.ARETURN)
}
}
)
}
override fun generateKotlinMetadataAnnotation() {
writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) {
// Do not write method metadata for raw coroutine state machines
}
}
companion object {
fun create(
cv: ClassBuilder,
expressionCodegen: ExpressionCodegen,
originalSuspendDescriptor: FunctionDescriptor,
declaration: KtFunction,
state: GenerationState
): CoroutineCodegen {
val cv = state.factory.newVisitor(
OtherOrigin(declaration, originalSuspendDescriptor),
CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor),
declaration.containingFile
)
declaration: KtFunction
): CoroutineCodegenForNamedFunction {
val bindingContext = expressionCodegen.state.bindingContext
val closure =
bindingContext[
CodegenBinding.CLOSURE,
bindingContext[CodegenBinding.CLASS_FOR_CALLABLE, originalSuspendDescriptor]
].sure { "There must be a closure defined for $originalSuspendDescriptor" }
return CoroutineCodegen(
val suspendFunctionView =
bindingContext[
CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendDescriptor
].sure { "There must be a jvm view defined for $originalSuspendDescriptor" }
if (suspendFunctionView.dispatchReceiverParameter != null) {
closure.setCaptureThis()
}
return CoroutineCodegenForNamedFunction(
expressionCodegen, declaration,
expressionCodegen.context.intoClosure(
originalSuspendDescriptor, expressionCodegen, expressionCodegen.state.typeMapper
),
cv,
originalSuspendDescriptor,
isSuspendLambda = false
originalSuspendDescriptor
)
}
}
@@ -44,8 +44,15 @@ class CoroutineTransformerMethodVisitor(
desc: String,
signature: String?,
exceptions: Array<out String>?,
private val classBuilder: ClassBuilder
private val containingClassInternalName: String,
private val classBuilderForCoroutineState: ClassBuilder,
isForNamedFunction: Boolean
) : TransformationMethodVisitor(delegate, access, name, desc, signature, exceptions) {
private val continuationIndex = if (isForNamedFunction) getLastParameterIndex(desc, access) else 0
private val dataIndex = continuationIndex + 1
private val exceptionIndex = dataIndex + 1
override fun performTransformations(methodNode: MethodNode) {
val customCoroutineStartMarker = methodNode.instructions.toArray().filterIsInstance<MethodInsnNode>().firstOrNull {
it.owner == COROUTINE_MARKER_OWNER && it.name == ACTUAL_COROUTINE_START_MARKER_NAME
@@ -61,7 +68,7 @@ class CoroutineTransformerMethodVisitor(
}
// Spill stack to variables before suspension points, try/catch blocks
FixStackWithLabelNormalizationMethodTransformer().transform(classBuilder.thisName, methodNode)
FixStackWithLabelNormalizationMethodTransformer().transform(containingClassInternalName, methodNode)
// Remove unreachable suspension points
// If we don't do this, then relevant frames will not be analyzed, that is unexpected from point of view of next steps (e.g. variable spilling)
@@ -86,7 +93,7 @@ class CoroutineTransformerMethodVisitor(
insnListOf(
*withInstructionAdapter { loadCoroutineSuspendedMarker() }.toArray(),
VarInsnNode(Opcodes.ASTORE, suspendMarkerVarIndex),
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, continuationIndex),
FieldInsnNode(
Opcodes.GETFIELD,
COROUTINE_IMPL_ASM_TYPE.internalName,
@@ -101,7 +108,7 @@ class CoroutineTransformerMethodVisitor(
)
)
insert(startLabel, withInstructionAdapter(InstructionAdapter::generateResumeWithExceptionCheck))
insert(startLabel, withInstructionAdapter { generateResumeWithExceptionCheck(exceptionIndex) })
insert(last, withInstructionAdapter {
visitLabel(defaultLabel.label)
@@ -116,7 +123,7 @@ class CoroutineTransformerMethodVisitor(
}
private fun removeUnreachableSuspensionPointsAndExitPoints(methodNode: MethodNode, suspensionPoints: MutableList<SuspensionPoint>) {
val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(classBuilder.thisName, methodNode)
val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(containingClassInternalName, methodNode)
// If the suspension call begin is alive and suspension call end is dead
// (e.g., an inlined suspend function call ends with throwing a exception -- see KT-15017),
@@ -166,7 +173,7 @@ class CoroutineTransformerMethodVisitor(
private fun spillVariables(suspensionPoints: List<SuspensionPoint>, methodNode: MethodNode) {
val instructions = methodNode.instructions
val frames = performRefinedTypeAnalysis(methodNode, classBuilder.thisName)
val frames = performRefinedTypeAnalysis(methodNode, containingClassInternalName)
fun AbstractInsnNode.index() = instructions.indexOf(this)
// We postpone these actions because they change instruction indices that we use when obtaining frames
@@ -200,13 +207,15 @@ class CoroutineTransformerMethodVisitor(
val livenessFrame = livenessFrames[suspensionCallBegin.index()]
// 0 - this
// 1 - continuation argument
// 2 - continuation exception
// 1 - parameter
// ...
// k - continuation
// k + 1 - data
// k + 2 - exception
val variablesToSpill =
(3 until localsCount)
((exceptionIndex + 1) until localsCount)
.map { Pair(it, frame.getLocal(it)) }
.filter {
val (index, value) = it
.filter { (index, value) ->
value != StrictBasicValue.UNINITIALIZED_VALUE && livenessFrame.isAlive(index)
}
@@ -235,16 +244,16 @@ class CoroutineTransformerMethodVisitor(
with(instructions) {
// store variable before suspension call
insertBefore(suspension.suspensionCallBegin, withInstructionAdapter {
load(0, AsmTypes.OBJECT_TYPE)
load(continuationIndex, AsmTypes.OBJECT_TYPE)
load(index, type)
StackValue.coerce(type, normalizedType, this)
putfield(classBuilder.thisName, fieldName, normalizedType.descriptor)
putfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor)
})
// restore variable after suspension call
insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter {
load(0, AsmTypes.OBJECT_TYPE)
getfield(classBuilder.thisName, fieldName, normalizedType.descriptor)
load(continuationIndex, AsmTypes.OBJECT_TYPE)
getfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor)
StackValue.coerce(normalizedType, type, this)
store(index, type)
})
@@ -262,8 +271,8 @@ class CoroutineTransformerMethodVisitor(
maxVarsCountByType.forEach { entry ->
val (type, maxIndex) = entry
for (index in 0..maxIndex) {
classBuilder.newField(
JvmDeclarationOrigin.NO_ORIGIN, Opcodes.ACC_PRIVATE,
classBuilderForCoroutineState.newField(
JvmDeclarationOrigin.NO_ORIGIN, AsmUtil.NO_FLAG_PACKAGE_PRIVATE,
type.fieldNameForVar(index), type.descriptor, null, null)
}
}
@@ -294,7 +303,7 @@ class CoroutineTransformerMethodVisitor(
// Save state
insertBefore(suspension.suspensionCallBegin,
insnListOf(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, continuationIndex),
*withInstructionAdapter { iconst(id) }.toArray(),
FieldInsnNode(
Opcodes.PUTFIELD, COROUTINE_IMPL_ASM_TYPE.internalName, COROUTINE_LABEL_FIELD_NAME,
@@ -327,10 +336,10 @@ class CoroutineTransformerMethodVisitor(
remove(possibleTryCatchBlockStart.previous)
insert(possibleTryCatchBlockStart, withInstructionAdapter {
generateResumeWithExceptionCheck()
generateResumeWithExceptionCheck(exceptionIndex)
// Load continuation argument just like suspending function returns it
load(1, AsmTypes.OBJECT_TYPE)
load(dataIndex, AsmTypes.OBJECT_TYPE)
visitLabel(continuationLabelAfterLoadedResult.label)
})
@@ -389,9 +398,9 @@ class CoroutineTransformerMethodVisitor(
}
}
private fun InstructionAdapter.generateResumeWithExceptionCheck() {
private fun InstructionAdapter.generateResumeWithExceptionCheck(exceptionIndex: Int) {
// Check if resumeWithException has been called
load(2, AsmTypes.OBJECT_TYPE)
load(exceptionIndex, AsmTypes.OBJECT_TYPE)
dup()
val noExceptionLabel = Label()
ifnull(noExceptionLabel)
@@ -432,3 +441,6 @@ private class SuspensionPoint(
) {
lateinit var tryCatchBlocksContinuationLabel: LabelNode
}
private fun getLastParameterIndex(desc: String, access: Int) =
Type.getArgumentTypes(desc).dropLast(1).map { it.size }.sum() + (if (access and Opcodes.ACC_STATIC != 0) 0 else 1)
@@ -18,33 +18,116 @@ package org.jetbrains.kotlin.codegen.coroutines
import org.jetbrains.kotlin.codegen.ExpressionCodegen
import org.jetbrains.kotlin.codegen.FunctionGenerationStrategy
import org.jetbrains.kotlin.codegen.binding.CodegenBinding
import org.jetbrains.kotlin.codegen.state.GenerationState
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.psi.KtFunction
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.kotlin.resolve.jvm.diagnostics.OtherOrigin
import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.org.objectweb.asm.MethodVisitor
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
class SuspendFunctionGenerationStrategy(
state: GenerationState,
private val originalSuspendDescriptor: FunctionDescriptor,
private val declaration: KtFunction
private val declaration: KtFunction,
private val containingClassInternalName: String
) : FunctionGenerationStrategy.CodegenBased(state) {
private val containsNonTailSuspensionCalls = originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext)
private val classBuilderForCoroutineState by lazy {
state.factory.newVisitor(
OtherOrigin(declaration, originalSuspendDescriptor),
CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor),
declaration.containingFile
)
}
override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor {
if (containsNonTailSuspensionCalls) {
return CoroutineTransformerMethodVisitor(
mv, access, name, desc, null, null, containingClassInternalName, classBuilderForCoroutineState,
isForNamedFunction = true
)
}
return super.wrapMethodVisitor(mv, access, name, desc)
}
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
if (!originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext)) {
if (!containsNonTailSuspensionCalls) {
codegen.returnExpression(declaration.bodyExpression)
return
}
val coroutineCodegen = CoroutineCodegen.create(codegen, originalSuspendDescriptor, declaration, state)
val coroutineCodegen =
CoroutineForNamedFunctionCodegen.create(classBuilderForCoroutineState, codegen, originalSuspendDescriptor, declaration)
coroutineCodegen.generate()
codegen.markLineNumber(declaration, false)
codegen.putClosureInstanceOnStack(coroutineCodegen, null).put(Type.getObjectType(coroutineCodegen.className), codegen.v)
codegen.v.invokeDoResumeWithUnit(coroutineCodegen.v.thisName)
val functionDescriptor = codegen.context.functionDescriptor
codegen.markLineNumber(declaration, true)
codegen.v.areturn(AsmTypes.OBJECT_TYPE)
val continuationIndex = codegen.frameMap.getIndex(functionDescriptor.valueParameters.last())
val objectTypeForState = Type.getObjectType(classBuilderForCoroutineState.thisName)
val dataIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE)
val exceptionIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE)
with(codegen.v) {
val createStateInstance = Label()
val storeStateObject = Label()
// We have to distinguish the following situations:
// - Our function got called in a common way (e.g. from another function or via recursive call) and we should execute our
// code from the beginning
// - We got called from `doResume` of our continuation, i.e. we need to continue from the last suspension point
//
// Also in the first case we wrap the completion into a special anonymous class instance (let's call it X$1)
// that we'll use as a continuation argument for suspension points
//
// How we distinguish the cases:
// - If the continuation is not an instance of X$1 we know exactly it's not the second case, because when resuming
// the continuation we pass an instance of that class
// - Otherwise it's still can be a recursive call. To check it's not the case we set the last bit in the label in
// `doResume` just before calling the suspend function (see kotlin.coroutines.experimental.jvm.internal.CoroutineImplForNamedFunction).
// So, if it's set we're in continuation.
visitVarInsn(Opcodes.ALOAD, continuationIndex)
instanceOf(objectTypeForState)
ifeq(createStateInstance)
visitVarInsn(Opcodes.ALOAD, continuationIndex)
checkcast(objectTypeForState)
invokevirtual(
COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName,
"checkAndFlushLastBit", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false
)
ifeq(createStateInstance)
visitVarInsn(Opcodes.ALOAD, continuationIndex)
checkcast(objectTypeForState)
goTo(storeStateObject)
visitLabel(createStateInstance)
coroutineCodegen.putInstanceOnStack(codegen, null).put(objectTypeForState, this)
visitLabel(storeStateObject)
visitVarInsn(Opcodes.ASTORE, continuationIndex)
visitVarInsn(Opcodes.ALOAD, continuationIndex)
checkcast(objectTypeForState)
getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "data", AsmTypes.OBJECT_TYPE.descriptor)
visitVarInsn(Opcodes.ASTORE, dataIndex)
visitVarInsn(Opcodes.ALOAD, continuationIndex)
checkcast(objectTypeForState)
getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "exception", AsmTypes.JAVA_THROWABLE_TYPE.descriptor)
visitVarInsn(Opcodes.ASTORE, exceptionIndex)
invokestatic(COROUTINE_MARKER_OWNER, ACTUAL_COROUTINE_START_MARKER_NAME, "()V", false)
}
codegen.returnExpression(declaration.bodyExpression)
}
}
@@ -72,6 +72,10 @@ val CONTINUATION_ASM_TYPE = DescriptorUtils.CONTINUATION_INTERFACE_FQ_NAME.topLe
@JvmField
val COROUTINE_IMPL_ASM_TYPE = COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImpl")).topLevelClassAsmType()
@JvmField
val COROUTINE_IMPL_FOR_NAMED_ASM_TYPE =
COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImplForNamedFunction")).topLevelClassAsmType()
private val COROUTINES_INTRINSICS_FILE_FACADE_INTERNAL_NAME =
COROUTINES_INTRINSICS_PACKAGE_FQ_NAME.child(Name.identifier("IntrinsicsKt")).topLevelClassAsmType()
@@ -598,7 +598,7 @@ public class InlineCodegen extends CallGenerator {
new SuspendFunctionGenerationStrategy(
state,
CoroutineCodegenUtilKt.unwrapInitialDescriptorForSuspendFunction(descriptor),
(KtFunction) expression
(KtFunction) expression, null
);
}
else {
@@ -35,7 +35,7 @@ abstract class CoroutineImpl(
// label == -1 when coroutine cannot be started (it is just a factory object) or has already finished execution
// label == 0 in initial part of the coroutine
@JvmField
protected var label: Int = if (completion != null) 0 else -1
var label: Int = if (completion != null) 0 else -1
private val _context: CoroutineContext? = completion?.context
@@ -71,3 +71,35 @@ abstract class CoroutineImpl(
throw IllegalStateException("create(Any?;Continuation) has not been overridden")
}
}
abstract class CoroutineImplForNamedFunction(
completion: Continuation<Any?>?
) : CoroutineImpl(0, completion), Continuation<Any?> {
companion object {
private const val LAST_BIT_MASK = 1 shl 31
}
@JvmField
var data: Any? = null
@JvmField
var exception: Throwable? = null
fun checkAndFlushLastBit(): Boolean {
if (label and LAST_BIT_MASK != 0) {
label -= LAST_BIT_MASK
return true
}
return false
}
override fun doResume(data: Any?, exception: Throwable?): Any? {
this.data = data
this.exception = exception
this.label = this.label or LAST_BIT_MASK
return doResume()
}
protected abstract fun doResume(): Any?
}