Generate state machine for named functions in their bodies
Inline functions aren't supported yet in the change #KT-17585 In Progress
This commit is contained in:
@@ -646,6 +646,8 @@ public class AsmUtil {
|
||||
@NotNull FrameMap frameMap
|
||||
) {
|
||||
if (state.isParamAssertionsDisabled()) return;
|
||||
// currently when resuming a suspend function we pass default values instead of real arguments (i.e. nulls for references)
|
||||
if (descriptor.isSuspend()) return;
|
||||
|
||||
// Private method is not accessible from other classes, no assertions needed
|
||||
if (getVisibilityAccessFlag(descriptor) == ACC_PRIVATE) return;
|
||||
|
||||
@@ -35,7 +35,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns;
|
||||
import org.jetbrains.kotlin.codegen.binding.CalculatedClosure;
|
||||
import org.jetbrains.kotlin.codegen.binding.CodegenBinding;
|
||||
import org.jetbrains.kotlin.codegen.context.*;
|
||||
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegen;
|
||||
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenForLambda;
|
||||
import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenUtilKt;
|
||||
import org.jetbrains.kotlin.codegen.coroutines.ResolvedCallWithRealDescriptor;
|
||||
import org.jetbrains.kotlin.codegen.extensions.ExpressionCodegenExtension;
|
||||
@@ -52,6 +52,7 @@ import org.jetbrains.kotlin.codegen.when.SwitchCodegen;
|
||||
import org.jetbrains.kotlin.codegen.when.SwitchCodegenUtil;
|
||||
import org.jetbrains.kotlin.config.ApiVersion;
|
||||
import org.jetbrains.kotlin.descriptors.*;
|
||||
import org.jetbrains.kotlin.descriptors.impl.AnonymousFunctionDescriptor;
|
||||
import org.jetbrains.kotlin.descriptors.impl.LocalVariableDescriptor;
|
||||
import org.jetbrains.kotlin.descriptors.impl.SyntheticFieldDescriptor;
|
||||
import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptor;
|
||||
@@ -942,7 +943,7 @@ public class ExpressionCodegen extends KtVisitor<StackValue, StackValue> impleme
|
||||
declaration.getContainingFile()
|
||||
);
|
||||
|
||||
ClosureCodegen coroutineCodegen = CoroutineCodegen.createByLambda(this, descriptor, declaration, cv);
|
||||
ClosureCodegen coroutineCodegen = CoroutineCodegenForLambda.create(this, descriptor, declaration, cv);
|
||||
ClosureCodegen closureCodegen = coroutineCodegen != null ? coroutineCodegen : new ClosureCodegen(
|
||||
state, declaration, samType, context.intoClosure(descriptor, this, typeMapper),
|
||||
functionReferenceTarget, strategy, parentCodegen, cv
|
||||
@@ -1188,12 +1189,14 @@ public class ExpressionCodegen extends KtVisitor<StackValue, StackValue> impleme
|
||||
bindingContext.get(ENCLOSING_SUSPEND_FUNCTION_FOR_SUSPEND_FUNCTION_CALL, resolvedCall.getCall());
|
||||
|
||||
if (enclosingSuspendLambdaForSuspensionPoint == null) return null;
|
||||
return genCoroutineInstanceBySuspendFunction(enclosingSuspendLambdaForSuspensionPoint);
|
||||
return genCoroutineInstanceForSuspendLambda(enclosingSuspendLambdaForSuspensionPoint);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private StackValue genCoroutineInstanceBySuspendFunction(@NotNull FunctionDescriptor suspendFunction) {
|
||||
private StackValue genCoroutineInstanceForSuspendLambda(@NotNull FunctionDescriptor suspendFunction) {
|
||||
if (!CoroutineCodegenUtilKt.isStateMachineNeeded(suspendFunction, bindingContext)) return null;
|
||||
if (!(suspendFunction instanceof AnonymousFunctionDescriptor)) return null;
|
||||
|
||||
ClassDescriptor suspendLambdaClassDescriptor = bindingContext.get(CodegenBinding.CLASS_FOR_CALLABLE, suspendFunction);
|
||||
assert suspendLambdaClassDescriptor != null : "Coroutine class descriptor should not be null";
|
||||
|
||||
|
||||
@@ -135,7 +135,8 @@ public class FunctionCodegen {
|
||||
strategy = new SuspendFunctionGenerationStrategy(
|
||||
state,
|
||||
CoroutineCodegenUtilKt.<FunctionDescriptor>unwrapInitialDescriptorForSuspendFunction(functionDescriptor),
|
||||
function
|
||||
function,
|
||||
v.getThisName()
|
||||
);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -43,6 +43,7 @@ class JvmRuntimeTypes(module: ModuleDescriptor) {
|
||||
private val localVariableReference: ClassDescriptor by klass("LocalVariableReference")
|
||||
private val mutableLocalVariableReference: ClassDescriptor by klass("MutableLocalVariableReference")
|
||||
private val coroutineImplClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImpl") }
|
||||
private val coroutineImplForNamedFunctionClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImplForNamedFunction") }
|
||||
|
||||
private val propertyReferences: List<ClassDescriptor> by lazy {
|
||||
(0..2).map { i -> createClass(kotlinJvmInternalPackage, "PropertyReference$i") }
|
||||
@@ -84,11 +85,13 @@ class JvmRuntimeTypes(module: ModuleDescriptor) {
|
||||
|
||||
if (descriptor.isSuspend) {
|
||||
return mutableListOf<KotlinType>().apply {
|
||||
add(coroutineImplClass.defaultType)
|
||||
|
||||
if (descriptor.isSuspendLambda) {
|
||||
add(coroutineImplClass.defaultType)
|
||||
add(functionType)
|
||||
}
|
||||
else {
|
||||
add(coroutineImplForNamedFunctionClass.defaultType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+166
-94
@@ -21,7 +21,6 @@ import org.jetbrains.kotlin.codegen.*
|
||||
import org.jetbrains.kotlin.codegen.binding.CodegenBinding
|
||||
import org.jetbrains.kotlin.codegen.context.ClosureContext
|
||||
import org.jetbrains.kotlin.codegen.context.MethodContext
|
||||
import org.jetbrains.kotlin.codegen.state.GenerationState
|
||||
import org.jetbrains.kotlin.coroutines.isSuspendLambda
|
||||
import org.jetbrains.kotlin.descriptors.*
|
||||
import org.jetbrains.kotlin.descriptors.annotations.Annotations
|
||||
@@ -42,7 +41,7 @@ import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature
|
||||
import org.jetbrains.kotlin.types.KotlinType
|
||||
import org.jetbrains.kotlin.types.typeUtil.makeNullable
|
||||
import org.jetbrains.kotlin.utils.addToStdlib.safeAs
|
||||
import org.jetbrains.org.objectweb.asm.Label
|
||||
import org.jetbrains.kotlin.utils.sure
|
||||
import org.jetbrains.org.objectweb.asm.MethodVisitor
|
||||
import org.jetbrains.org.objectweb.asm.Opcodes
|
||||
import org.jetbrains.org.objectweb.asm.Type
|
||||
@@ -50,19 +49,64 @@ import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
|
||||
import org.jetbrains.org.objectweb.asm.commons.Method
|
||||
|
||||
|
||||
class CoroutineCodegen private constructor(
|
||||
abstract class AbstractCoroutineCodegen(
|
||||
outerExpressionCodegen: ExpressionCodegen,
|
||||
element: KtElement,
|
||||
private val closureContext: ClosureContext,
|
||||
classBuilder: ClassBuilder,
|
||||
private val originalSuspendFunctionDescriptor: FunctionDescriptor,
|
||||
private val isSuspendLambda: Boolean
|
||||
closureContext: ClosureContext,
|
||||
classBuilder: ClassBuilder
|
||||
) : ClosureCodegen(
|
||||
outerExpressionCodegen.state,
|
||||
element, null, closureContext, null,
|
||||
FailingFunctionGenerationStrategy,
|
||||
outerExpressionCodegen.parentCodegen, classBuilder
|
||||
) {
|
||||
override fun generateConstructor(): Method {
|
||||
val args = calculateConstructorParameters(typeMapper, closure, asmType)
|
||||
val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray()
|
||||
|
||||
val constructor = Method("<init>", Type.VOID_TYPE, argTypes)
|
||||
val mv = v.newMethod(
|
||||
OtherOrigin(element, funDescriptor), visibilityFlag, "<init>", constructor.descriptor, null,
|
||||
ArrayUtil.EMPTY_STRING_ARRAY
|
||||
)
|
||||
|
||||
if (state.classBuilderMode.generateBodies) {
|
||||
mv.visitCode()
|
||||
val iv = InstructionAdapter(mv)
|
||||
|
||||
iv.generateClosureFieldsInitializationFromParameters(closure, args)
|
||||
|
||||
iv.load(0, AsmTypes.OBJECT_TYPE)
|
||||
if (passArityToSuperClass) {
|
||||
iv.iconst(calculateArity())
|
||||
}
|
||||
iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE)
|
||||
|
||||
val superClassConstructorDescriptor = Type.getMethodDescriptor(
|
||||
Type.VOID_TYPE,
|
||||
*(if (passArityToSuperClass) arrayOf(Type.INT_TYPE) else emptyArray()),
|
||||
CONTINUATION_ASM_TYPE
|
||||
)
|
||||
iv.invokespecial(superClassAsmType.internalName, "<init>", superClassConstructorDescriptor, false)
|
||||
|
||||
iv.visitInsn(Opcodes.RETURN)
|
||||
|
||||
FunctionCodegen.endVisit(iv, "constructor", element)
|
||||
}
|
||||
|
||||
return constructor
|
||||
}
|
||||
|
||||
abstract protected val passArityToSuperClass: Boolean
|
||||
}
|
||||
|
||||
class CoroutineCodegenForLambda private constructor(
|
||||
outerExpressionCodegen: ExpressionCodegen,
|
||||
element: KtElement,
|
||||
private val closureContext: ClosureContext,
|
||||
classBuilder: ClassBuilder,
|
||||
private val originalSuspendFunctionDescriptor: FunctionDescriptor
|
||||
) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) {
|
||||
private val classDescriptor = closureContext.contextDescriptor
|
||||
private val builtIns = funDescriptor.builtIns
|
||||
|
||||
@@ -126,16 +170,9 @@ class CoroutineCodegen private constructor(
|
||||
generateDoResume()
|
||||
}
|
||||
|
||||
override fun generateBridges() {
|
||||
if (!isSuspendLambda) return
|
||||
super.generateBridges()
|
||||
}
|
||||
|
||||
override fun generateBody() {
|
||||
super.generateBody()
|
||||
|
||||
if (!isSuspendLambda) return
|
||||
|
||||
// create() = ...
|
||||
functionCodegen.generateMethod(JvmDeclarationOrigin.NO_ORIGIN, createCoroutineDescriptor,
|
||||
object : FunctionGenerationStrategy.CodegenBased(state) {
|
||||
@@ -187,41 +224,14 @@ class CoroutineCodegen private constructor(
|
||||
areturn(AsmTypes.OBJECT_TYPE)
|
||||
}
|
||||
|
||||
override val passArityToSuperClass get() = true
|
||||
|
||||
override fun generateConstructor(): Method {
|
||||
val args = calculateConstructorParameters(typeMapper, closure, asmType)
|
||||
val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray()
|
||||
|
||||
val constructor = Method("<init>", Type.VOID_TYPE, argTypes)
|
||||
val mv = v.newMethod(
|
||||
OtherOrigin(element, funDescriptor), visibilityFlag, "<init>", constructor.descriptor, null,
|
||||
ArrayUtil.EMPTY_STRING_ARRAY
|
||||
)
|
||||
|
||||
constructorToUseFromInvoke = constructor
|
||||
|
||||
if (state.classBuilderMode.generateBodies) {
|
||||
mv.visitCode()
|
||||
val iv = InstructionAdapter(mv)
|
||||
|
||||
iv.generateClosureFieldsInitializationFromParameters(closure, args)
|
||||
|
||||
iv.load(0, AsmTypes.OBJECT_TYPE)
|
||||
iv.iconst(calculateArity())
|
||||
iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE)
|
||||
|
||||
val superClassConstructorDescriptor = Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE, CONTINUATION_ASM_TYPE)
|
||||
iv.invokespecial(superClassAsmType.internalName, "<init>", superClassConstructorDescriptor, false)
|
||||
|
||||
iv.visitInsn(Opcodes.RETURN)
|
||||
|
||||
FunctionCodegen.endVisit(iv, "constructor", element)
|
||||
}
|
||||
|
||||
return constructor
|
||||
constructorToUseFromInvoke = super.generateConstructor()
|
||||
return constructorToUseFromInvoke
|
||||
}
|
||||
|
||||
private fun generateCreateCoroutineMethod(codegen: ExpressionCodegen) {
|
||||
assert(isSuspendLambda) { "create method should only be generated for suspend lambdas" }
|
||||
val classDescriptor = closureContext.contextDescriptor
|
||||
val owner = typeMapper.mapClass(classDescriptor)
|
||||
|
||||
@@ -260,15 +270,11 @@ class CoroutineCodegen private constructor(
|
||||
}
|
||||
|
||||
private fun ExpressionCodegen.initializeCoroutineParameters() {
|
||||
if (!isSuspendLambda && !originalSuspendFunctionDescriptor.isTailrec) return
|
||||
for (parameter in allFunctionParameters()) {
|
||||
val fieldStackValue =
|
||||
if (isSuspendLambda)
|
||||
StackValue.field(
|
||||
parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false)
|
||||
)
|
||||
else
|
||||
closureContext.lookupInContext(parameter, null, state, /* ignoreNoOuter = */ false)
|
||||
StackValue.field(
|
||||
parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false)
|
||||
)
|
||||
|
||||
val mappedType = typeMapper.mapType(parameter.type)
|
||||
fieldStackValue.put(mappedType, v)
|
||||
@@ -277,14 +283,7 @@ class CoroutineCodegen private constructor(
|
||||
v.store(newIndex, mappedType)
|
||||
}
|
||||
|
||||
// necessary for proper tailrec codegen
|
||||
val actualMethodStartLabel = Label()
|
||||
v.visitLabel(actualMethodStartLabel)
|
||||
context.setMethodStartLabel(actualMethodStartLabel)
|
||||
|
||||
if (isSuspendLambda) {
|
||||
initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters)
|
||||
}
|
||||
initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters)
|
||||
}
|
||||
|
||||
private fun allFunctionParameters() =
|
||||
@@ -308,7 +307,12 @@ class CoroutineCodegen private constructor(
|
||||
object : FunctionGenerationStrategy.FunctionDefault(state, element as KtDeclarationWithBody) {
|
||||
|
||||
override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor {
|
||||
return CoroutineTransformerMethodVisitor(mv, access, name, desc, null, null, v)
|
||||
return CoroutineTransformerMethodVisitor(
|
||||
mv, access, name, desc, null, null,
|
||||
obtainClassBuilderForCoroutineState = { v },
|
||||
containingClassInternalName = v.thisName,
|
||||
isForNamedFunction = false
|
||||
)
|
||||
}
|
||||
|
||||
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
|
||||
@@ -319,34 +323,17 @@ class CoroutineCodegen private constructor(
|
||||
)
|
||||
}
|
||||
|
||||
override fun generateKotlinMetadataAnnotation() {
|
||||
if (isSuspendLambda) {
|
||||
super.generateKotlinMetadataAnnotation()
|
||||
}
|
||||
else {
|
||||
writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) {
|
||||
// Do not write method metadata for raw coroutine state machines
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun shouldCreateByLambda(
|
||||
originalSuspendLambdaDescriptor: CallableDescriptor,
|
||||
declaration: KtElement): Boolean {
|
||||
return (declaration is KtFunctionLiteral && originalSuspendLambdaDescriptor.isSuspendLambda)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun createByLambda(
|
||||
fun create(
|
||||
expressionCodegen: ExpressionCodegen,
|
||||
originalSuspendLambdaDescriptor: FunctionDescriptor,
|
||||
declaration: KtElement,
|
||||
classBuilder: ClassBuilder
|
||||
): ClosureCodegen? {
|
||||
if (!shouldCreateByLambda(originalSuspendLambdaDescriptor, declaration)) return null
|
||||
if (declaration !is KtFunctionLiteral || !originalSuspendLambdaDescriptor.isSuspendLambda) return null
|
||||
|
||||
return CoroutineCodegen(
|
||||
return CoroutineCodegenForLambda(
|
||||
expressionCodegen,
|
||||
declaration,
|
||||
expressionCodegen.context.intoCoroutineClosure(
|
||||
@@ -354,31 +341,116 @@ class CoroutineCodegen private constructor(
|
||||
originalSuspendLambdaDescriptor, expressionCodegen, expressionCodegen.state.typeMapper
|
||||
),
|
||||
classBuilder,
|
||||
originalSuspendLambdaDescriptor,
|
||||
isSuspendLambda = true
|
||||
originalSuspendLambdaDescriptor
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CoroutineCodegenForNamedFunction private constructor(
|
||||
outerExpressionCodegen: ExpressionCodegen,
|
||||
element: KtElement,
|
||||
closureContext: ClosureContext,
|
||||
classBuilder: ClassBuilder,
|
||||
originalSuspendFunctionDescriptor: FunctionDescriptor
|
||||
) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) {
|
||||
private val classDescriptor = closureContext.contextDescriptor
|
||||
|
||||
private val suspendFunctionJvmView =
|
||||
bindingContext[CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendFunctionDescriptor]!!
|
||||
|
||||
// protected fun doResume(): Any?
|
||||
private val doResumeDescriptor =
|
||||
SimpleFunctionDescriptorImpl.create(
|
||||
classDescriptor, Annotations.EMPTY, Name.identifier(DO_RESUME_METHOD_NAME), CallableMemberDescriptor.Kind.DECLARATION,
|
||||
funDescriptor.source
|
||||
).apply doResume@{
|
||||
initialize(
|
||||
/* receiverParameterType = */ null,
|
||||
classDescriptor.thisAsReceiverParameter,
|
||||
/* typeParameters = */ emptyList(),
|
||||
listOf(),
|
||||
builtIns.nullableAnyType,
|
||||
Modality.FINAL,
|
||||
Visibilities.PUBLIC
|
||||
)
|
||||
}
|
||||
|
||||
override val passArityToSuperClass get() = false
|
||||
|
||||
override fun generateBridges() {
|
||||
// Do not generate any closure bridges
|
||||
}
|
||||
|
||||
override fun generateClosureBody() {
|
||||
generateDoResume()
|
||||
}
|
||||
|
||||
private fun generateDoResume() {
|
||||
functionCodegen.generateMethod(
|
||||
OtherOrigin(element),
|
||||
doResumeDescriptor,
|
||||
object : FunctionGenerationStrategy.CodegenBased(state) {
|
||||
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
|
||||
val captureThisType = closure.captureThis?.let(typeMapper::mapType)
|
||||
if (captureThisType != null) {
|
||||
StackValue.field(
|
||||
captureThisType, Type.getObjectType(v.thisName), AsmUtil.CAPTURED_THIS_FIELD,
|
||||
false, StackValue.LOCAL_0
|
||||
).put(captureThisType, codegen.v)
|
||||
}
|
||||
|
||||
val callableMethod = typeMapper.mapToCallableMethod(suspendFunctionJvmView, false)
|
||||
|
||||
for (argumentType in callableMethod.getAsmMethod().argumentTypes.dropLast(1)) {
|
||||
AsmUtil.pushDefaultValueOnStack(argumentType, codegen.v)
|
||||
}
|
||||
|
||||
codegen.v.load(0, AsmTypes.OBJECT_TYPE)
|
||||
callableMethod.genInvokeInstruction(codegen.v)
|
||||
|
||||
codegen.v.visitInsn(Opcodes.ARETURN)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
override fun generateKotlinMetadataAnnotation() {
|
||||
writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) {
|
||||
// Do not write method metadata for raw coroutine state machines
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun create(
|
||||
cv: ClassBuilder,
|
||||
expressionCodegen: ExpressionCodegen,
|
||||
originalSuspendDescriptor: FunctionDescriptor,
|
||||
declaration: KtFunction,
|
||||
state: GenerationState
|
||||
): CoroutineCodegen {
|
||||
val cv = state.factory.newVisitor(
|
||||
OtherOrigin(declaration, originalSuspendDescriptor),
|
||||
CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor),
|
||||
declaration.containingFile
|
||||
)
|
||||
declaration: KtFunction
|
||||
): CoroutineCodegenForNamedFunction {
|
||||
val bindingContext = expressionCodegen.state.bindingContext
|
||||
val closure =
|
||||
bindingContext[
|
||||
CodegenBinding.CLOSURE,
|
||||
bindingContext[CodegenBinding.CLASS_FOR_CALLABLE, originalSuspendDescriptor]
|
||||
].sure { "There must be a closure defined for $originalSuspendDescriptor" }
|
||||
|
||||
return CoroutineCodegen(
|
||||
val suspendFunctionView =
|
||||
bindingContext[
|
||||
CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendDescriptor
|
||||
].sure { "There must be a jvm view defined for $originalSuspendDescriptor" }
|
||||
|
||||
if (suspendFunctionView.dispatchReceiverParameter != null) {
|
||||
closure.setCaptureThis()
|
||||
}
|
||||
|
||||
return CoroutineCodegenForNamedFunction(
|
||||
expressionCodegen, declaration,
|
||||
expressionCodegen.context.intoClosure(
|
||||
originalSuspendDescriptor, expressionCodegen, expressionCodegen.state.typeMapper
|
||||
),
|
||||
cv,
|
||||
originalSuspendDescriptor,
|
||||
isSuspendLambda = false
|
||||
originalSuspendDescriptor
|
||||
)
|
||||
}
|
||||
}
|
||||
+34
-22
@@ -44,8 +44,15 @@ class CoroutineTransformerMethodVisitor(
|
||||
desc: String,
|
||||
signature: String?,
|
||||
exceptions: Array<out String>?,
|
||||
private val classBuilder: ClassBuilder
|
||||
private val containingClassInternalName: String,
|
||||
private val classBuilderForCoroutineState: ClassBuilder,
|
||||
isForNamedFunction: Boolean
|
||||
) : TransformationMethodVisitor(delegate, access, name, desc, signature, exceptions) {
|
||||
|
||||
private val continuationIndex = if (isForNamedFunction) getLastParameterIndex(desc, access) else 0
|
||||
private val dataIndex = continuationIndex + 1
|
||||
private val exceptionIndex = dataIndex + 1
|
||||
|
||||
override fun performTransformations(methodNode: MethodNode) {
|
||||
val customCoroutineStartMarker = methodNode.instructions.toArray().filterIsInstance<MethodInsnNode>().firstOrNull {
|
||||
it.owner == COROUTINE_MARKER_OWNER && it.name == ACTUAL_COROUTINE_START_MARKER_NAME
|
||||
@@ -61,7 +68,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
}
|
||||
|
||||
// Spill stack to variables before suspension points, try/catch blocks
|
||||
FixStackWithLabelNormalizationMethodTransformer().transform(classBuilder.thisName, methodNode)
|
||||
FixStackWithLabelNormalizationMethodTransformer().transform(containingClassInternalName, methodNode)
|
||||
|
||||
// Remove unreachable suspension points
|
||||
// If we don't do this, then relevant frames will not be analyzed, that is unexpected from point of view of next steps (e.g. variable spilling)
|
||||
@@ -86,7 +93,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
insnListOf(
|
||||
*withInstructionAdapter { loadCoroutineSuspendedMarker() }.toArray(),
|
||||
VarInsnNode(Opcodes.ASTORE, suspendMarkerVarIndex),
|
||||
VarInsnNode(Opcodes.ALOAD, 0),
|
||||
VarInsnNode(Opcodes.ALOAD, continuationIndex),
|
||||
FieldInsnNode(
|
||||
Opcodes.GETFIELD,
|
||||
COROUTINE_IMPL_ASM_TYPE.internalName,
|
||||
@@ -101,7 +108,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
)
|
||||
)
|
||||
|
||||
insert(startLabel, withInstructionAdapter(InstructionAdapter::generateResumeWithExceptionCheck))
|
||||
insert(startLabel, withInstructionAdapter { generateResumeWithExceptionCheck(exceptionIndex) })
|
||||
|
||||
insert(last, withInstructionAdapter {
|
||||
visitLabel(defaultLabel.label)
|
||||
@@ -116,7 +123,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
}
|
||||
|
||||
private fun removeUnreachableSuspensionPointsAndExitPoints(methodNode: MethodNode, suspensionPoints: MutableList<SuspensionPoint>) {
|
||||
val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(classBuilder.thisName, methodNode)
|
||||
val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(containingClassInternalName, methodNode)
|
||||
|
||||
// If the suspension call begin is alive and suspension call end is dead
|
||||
// (e.g., an inlined suspend function call ends with throwing a exception -- see KT-15017),
|
||||
@@ -166,7 +173,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
|
||||
private fun spillVariables(suspensionPoints: List<SuspensionPoint>, methodNode: MethodNode) {
|
||||
val instructions = methodNode.instructions
|
||||
val frames = performRefinedTypeAnalysis(methodNode, classBuilder.thisName)
|
||||
val frames = performRefinedTypeAnalysis(methodNode, containingClassInternalName)
|
||||
fun AbstractInsnNode.index() = instructions.indexOf(this)
|
||||
|
||||
// We postpone these actions because they change instruction indices that we use when obtaining frames
|
||||
@@ -200,13 +207,15 @@ class CoroutineTransformerMethodVisitor(
|
||||
val livenessFrame = livenessFrames[suspensionCallBegin.index()]
|
||||
|
||||
// 0 - this
|
||||
// 1 - continuation argument
|
||||
// 2 - continuation exception
|
||||
// 1 - parameter
|
||||
// ...
|
||||
// k - continuation
|
||||
// k + 1 - data
|
||||
// k + 2 - exception
|
||||
val variablesToSpill =
|
||||
(3 until localsCount)
|
||||
((exceptionIndex + 1) until localsCount)
|
||||
.map { Pair(it, frame.getLocal(it)) }
|
||||
.filter {
|
||||
val (index, value) = it
|
||||
.filter { (index, value) ->
|
||||
value != StrictBasicValue.UNINITIALIZED_VALUE && livenessFrame.isAlive(index)
|
||||
}
|
||||
|
||||
@@ -235,16 +244,16 @@ class CoroutineTransformerMethodVisitor(
|
||||
with(instructions) {
|
||||
// store variable before suspension call
|
||||
insertBefore(suspension.suspensionCallBegin, withInstructionAdapter {
|
||||
load(0, AsmTypes.OBJECT_TYPE)
|
||||
load(continuationIndex, AsmTypes.OBJECT_TYPE)
|
||||
load(index, type)
|
||||
StackValue.coerce(type, normalizedType, this)
|
||||
putfield(classBuilder.thisName, fieldName, normalizedType.descriptor)
|
||||
putfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor)
|
||||
})
|
||||
|
||||
// restore variable after suspension call
|
||||
insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter {
|
||||
load(0, AsmTypes.OBJECT_TYPE)
|
||||
getfield(classBuilder.thisName, fieldName, normalizedType.descriptor)
|
||||
load(continuationIndex, AsmTypes.OBJECT_TYPE)
|
||||
getfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor)
|
||||
StackValue.coerce(normalizedType, type, this)
|
||||
store(index, type)
|
||||
})
|
||||
@@ -262,8 +271,8 @@ class CoroutineTransformerMethodVisitor(
|
||||
maxVarsCountByType.forEach { entry ->
|
||||
val (type, maxIndex) = entry
|
||||
for (index in 0..maxIndex) {
|
||||
classBuilder.newField(
|
||||
JvmDeclarationOrigin.NO_ORIGIN, Opcodes.ACC_PRIVATE,
|
||||
classBuilderForCoroutineState.newField(
|
||||
JvmDeclarationOrigin.NO_ORIGIN, AsmUtil.NO_FLAG_PACKAGE_PRIVATE,
|
||||
type.fieldNameForVar(index), type.descriptor, null, null)
|
||||
}
|
||||
}
|
||||
@@ -294,7 +303,7 @@ class CoroutineTransformerMethodVisitor(
|
||||
// Save state
|
||||
insertBefore(suspension.suspensionCallBegin,
|
||||
insnListOf(
|
||||
VarInsnNode(Opcodes.ALOAD, 0),
|
||||
VarInsnNode(Opcodes.ALOAD, continuationIndex),
|
||||
*withInstructionAdapter { iconst(id) }.toArray(),
|
||||
FieldInsnNode(
|
||||
Opcodes.PUTFIELD, COROUTINE_IMPL_ASM_TYPE.internalName, COROUTINE_LABEL_FIELD_NAME,
|
||||
@@ -327,10 +336,10 @@ class CoroutineTransformerMethodVisitor(
|
||||
remove(possibleTryCatchBlockStart.previous)
|
||||
|
||||
insert(possibleTryCatchBlockStart, withInstructionAdapter {
|
||||
generateResumeWithExceptionCheck()
|
||||
generateResumeWithExceptionCheck(exceptionIndex)
|
||||
|
||||
// Load continuation argument just like suspending function returns it
|
||||
load(1, AsmTypes.OBJECT_TYPE)
|
||||
load(dataIndex, AsmTypes.OBJECT_TYPE)
|
||||
|
||||
visitLabel(continuationLabelAfterLoadedResult.label)
|
||||
})
|
||||
@@ -389,9 +398,9 @@ class CoroutineTransformerMethodVisitor(
|
||||
}
|
||||
}
|
||||
|
||||
private fun InstructionAdapter.generateResumeWithExceptionCheck() {
|
||||
private fun InstructionAdapter.generateResumeWithExceptionCheck(exceptionIndex: Int) {
|
||||
// Check if resumeWithException has been called
|
||||
load(2, AsmTypes.OBJECT_TYPE)
|
||||
load(exceptionIndex, AsmTypes.OBJECT_TYPE)
|
||||
dup()
|
||||
val noExceptionLabel = Label()
|
||||
ifnull(noExceptionLabel)
|
||||
@@ -432,3 +441,6 @@ private class SuspensionPoint(
|
||||
) {
|
||||
lateinit var tryCatchBlocksContinuationLabel: LabelNode
|
||||
}
|
||||
|
||||
private fun getLastParameterIndex(desc: String, access: Int) =
|
||||
Type.getArgumentTypes(desc).dropLast(1).map { it.size }.sum() + (if (access and Opcodes.ACC_STATIC != 0) 0 else 1)
|
||||
|
||||
+91
-8
@@ -18,33 +18,116 @@ package org.jetbrains.kotlin.codegen.coroutines
|
||||
|
||||
import org.jetbrains.kotlin.codegen.ExpressionCodegen
|
||||
import org.jetbrains.kotlin.codegen.FunctionGenerationStrategy
|
||||
import org.jetbrains.kotlin.codegen.binding.CodegenBinding
|
||||
import org.jetbrains.kotlin.codegen.state.GenerationState
|
||||
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
|
||||
import org.jetbrains.kotlin.psi.KtFunction
|
||||
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
|
||||
import org.jetbrains.kotlin.resolve.jvm.diagnostics.OtherOrigin
|
||||
import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature
|
||||
import org.jetbrains.org.objectweb.asm.Label
|
||||
import org.jetbrains.org.objectweb.asm.MethodVisitor
|
||||
import org.jetbrains.org.objectweb.asm.Opcodes
|
||||
import org.jetbrains.org.objectweb.asm.Type
|
||||
|
||||
class SuspendFunctionGenerationStrategy(
|
||||
state: GenerationState,
|
||||
private val originalSuspendDescriptor: FunctionDescriptor,
|
||||
private val declaration: KtFunction
|
||||
private val declaration: KtFunction,
|
||||
private val containingClassInternalName: String
|
||||
) : FunctionGenerationStrategy.CodegenBased(state) {
|
||||
private val containsNonTailSuspensionCalls = originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext)
|
||||
|
||||
private val classBuilderForCoroutineState by lazy {
|
||||
state.factory.newVisitor(
|
||||
OtherOrigin(declaration, originalSuspendDescriptor),
|
||||
CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor),
|
||||
declaration.containingFile
|
||||
)
|
||||
}
|
||||
|
||||
override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor {
|
||||
if (containsNonTailSuspensionCalls) {
|
||||
return CoroutineTransformerMethodVisitor(
|
||||
mv, access, name, desc, null, null, containingClassInternalName, classBuilderForCoroutineState,
|
||||
isForNamedFunction = true
|
||||
)
|
||||
}
|
||||
|
||||
return super.wrapMethodVisitor(mv, access, name, desc)
|
||||
}
|
||||
|
||||
override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) {
|
||||
if (!originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext)) {
|
||||
if (!containsNonTailSuspensionCalls) {
|
||||
codegen.returnExpression(declaration.bodyExpression)
|
||||
return
|
||||
}
|
||||
|
||||
val coroutineCodegen = CoroutineCodegen.create(codegen, originalSuspendDescriptor, declaration, state)
|
||||
val coroutineCodegen =
|
||||
CoroutineForNamedFunctionCodegen.create(classBuilderForCoroutineState, codegen, originalSuspendDescriptor, declaration)
|
||||
coroutineCodegen.generate()
|
||||
|
||||
codegen.markLineNumber(declaration, false)
|
||||
codegen.putClosureInstanceOnStack(coroutineCodegen, null).put(Type.getObjectType(coroutineCodegen.className), codegen.v)
|
||||
codegen.v.invokeDoResumeWithUnit(coroutineCodegen.v.thisName)
|
||||
val functionDescriptor = codegen.context.functionDescriptor
|
||||
|
||||
codegen.markLineNumber(declaration, true)
|
||||
codegen.v.areturn(AsmTypes.OBJECT_TYPE)
|
||||
val continuationIndex = codegen.frameMap.getIndex(functionDescriptor.valueParameters.last())
|
||||
val objectTypeForState = Type.getObjectType(classBuilderForCoroutineState.thisName)
|
||||
|
||||
val dataIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE)
|
||||
val exceptionIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE)
|
||||
|
||||
with(codegen.v) {
|
||||
val createStateInstance = Label()
|
||||
val storeStateObject = Label()
|
||||
|
||||
// We have to distinguish the following situations:
|
||||
// - Our function got called in a common way (e.g. from another function or via recursive call) and we should execute our
|
||||
// code from the beginning
|
||||
// - We got called from `doResume` of our continuation, i.e. we need to continue from the last suspension point
|
||||
//
|
||||
// Also in the first case we wrap the completion into a special anonymous class instance (let's call it X$1)
|
||||
// that we'll use as a continuation argument for suspension points
|
||||
//
|
||||
// How we distinguish the cases:
|
||||
// - If the continuation is not an instance of X$1 we know exactly it's not the second case, because when resuming
|
||||
// the continuation we pass an instance of that class
|
||||
// - Otherwise it's still can be a recursive call. To check it's not the case we set the last bit in the label in
|
||||
// `doResume` just before calling the suspend function (see kotlin.coroutines.experimental.jvm.internal.CoroutineImplForNamedFunction).
|
||||
// So, if it's set we're in continuation.
|
||||
visitVarInsn(Opcodes.ALOAD, continuationIndex)
|
||||
instanceOf(objectTypeForState)
|
||||
ifeq(createStateInstance)
|
||||
|
||||
visitVarInsn(Opcodes.ALOAD, continuationIndex)
|
||||
checkcast(objectTypeForState)
|
||||
invokevirtual(
|
||||
COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName,
|
||||
"checkAndFlushLastBit", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false
|
||||
)
|
||||
ifeq(createStateInstance)
|
||||
|
||||
visitVarInsn(Opcodes.ALOAD, continuationIndex)
|
||||
checkcast(objectTypeForState)
|
||||
goTo(storeStateObject)
|
||||
|
||||
visitLabel(createStateInstance)
|
||||
coroutineCodegen.putInstanceOnStack(codegen, null).put(objectTypeForState, this)
|
||||
|
||||
visitLabel(storeStateObject)
|
||||
visitVarInsn(Opcodes.ASTORE, continuationIndex)
|
||||
|
||||
visitVarInsn(Opcodes.ALOAD, continuationIndex)
|
||||
checkcast(objectTypeForState)
|
||||
getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "data", AsmTypes.OBJECT_TYPE.descriptor)
|
||||
visitVarInsn(Opcodes.ASTORE, dataIndex)
|
||||
|
||||
visitVarInsn(Opcodes.ALOAD, continuationIndex)
|
||||
checkcast(objectTypeForState)
|
||||
getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "exception", AsmTypes.JAVA_THROWABLE_TYPE.descriptor)
|
||||
visitVarInsn(Opcodes.ASTORE, exceptionIndex)
|
||||
|
||||
invokestatic(COROUTINE_MARKER_OWNER, ACTUAL_COROUTINE_START_MARKER_NAME, "()V", false)
|
||||
}
|
||||
|
||||
codegen.returnExpression(declaration.bodyExpression)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,10 @@ val CONTINUATION_ASM_TYPE = DescriptorUtils.CONTINUATION_INTERFACE_FQ_NAME.topLe
|
||||
@JvmField
|
||||
val COROUTINE_IMPL_ASM_TYPE = COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImpl")).topLevelClassAsmType()
|
||||
|
||||
@JvmField
|
||||
val COROUTINE_IMPL_FOR_NAMED_ASM_TYPE =
|
||||
COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImplForNamedFunction")).topLevelClassAsmType()
|
||||
|
||||
private val COROUTINES_INTRINSICS_FILE_FACADE_INTERNAL_NAME =
|
||||
COROUTINES_INTRINSICS_PACKAGE_FQ_NAME.child(Name.identifier("IntrinsicsKt")).topLevelClassAsmType()
|
||||
|
||||
|
||||
@@ -598,7 +598,7 @@ public class InlineCodegen extends CallGenerator {
|
||||
new SuspendFunctionGenerationStrategy(
|
||||
state,
|
||||
CoroutineCodegenUtilKt.unwrapInitialDescriptorForSuspendFunction(descriptor),
|
||||
(KtFunction) expression
|
||||
(KtFunction) expression, null
|
||||
);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -35,7 +35,7 @@ abstract class CoroutineImpl(
|
||||
// label == -1 when coroutine cannot be started (it is just a factory object) or has already finished execution
|
||||
// label == 0 in initial part of the coroutine
|
||||
@JvmField
|
||||
protected var label: Int = if (completion != null) 0 else -1
|
||||
var label: Int = if (completion != null) 0 else -1
|
||||
|
||||
private val _context: CoroutineContext? = completion?.context
|
||||
|
||||
@@ -71,3 +71,35 @@ abstract class CoroutineImpl(
|
||||
throw IllegalStateException("create(Any?;Continuation) has not been overridden")
|
||||
}
|
||||
}
|
||||
|
||||
abstract class CoroutineImplForNamedFunction(
|
||||
completion: Continuation<Any?>?
|
||||
) : CoroutineImpl(0, completion), Continuation<Any?> {
|
||||
companion object {
|
||||
private const val LAST_BIT_MASK = 1 shl 31
|
||||
}
|
||||
|
||||
@JvmField
|
||||
var data: Any? = null
|
||||
|
||||
@JvmField
|
||||
var exception: Throwable? = null
|
||||
|
||||
fun checkAndFlushLastBit(): Boolean {
|
||||
if (label and LAST_BIT_MASK != 0) {
|
||||
label -= LAST_BIT_MASK
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
override fun doResume(data: Any?, exception: Throwable?): Any? {
|
||||
this.data = data
|
||||
this.exception = exception
|
||||
this.label = this.label or LAST_BIT_MASK
|
||||
return doResume()
|
||||
}
|
||||
|
||||
protected abstract fun doResume(): Any?
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user