From c3a032ea0b6c5c48da63e275fe784bffe14e0446 Mon Sep 17 00:00:00 2001 From: Denis Zharkov Date: Tue, 25 Apr 2017 15:10:48 +0300 Subject: [PATCH] Generate state machine for named functions in their bodies Inline functions aren't supported yet in the change #KT-17585 In Progress --- .../org/jetbrains/kotlin/codegen/AsmUtil.java | 2 + .../kotlin/codegen/ExpressionCodegen.java | 11 +- .../kotlin/codegen/FunctionCodegen.java | 3 +- .../kotlin/codegen/JvmRuntimeTypes.kt | 7 +- ...odegen.kt => CoroutineCodegenForLambda.kt} | 260 +++++++++++------- .../CoroutineTransformerMethodVisitor.kt | 56 ++-- .../SuspendFunctionGenerationStrategy.kt | 99 ++++++- .../coroutines/coroutineCodegenUtil.kt | 4 + .../kotlin/codegen/inline/InlineCodegen.java | 2 +- .../jvm/internal/CoroutineImpl.kt | 34 ++- 10 files changed, 345 insertions(+), 133 deletions(-) rename compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/{CoroutineCodegen.kt => CoroutineCodegenForLambda.kt} (71%) diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/AsmUtil.java b/compiler/backend/src/org/jetbrains/kotlin/codegen/AsmUtil.java index 95737704b69..48e2bcb551e 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/AsmUtil.java +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/AsmUtil.java @@ -646,6 +646,8 @@ public class AsmUtil { @NotNull FrameMap frameMap ) { if (state.isParamAssertionsDisabled()) return; + // currently when resuming a suspend function we pass default values instead of real arguments (i.e. nulls for references) + if (descriptor.isSuspend()) return; // Private method is not accessible from other classes, no assertions needed if (getVisibilityAccessFlag(descriptor) == ACC_PRIVATE) return; diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java b/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java index b30f6506775..907eb4fc0ba 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java @@ -35,7 +35,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns; import org.jetbrains.kotlin.codegen.binding.CalculatedClosure; import org.jetbrains.kotlin.codegen.binding.CodegenBinding; import org.jetbrains.kotlin.codegen.context.*; -import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegen; +import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenForLambda; import org.jetbrains.kotlin.codegen.coroutines.CoroutineCodegenUtilKt; import org.jetbrains.kotlin.codegen.coroutines.ResolvedCallWithRealDescriptor; import org.jetbrains.kotlin.codegen.extensions.ExpressionCodegenExtension; @@ -52,6 +52,7 @@ import org.jetbrains.kotlin.codegen.when.SwitchCodegen; import org.jetbrains.kotlin.codegen.when.SwitchCodegenUtil; import org.jetbrains.kotlin.config.ApiVersion; import org.jetbrains.kotlin.descriptors.*; +import org.jetbrains.kotlin.descriptors.impl.AnonymousFunctionDescriptor; import org.jetbrains.kotlin.descriptors.impl.LocalVariableDescriptor; import org.jetbrains.kotlin.descriptors.impl.SyntheticFieldDescriptor; import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptor; @@ -942,7 +943,7 @@ public class ExpressionCodegen extends KtVisitor impleme declaration.getContainingFile() ); - ClosureCodegen coroutineCodegen = CoroutineCodegen.createByLambda(this, descriptor, declaration, cv); + ClosureCodegen coroutineCodegen = CoroutineCodegenForLambda.create(this, descriptor, declaration, cv); ClosureCodegen closureCodegen = coroutineCodegen != null ? coroutineCodegen : new ClosureCodegen( state, declaration, samType, context.intoClosure(descriptor, this, typeMapper), functionReferenceTarget, strategy, parentCodegen, cv @@ -1188,12 +1189,14 @@ public class ExpressionCodegen extends KtVisitor impleme bindingContext.get(ENCLOSING_SUSPEND_FUNCTION_FOR_SUSPEND_FUNCTION_CALL, resolvedCall.getCall()); if (enclosingSuspendLambdaForSuspensionPoint == null) return null; - return genCoroutineInstanceBySuspendFunction(enclosingSuspendLambdaForSuspensionPoint); + return genCoroutineInstanceForSuspendLambda(enclosingSuspendLambdaForSuspensionPoint); } @Nullable - private StackValue genCoroutineInstanceBySuspendFunction(@NotNull FunctionDescriptor suspendFunction) { + private StackValue genCoroutineInstanceForSuspendLambda(@NotNull FunctionDescriptor suspendFunction) { if (!CoroutineCodegenUtilKt.isStateMachineNeeded(suspendFunction, bindingContext)) return null; + if (!(suspendFunction instanceof AnonymousFunctionDescriptor)) return null; + ClassDescriptor suspendLambdaClassDescriptor = bindingContext.get(CodegenBinding.CLASS_FOR_CALLABLE, suspendFunction); assert suspendLambdaClassDescriptor != null : "Coroutine class descriptor should not be null"; diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/FunctionCodegen.java b/compiler/backend/src/org/jetbrains/kotlin/codegen/FunctionCodegen.java index 0f172467b6a..7b7e9aed832 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/FunctionCodegen.java +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/FunctionCodegen.java @@ -135,7 +135,8 @@ public class FunctionCodegen { strategy = new SuspendFunctionGenerationStrategy( state, CoroutineCodegenUtilKt.unwrapInitialDescriptorForSuspendFunction(functionDescriptor), - function + function, + v.getThisName() ); } else { diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/JvmRuntimeTypes.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/JvmRuntimeTypes.kt index 9534bfab4d5..90ef9009920 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/JvmRuntimeTypes.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/JvmRuntimeTypes.kt @@ -43,6 +43,7 @@ class JvmRuntimeTypes(module: ModuleDescriptor) { private val localVariableReference: ClassDescriptor by klass("LocalVariableReference") private val mutableLocalVariableReference: ClassDescriptor by klass("MutableLocalVariableReference") private val coroutineImplClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImpl") } + private val coroutineImplForNamedFunctionClass by lazy { createClass(kotlinCoroutinesJvmInternalPackage, "CoroutineImplForNamedFunction") } private val propertyReferences: List by lazy { (0..2).map { i -> createClass(kotlinJvmInternalPackage, "PropertyReference$i") } @@ -84,11 +85,13 @@ class JvmRuntimeTypes(module: ModuleDescriptor) { if (descriptor.isSuspend) { return mutableListOf().apply { - add(coroutineImplClass.defaultType) - if (descriptor.isSuspendLambda) { + add(coroutineImplClass.defaultType) add(functionType) } + else { + add(coroutineImplForNamedFunctionClass.defaultType) + } } } diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegenForLambda.kt similarity index 71% rename from compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt rename to compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegenForLambda.kt index 160d14e0370..6935edbcf29 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegenForLambda.kt @@ -21,7 +21,6 @@ import org.jetbrains.kotlin.codegen.* import org.jetbrains.kotlin.codegen.binding.CodegenBinding import org.jetbrains.kotlin.codegen.context.ClosureContext import org.jetbrains.kotlin.codegen.context.MethodContext -import org.jetbrains.kotlin.codegen.state.GenerationState import org.jetbrains.kotlin.coroutines.isSuspendLambda import org.jetbrains.kotlin.descriptors.* import org.jetbrains.kotlin.descriptors.annotations.Annotations @@ -42,7 +41,7 @@ import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature import org.jetbrains.kotlin.types.KotlinType import org.jetbrains.kotlin.types.typeUtil.makeNullable import org.jetbrains.kotlin.utils.addToStdlib.safeAs -import org.jetbrains.org.objectweb.asm.Label +import org.jetbrains.kotlin.utils.sure import org.jetbrains.org.objectweb.asm.MethodVisitor import org.jetbrains.org.objectweb.asm.Opcodes import org.jetbrains.org.objectweb.asm.Type @@ -50,19 +49,64 @@ import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter import org.jetbrains.org.objectweb.asm.commons.Method -class CoroutineCodegen private constructor( +abstract class AbstractCoroutineCodegen( outerExpressionCodegen: ExpressionCodegen, element: KtElement, - private val closureContext: ClosureContext, - classBuilder: ClassBuilder, - private val originalSuspendFunctionDescriptor: FunctionDescriptor, - private val isSuspendLambda: Boolean + closureContext: ClosureContext, + classBuilder: ClassBuilder ) : ClosureCodegen( outerExpressionCodegen.state, element, null, closureContext, null, FailingFunctionGenerationStrategy, outerExpressionCodegen.parentCodegen, classBuilder ) { + override fun generateConstructor(): Method { + val args = calculateConstructorParameters(typeMapper, closure, asmType) + val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray() + + val constructor = Method("", Type.VOID_TYPE, argTypes) + val mv = v.newMethod( + OtherOrigin(element, funDescriptor), visibilityFlag, "", constructor.descriptor, null, + ArrayUtil.EMPTY_STRING_ARRAY + ) + + if (state.classBuilderMode.generateBodies) { + mv.visitCode() + val iv = InstructionAdapter(mv) + + iv.generateClosureFieldsInitializationFromParameters(closure, args) + + iv.load(0, AsmTypes.OBJECT_TYPE) + if (passArityToSuperClass) { + iv.iconst(calculateArity()) + } + iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE) + + val superClassConstructorDescriptor = Type.getMethodDescriptor( + Type.VOID_TYPE, + *(if (passArityToSuperClass) arrayOf(Type.INT_TYPE) else emptyArray()), + CONTINUATION_ASM_TYPE + ) + iv.invokespecial(superClassAsmType.internalName, "", superClassConstructorDescriptor, false) + + iv.visitInsn(Opcodes.RETURN) + + FunctionCodegen.endVisit(iv, "constructor", element) + } + + return constructor + } + + abstract protected val passArityToSuperClass: Boolean +} + +class CoroutineCodegenForLambda private constructor( + outerExpressionCodegen: ExpressionCodegen, + element: KtElement, + private val closureContext: ClosureContext, + classBuilder: ClassBuilder, + private val originalSuspendFunctionDescriptor: FunctionDescriptor +) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) { private val classDescriptor = closureContext.contextDescriptor private val builtIns = funDescriptor.builtIns @@ -126,16 +170,9 @@ class CoroutineCodegen private constructor( generateDoResume() } - override fun generateBridges() { - if (!isSuspendLambda) return - super.generateBridges() - } - override fun generateBody() { super.generateBody() - if (!isSuspendLambda) return - // create() = ... functionCodegen.generateMethod(JvmDeclarationOrigin.NO_ORIGIN, createCoroutineDescriptor, object : FunctionGenerationStrategy.CodegenBased(state) { @@ -187,41 +224,14 @@ class CoroutineCodegen private constructor( areturn(AsmTypes.OBJECT_TYPE) } + override val passArityToSuperClass get() = true + override fun generateConstructor(): Method { - val args = calculateConstructorParameters(typeMapper, closure, asmType) - val argTypes = args.map { it.fieldType }.plus(CONTINUATION_ASM_TYPE).toTypedArray() - - val constructor = Method("", Type.VOID_TYPE, argTypes) - val mv = v.newMethod( - OtherOrigin(element, funDescriptor), visibilityFlag, "", constructor.descriptor, null, - ArrayUtil.EMPTY_STRING_ARRAY - ) - - constructorToUseFromInvoke = constructor - - if (state.classBuilderMode.generateBodies) { - mv.visitCode() - val iv = InstructionAdapter(mv) - - iv.generateClosureFieldsInitializationFromParameters(closure, args) - - iv.load(0, AsmTypes.OBJECT_TYPE) - iv.iconst(calculateArity()) - iv.load(argTypes.map { it.size }.sum(), AsmTypes.OBJECT_TYPE) - - val superClassConstructorDescriptor = Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE, CONTINUATION_ASM_TYPE) - iv.invokespecial(superClassAsmType.internalName, "", superClassConstructorDescriptor, false) - - iv.visitInsn(Opcodes.RETURN) - - FunctionCodegen.endVisit(iv, "constructor", element) - } - - return constructor + constructorToUseFromInvoke = super.generateConstructor() + return constructorToUseFromInvoke } private fun generateCreateCoroutineMethod(codegen: ExpressionCodegen) { - assert(isSuspendLambda) { "create method should only be generated for suspend lambdas" } val classDescriptor = closureContext.contextDescriptor val owner = typeMapper.mapClass(classDescriptor) @@ -260,15 +270,11 @@ class CoroutineCodegen private constructor( } private fun ExpressionCodegen.initializeCoroutineParameters() { - if (!isSuspendLambda && !originalSuspendFunctionDescriptor.isTailrec) return for (parameter in allFunctionParameters()) { val fieldStackValue = - if (isSuspendLambda) - StackValue.field( - parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false) - ) - else - closureContext.lookupInContext(parameter, null, state, /* ignoreNoOuter = */ false) + StackValue.field( + parameter.getFieldInfoForCoroutineLambdaParameter(), generateThisOrOuter(context.thisDescriptor, false) + ) val mappedType = typeMapper.mapType(parameter.type) fieldStackValue.put(mappedType, v) @@ -277,14 +283,7 @@ class CoroutineCodegen private constructor( v.store(newIndex, mappedType) } - // necessary for proper tailrec codegen - val actualMethodStartLabel = Label() - v.visitLabel(actualMethodStartLabel) - context.setMethodStartLabel(actualMethodStartLabel) - - if (isSuspendLambda) { - initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters) - } + initializeVariablesForDestructuredLambdaParameters(this, originalSuspendFunctionDescriptor.valueParameters) } private fun allFunctionParameters() = @@ -308,7 +307,12 @@ class CoroutineCodegen private constructor( object : FunctionGenerationStrategy.FunctionDefault(state, element as KtDeclarationWithBody) { override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor { - return CoroutineTransformerMethodVisitor(mv, access, name, desc, null, null, v) + return CoroutineTransformerMethodVisitor( + mv, access, name, desc, null, null, + obtainClassBuilderForCoroutineState = { v }, + containingClassInternalName = v.thisName, + isForNamedFunction = false + ) } override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) { @@ -319,34 +323,17 @@ class CoroutineCodegen private constructor( ) } - override fun generateKotlinMetadataAnnotation() { - if (isSuspendLambda) { - super.generateKotlinMetadataAnnotation() - } - else { - writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) { - // Do not write method metadata for raw coroutine state machines - } - } - } - companion object { - fun shouldCreateByLambda( - originalSuspendLambdaDescriptor: CallableDescriptor, - declaration: KtElement): Boolean { - return (declaration is KtFunctionLiteral && originalSuspendLambdaDescriptor.isSuspendLambda) - } - @JvmStatic - fun createByLambda( + fun create( expressionCodegen: ExpressionCodegen, originalSuspendLambdaDescriptor: FunctionDescriptor, declaration: KtElement, classBuilder: ClassBuilder ): ClosureCodegen? { - if (!shouldCreateByLambda(originalSuspendLambdaDescriptor, declaration)) return null + if (declaration !is KtFunctionLiteral || !originalSuspendLambdaDescriptor.isSuspendLambda) return null - return CoroutineCodegen( + return CoroutineCodegenForLambda( expressionCodegen, declaration, expressionCodegen.context.intoCoroutineClosure( @@ -354,31 +341,116 @@ class CoroutineCodegen private constructor( originalSuspendLambdaDescriptor, expressionCodegen, expressionCodegen.state.typeMapper ), classBuilder, - originalSuspendLambdaDescriptor, - isSuspendLambda = true + originalSuspendLambdaDescriptor ) } + } +} +class CoroutineCodegenForNamedFunction private constructor( + outerExpressionCodegen: ExpressionCodegen, + element: KtElement, + closureContext: ClosureContext, + classBuilder: ClassBuilder, + originalSuspendFunctionDescriptor: FunctionDescriptor +) : AbstractCoroutineCodegen(outerExpressionCodegen, element, closureContext, classBuilder) { + private val classDescriptor = closureContext.contextDescriptor + + private val suspendFunctionJvmView = + bindingContext[CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendFunctionDescriptor]!! + + // protected fun doResume(): Any? + private val doResumeDescriptor = + SimpleFunctionDescriptorImpl.create( + classDescriptor, Annotations.EMPTY, Name.identifier(DO_RESUME_METHOD_NAME), CallableMemberDescriptor.Kind.DECLARATION, + funDescriptor.source + ).apply doResume@{ + initialize( + /* receiverParameterType = */ null, + classDescriptor.thisAsReceiverParameter, + /* typeParameters = */ emptyList(), + listOf(), + builtIns.nullableAnyType, + Modality.FINAL, + Visibilities.PUBLIC + ) + } + + override val passArityToSuperClass get() = false + + override fun generateBridges() { + // Do not generate any closure bridges + } + + override fun generateClosureBody() { + generateDoResume() + } + + private fun generateDoResume() { + functionCodegen.generateMethod( + OtherOrigin(element), + doResumeDescriptor, + object : FunctionGenerationStrategy.CodegenBased(state) { + override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) { + val captureThisType = closure.captureThis?.let(typeMapper::mapType) + if (captureThisType != null) { + StackValue.field( + captureThisType, Type.getObjectType(v.thisName), AsmUtil.CAPTURED_THIS_FIELD, + false, StackValue.LOCAL_0 + ).put(captureThisType, codegen.v) + } + + val callableMethod = typeMapper.mapToCallableMethod(suspendFunctionJvmView, false) + + for (argumentType in callableMethod.getAsmMethod().argumentTypes.dropLast(1)) { + AsmUtil.pushDefaultValueOnStack(argumentType, codegen.v) + } + + codegen.v.load(0, AsmTypes.OBJECT_TYPE) + callableMethod.genInvokeInstruction(codegen.v) + + codegen.v.visitInsn(Opcodes.ARETURN) + } + } + ) + } + + override fun generateKotlinMetadataAnnotation() { + writeKotlinMetadata(v, state, KotlinClassHeader.Kind.SYNTHETIC_CLASS, 0) { + // Do not write method metadata for raw coroutine state machines + } + } + + companion object { fun create( + cv: ClassBuilder, expressionCodegen: ExpressionCodegen, originalSuspendDescriptor: FunctionDescriptor, - declaration: KtFunction, - state: GenerationState - ): CoroutineCodegen { - val cv = state.factory.newVisitor( - OtherOrigin(declaration, originalSuspendDescriptor), - CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor), - declaration.containingFile - ) + declaration: KtFunction + ): CoroutineCodegenForNamedFunction { + val bindingContext = expressionCodegen.state.bindingContext + val closure = + bindingContext[ + CodegenBinding.CLOSURE, + bindingContext[CodegenBinding.CLASS_FOR_CALLABLE, originalSuspendDescriptor] + ].sure { "There must be a closure defined for $originalSuspendDescriptor" } - return CoroutineCodegen( + val suspendFunctionView = + bindingContext[ + CodegenBinding.SUSPEND_FUNCTION_TO_JVM_VIEW, originalSuspendDescriptor + ].sure { "There must be a jvm view defined for $originalSuspendDescriptor" } + + if (suspendFunctionView.dispatchReceiverParameter != null) { + closure.setCaptureThis() + } + + return CoroutineCodegenForNamedFunction( expressionCodegen, declaration, expressionCodegen.context.intoClosure( originalSuspendDescriptor, expressionCodegen, expressionCodegen.state.typeMapper ), cv, - originalSuspendDescriptor, - isSuspendLambda = false + originalSuspendDescriptor ) } } diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt index 9ccd4767b9b..5b2e335f1b6 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt @@ -44,8 +44,15 @@ class CoroutineTransformerMethodVisitor( desc: String, signature: String?, exceptions: Array?, - private val classBuilder: ClassBuilder + private val containingClassInternalName: String, + private val classBuilderForCoroutineState: ClassBuilder, + isForNamedFunction: Boolean ) : TransformationMethodVisitor(delegate, access, name, desc, signature, exceptions) { + + private val continuationIndex = if (isForNamedFunction) getLastParameterIndex(desc, access) else 0 + private val dataIndex = continuationIndex + 1 + private val exceptionIndex = dataIndex + 1 + override fun performTransformations(methodNode: MethodNode) { val customCoroutineStartMarker = methodNode.instructions.toArray().filterIsInstance().firstOrNull { it.owner == COROUTINE_MARKER_OWNER && it.name == ACTUAL_COROUTINE_START_MARKER_NAME @@ -61,7 +68,7 @@ class CoroutineTransformerMethodVisitor( } // Spill stack to variables before suspension points, try/catch blocks - FixStackWithLabelNormalizationMethodTransformer().transform(classBuilder.thisName, methodNode) + FixStackWithLabelNormalizationMethodTransformer().transform(containingClassInternalName, methodNode) // Remove unreachable suspension points // If we don't do this, then relevant frames will not be analyzed, that is unexpected from point of view of next steps (e.g. variable spilling) @@ -86,7 +93,7 @@ class CoroutineTransformerMethodVisitor( insnListOf( *withInstructionAdapter { loadCoroutineSuspendedMarker() }.toArray(), VarInsnNode(Opcodes.ASTORE, suspendMarkerVarIndex), - VarInsnNode(Opcodes.ALOAD, 0), + VarInsnNode(Opcodes.ALOAD, continuationIndex), FieldInsnNode( Opcodes.GETFIELD, COROUTINE_IMPL_ASM_TYPE.internalName, @@ -101,7 +108,7 @@ class CoroutineTransformerMethodVisitor( ) ) - insert(startLabel, withInstructionAdapter(InstructionAdapter::generateResumeWithExceptionCheck)) + insert(startLabel, withInstructionAdapter { generateResumeWithExceptionCheck(exceptionIndex) }) insert(last, withInstructionAdapter { visitLabel(defaultLabel.label) @@ -116,7 +123,7 @@ class CoroutineTransformerMethodVisitor( } private fun removeUnreachableSuspensionPointsAndExitPoints(methodNode: MethodNode, suspensionPoints: MutableList) { - val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(classBuilder.thisName, methodNode) + val dceResult = DeadCodeEliminationMethodTransformer().transformWithResult(containingClassInternalName, methodNode) // If the suspension call begin is alive and suspension call end is dead // (e.g., an inlined suspend function call ends with throwing a exception -- see KT-15017), @@ -166,7 +173,7 @@ class CoroutineTransformerMethodVisitor( private fun spillVariables(suspensionPoints: List, methodNode: MethodNode) { val instructions = methodNode.instructions - val frames = performRefinedTypeAnalysis(methodNode, classBuilder.thisName) + val frames = performRefinedTypeAnalysis(methodNode, containingClassInternalName) fun AbstractInsnNode.index() = instructions.indexOf(this) // We postpone these actions because they change instruction indices that we use when obtaining frames @@ -200,13 +207,15 @@ class CoroutineTransformerMethodVisitor( val livenessFrame = livenessFrames[suspensionCallBegin.index()] // 0 - this - // 1 - continuation argument - // 2 - continuation exception + // 1 - parameter + // ... + // k - continuation + // k + 1 - data + // k + 2 - exception val variablesToSpill = - (3 until localsCount) + ((exceptionIndex + 1) until localsCount) .map { Pair(it, frame.getLocal(it)) } - .filter { - val (index, value) = it + .filter { (index, value) -> value != StrictBasicValue.UNINITIALIZED_VALUE && livenessFrame.isAlive(index) } @@ -235,16 +244,16 @@ class CoroutineTransformerMethodVisitor( with(instructions) { // store variable before suspension call insertBefore(suspension.suspensionCallBegin, withInstructionAdapter { - load(0, AsmTypes.OBJECT_TYPE) + load(continuationIndex, AsmTypes.OBJECT_TYPE) load(index, type) StackValue.coerce(type, normalizedType, this) - putfield(classBuilder.thisName, fieldName, normalizedType.descriptor) + putfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor) }) // restore variable after suspension call insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter { - load(0, AsmTypes.OBJECT_TYPE) - getfield(classBuilder.thisName, fieldName, normalizedType.descriptor) + load(continuationIndex, AsmTypes.OBJECT_TYPE) + getfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor) StackValue.coerce(normalizedType, type, this) store(index, type) }) @@ -262,8 +271,8 @@ class CoroutineTransformerMethodVisitor( maxVarsCountByType.forEach { entry -> val (type, maxIndex) = entry for (index in 0..maxIndex) { - classBuilder.newField( - JvmDeclarationOrigin.NO_ORIGIN, Opcodes.ACC_PRIVATE, + classBuilderForCoroutineState.newField( + JvmDeclarationOrigin.NO_ORIGIN, AsmUtil.NO_FLAG_PACKAGE_PRIVATE, type.fieldNameForVar(index), type.descriptor, null, null) } } @@ -294,7 +303,7 @@ class CoroutineTransformerMethodVisitor( // Save state insertBefore(suspension.suspensionCallBegin, insnListOf( - VarInsnNode(Opcodes.ALOAD, 0), + VarInsnNode(Opcodes.ALOAD, continuationIndex), *withInstructionAdapter { iconst(id) }.toArray(), FieldInsnNode( Opcodes.PUTFIELD, COROUTINE_IMPL_ASM_TYPE.internalName, COROUTINE_LABEL_FIELD_NAME, @@ -327,10 +336,10 @@ class CoroutineTransformerMethodVisitor( remove(possibleTryCatchBlockStart.previous) insert(possibleTryCatchBlockStart, withInstructionAdapter { - generateResumeWithExceptionCheck() + generateResumeWithExceptionCheck(exceptionIndex) // Load continuation argument just like suspending function returns it - load(1, AsmTypes.OBJECT_TYPE) + load(dataIndex, AsmTypes.OBJECT_TYPE) visitLabel(continuationLabelAfterLoadedResult.label) }) @@ -389,9 +398,9 @@ class CoroutineTransformerMethodVisitor( } } -private fun InstructionAdapter.generateResumeWithExceptionCheck() { +private fun InstructionAdapter.generateResumeWithExceptionCheck(exceptionIndex: Int) { // Check if resumeWithException has been called - load(2, AsmTypes.OBJECT_TYPE) + load(exceptionIndex, AsmTypes.OBJECT_TYPE) dup() val noExceptionLabel = Label() ifnull(noExceptionLabel) @@ -432,3 +441,6 @@ private class SuspensionPoint( ) { lateinit var tryCatchBlocksContinuationLabel: LabelNode } + +private fun getLastParameterIndex(desc: String, access: Int) = + Type.getArgumentTypes(desc).dropLast(1).map { it.size }.sum() + (if (access and Opcodes.ACC_STATIC != 0) 0 else 1) diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/SuspendFunctionGenerationStrategy.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/SuspendFunctionGenerationStrategy.kt index fe16a57acdb..1f1d702551e 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/SuspendFunctionGenerationStrategy.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/SuspendFunctionGenerationStrategy.kt @@ -18,33 +18,116 @@ package org.jetbrains.kotlin.codegen.coroutines import org.jetbrains.kotlin.codegen.ExpressionCodegen import org.jetbrains.kotlin.codegen.FunctionGenerationStrategy +import org.jetbrains.kotlin.codegen.binding.CodegenBinding import org.jetbrains.kotlin.codegen.state.GenerationState import org.jetbrains.kotlin.descriptors.FunctionDescriptor import org.jetbrains.kotlin.psi.KtFunction import org.jetbrains.kotlin.resolve.jvm.AsmTypes +import org.jetbrains.kotlin.resolve.jvm.diagnostics.OtherOrigin import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature +import org.jetbrains.org.objectweb.asm.Label +import org.jetbrains.org.objectweb.asm.MethodVisitor +import org.jetbrains.org.objectweb.asm.Opcodes import org.jetbrains.org.objectweb.asm.Type class SuspendFunctionGenerationStrategy( state: GenerationState, private val originalSuspendDescriptor: FunctionDescriptor, - private val declaration: KtFunction + private val declaration: KtFunction, + private val containingClassInternalName: String ) : FunctionGenerationStrategy.CodegenBased(state) { + private val containsNonTailSuspensionCalls = originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext) + + private val classBuilderForCoroutineState by lazy { + state.factory.newVisitor( + OtherOrigin(declaration, originalSuspendDescriptor), + CodegenBinding.asmTypeForAnonymousClass(state.bindingContext, originalSuspendDescriptor), + declaration.containingFile + ) + } + + override fun wrapMethodVisitor(mv: MethodVisitor, access: Int, name: String, desc: String): MethodVisitor { + if (containsNonTailSuspensionCalls) { + return CoroutineTransformerMethodVisitor( + mv, access, name, desc, null, null, containingClassInternalName, classBuilderForCoroutineState, + isForNamedFunction = true + ) + } + + return super.wrapMethodVisitor(mv, access, name, desc) + } override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) { - if (!originalSuspendDescriptor.containsNonTailSuspensionCalls(state.bindingContext)) { + if (!containsNonTailSuspensionCalls) { codegen.returnExpression(declaration.bodyExpression) return } - val coroutineCodegen = CoroutineCodegen.create(codegen, originalSuspendDescriptor, declaration, state) + val coroutineCodegen = + CoroutineForNamedFunctionCodegen.create(classBuilderForCoroutineState, codegen, originalSuspendDescriptor, declaration) coroutineCodegen.generate() - codegen.markLineNumber(declaration, false) - codegen.putClosureInstanceOnStack(coroutineCodegen, null).put(Type.getObjectType(coroutineCodegen.className), codegen.v) - codegen.v.invokeDoResumeWithUnit(coroutineCodegen.v.thisName) + val functionDescriptor = codegen.context.functionDescriptor - codegen.markLineNumber(declaration, true) - codegen.v.areturn(AsmTypes.OBJECT_TYPE) + val continuationIndex = codegen.frameMap.getIndex(functionDescriptor.valueParameters.last()) + val objectTypeForState = Type.getObjectType(classBuilderForCoroutineState.thisName) + + val dataIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE) + val exceptionIndex = codegen.frameMap.enterTemp(AsmTypes.OBJECT_TYPE) + + with(codegen.v) { + val createStateInstance = Label() + val storeStateObject = Label() + + // We have to distinguish the following situations: + // - Our function got called in a common way (e.g. from another function or via recursive call) and we should execute our + // code from the beginning + // - We got called from `doResume` of our continuation, i.e. we need to continue from the last suspension point + // + // Also in the first case we wrap the completion into a special anonymous class instance (let's call it X$1) + // that we'll use as a continuation argument for suspension points + // + // How we distinguish the cases: + // - If the continuation is not an instance of X$1 we know exactly it's not the second case, because when resuming + // the continuation we pass an instance of that class + // - Otherwise it's still can be a recursive call. To check it's not the case we set the last bit in the label in + // `doResume` just before calling the suspend function (see kotlin.coroutines.experimental.jvm.internal.CoroutineImplForNamedFunction). + // So, if it's set we're in continuation. + visitVarInsn(Opcodes.ALOAD, continuationIndex) + instanceOf(objectTypeForState) + ifeq(createStateInstance) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + checkcast(objectTypeForState) + invokevirtual( + COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, + "checkAndFlushLastBit", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false + ) + ifeq(createStateInstance) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + checkcast(objectTypeForState) + goTo(storeStateObject) + + visitLabel(createStateInstance) + coroutineCodegen.putInstanceOnStack(codegen, null).put(objectTypeForState, this) + + visitLabel(storeStateObject) + visitVarInsn(Opcodes.ASTORE, continuationIndex) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + checkcast(objectTypeForState) + getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "data", AsmTypes.OBJECT_TYPE.descriptor) + visitVarInsn(Opcodes.ASTORE, dataIndex) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + checkcast(objectTypeForState) + getfield(COROUTINE_IMPL_FOR_NAMED_ASM_TYPE.internalName, "exception", AsmTypes.JAVA_THROWABLE_TYPE.descriptor) + visitVarInsn(Opcodes.ASTORE, exceptionIndex) + + invokestatic(COROUTINE_MARKER_OWNER, ACTUAL_COROUTINE_START_MARKER_NAME, "()V", false) + } + + codegen.returnExpression(declaration.bodyExpression) } } diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt index 37c43e63a89..2c5ceae84d4 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt @@ -72,6 +72,10 @@ val CONTINUATION_ASM_TYPE = DescriptorUtils.CONTINUATION_INTERFACE_FQ_NAME.topLe @JvmField val COROUTINE_IMPL_ASM_TYPE = COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImpl")).topLevelClassAsmType() +@JvmField +val COROUTINE_IMPL_FOR_NAMED_ASM_TYPE = + COROUTINES_JVM_INTERNAL_PACKAGE_FQ_NAME.child(Name.identifier("CoroutineImplForNamedFunction")).topLevelClassAsmType() + private val COROUTINES_INTRINSICS_FILE_FACADE_INTERNAL_NAME = COROUTINES_INTRINSICS_PACKAGE_FQ_NAME.child(Name.identifier("IntrinsicsKt")).topLevelClassAsmType() diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.java b/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.java index 8469a00f802..9762ba38b6c 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.java +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.java @@ -598,7 +598,7 @@ public class InlineCodegen extends CallGenerator { new SuspendFunctionGenerationStrategy( state, CoroutineCodegenUtilKt.unwrapInitialDescriptorForSuspendFunction(descriptor), - (KtFunction) expression + (KtFunction) expression, null ); } else { diff --git a/libraries/stdlib/src/kotlin/coroutines/experimental/jvm/internal/CoroutineImpl.kt b/libraries/stdlib/src/kotlin/coroutines/experimental/jvm/internal/CoroutineImpl.kt index f8f9964ded3..cab5882cb06 100644 --- a/libraries/stdlib/src/kotlin/coroutines/experimental/jvm/internal/CoroutineImpl.kt +++ b/libraries/stdlib/src/kotlin/coroutines/experimental/jvm/internal/CoroutineImpl.kt @@ -35,7 +35,7 @@ abstract class CoroutineImpl( // label == -1 when coroutine cannot be started (it is just a factory object) or has already finished execution // label == 0 in initial part of the coroutine @JvmField - protected var label: Int = if (completion != null) 0 else -1 + var label: Int = if (completion != null) 0 else -1 private val _context: CoroutineContext? = completion?.context @@ -71,3 +71,35 @@ abstract class CoroutineImpl( throw IllegalStateException("create(Any?;Continuation) has not been overridden") } } + +abstract class CoroutineImplForNamedFunction( + completion: Continuation? +) : CoroutineImpl(0, completion), Continuation { + companion object { + private const val LAST_BIT_MASK = 1 shl 31 + } + + @JvmField + var data: Any? = null + + @JvmField + var exception: Throwable? = null + + fun checkAndFlushLastBit(): Boolean { + if (label and LAST_BIT_MASK != 0) { + label -= LAST_BIT_MASK + return true + } + + return false + } + + override fun doResume(data: Any?, exception: Throwable?): Any? { + this.data = data + this.exception = exception + this.label = this.label or LAST_BIT_MASK + return doResume() + } + + protected abstract fun doResume(): Any? +}