JS: refactor coroutines to support inlining of suspend functions

This commit is contained in:
Alexey Andreev
2016-11-08 16:07:53 +03:00
parent e56d735723
commit 2cc299fb40
10 changed files with 193 additions and 79 deletions
@@ -1,6 +1,5 @@
// WITH_RUNTIME
// WITH_REFLECT
// TARGET_BACKEND: JVM
class Controller {
fun withValue(v: String, x: Continuation<String>) {
x.resume(v)
@@ -1,5 +1,5 @@
/*
* Copyright 2010-2014 JetBrains s.r.o.
* Copyright 2010-2016 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -62,10 +62,39 @@ var JsFunction.coroutineType: ClassDescriptor? by MetadataProperty(default = nul
var JsFunction.controllerType: ClassDescriptor? by MetadataProperty(default = null)
/**
* Denotes a suspension call-site that is to be processed by coroutine transformer.
* More clearly, denotes invocation that should immediately return from coroutine state machine
*/
var JsInvocation.isSuspend: Boolean by MetadataProperty(default = false)
/**
* Denotes a pre-suspend call-site that is to be processed by coroutine transformer.
* For normal suspend call-sites both [isSuspend] and [isPreSuspend] present.
* For inlined suspend calls fake calls are generated before and after inlined function body.
*/
var JsInvocation.isPreSuspend: Boolean by MetadataProperty(default = false)
/**
* Denotes a fake suspend call for inlining purposes.
*/
var JsInvocation.isFakeSuspend: Boolean by MetadataProperty(default = false)
/**
* Denotes a call to coroutine's controller `handleResult` function.
* See coroutine spec for explanation.
*/
var JsInvocation.isHandleResult: Boolean by MetadataProperty(default = false)
/**
* Denotes a reference to coroutine's `result` field that contains result of
* last suspended invocation.
*/
var JsNameRef.coroutineResult by MetadataProperty(default = false)
/**
* Denotes a reference to coroutine's `controller` field that contains coroutines's controller
*/
var JsNameRef.coroutineController by MetadataProperty(default = false)
enum class TypeCheck {
@@ -17,8 +17,7 @@
package org.jetbrains.kotlin.js.coroutine
import com.google.dart.compiler.backend.js.ast.*
import com.google.dart.compiler.backend.js.ast.metadata.MetadataProperty
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
import com.google.dart.compiler.backend.js.ast.metadata.*
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
import org.jetbrains.kotlin.utils.DFS
import org.jetbrains.kotlin.utils.singletonOrEmptyList
@@ -45,6 +44,7 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
private lateinit var nodesToSplit: Set<JsNode>
private var currentCatchBlock = globalCatchBlock
private val tryStack = mutableListOf(TryBlock(globalCatchBlock, null))
private var suspendTarget: CoroutineBlock? = null
var hasFinallyBlocks = false
get
@@ -188,8 +188,6 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
}
override fun visitIf(x: JsIf) = splitIfNecessary(x) {
x.ifExpression = handleExpression(x.ifExpression)
val ifBlock = currentBlock
val thenEntryBlock = CoroutineBlock()
@@ -433,27 +431,19 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
override fun visitExpressionStatement(x: JsExpressionStatement) {
val expression = x.expression
if (expression is JsInvocation && expression.isSuspend) {
handleSuspend(expression)
val splitExpression = handleExpression(expression)
if (splitExpression == expression) {
currentStatements += x
}
else {
val splitExpression = handleExpression(x.expression)
currentStatements += if (splitExpression == expression) x else JsExpressionStatement(expression)
else if (splitExpression != null) {
currentStatements += JsExpressionStatement(splitExpression).apply { synthetic = true }
}
}
override fun visitVars(x: JsVars) {
super.visitVars(x)
currentStatements += x
}
override fun visit(x: JsVars.JsVar) {
val initExpression = x.initExpression
if (initExpression != null) {
x.initExpression = handleExpression(initExpression)
}
}
override fun visitReturn(x: JsReturn) {
val returnBlock = CoroutineBlock()
val isInFinally = hasEnclosingFinallyBlock()
@@ -461,11 +451,6 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
jumpWithFinally(0, returnBlock)
}
val returnExpression = x.expression
if (returnExpression != null) {
x.expression = handleExpression(returnExpression)
}
if (isInFinally) {
currentStatements += x.expression?.makeStmt().singletonOrEmptyList()
currentStatements += jump()
@@ -479,9 +464,8 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
override fun visitThrow(x: JsThrow) {
if (throwFunctionName != null) {
val exception = handleExpression(x.expression)
val methodRef = JsNameRef(throwFunctionName, JsNameRef(controllerFieldName, JsLiteral.THIS))
val invocation = JsInvocation(methodRef, exception).apply {
val invocation = JsInvocation(methodRef, x.expression).apply {
source = x.source
}
currentStatements += JsReturn(invocation)
@@ -491,28 +475,35 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
}
}
private fun handleExpression(expression: JsExpression): JsExpression {
if (expression !in nodesToSplit) return expression
val visitor = object : JsVisitorWithContextImpl() {
override fun endVisit(x: JsInvocation, ctx: JsContext<in JsExpression>) {
if (x.isSuspend) {
ctx.replaceMe(handleSuspend(x))
}
super.endVisit(x, ctx)
private fun handleExpression(expression: JsExpression): JsExpression? {
return if (expression is JsInvocation) {
var result: JsExpression? = expression
if (expression.isPreSuspend) {
result = handlePreSuspend(expression)
}
if (expression.isSuspend) {
handleSuspend(expression)
result = null
}
result
}
else {
expression
}
return visitor.accept(expression)
}
private fun handleSuspend(invocation: JsInvocation): JsExpression {
private fun handlePreSuspend(invocation: JsInvocation): JsExpression? {
val nextBlock = CoroutineBlock()
currentStatements += state(nextBlock)
currentStatements += JsReturn(invocation)
currentBlock = nextBlock
suspendTarget = nextBlock
return JsNameRef(resultFieldName, JsLiteral.THIS)
return if (invocation.isFakeSuspend) null else invocation
}
private fun handleSuspend(invocation: JsInvocation) {
val invokeExpression = if (invocation.isFakeSuspend) invocation.arguments.getOrNull(0) else invocation
currentStatements += JsReturn(invokeExpression)
currentBlock = suspendTarget!!
}
private fun state(target: CoroutineBlock): List<JsStatement> {
@@ -17,11 +17,12 @@
package org.jetbrains.kotlin.js.coroutine
import com.google.dart.compiler.backend.js.ast.*
import com.google.dart.compiler.backend.js.ast.metadata.SideEffectKind
import com.google.dart.compiler.backend.js.ast.metadata.coroutineController
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
import com.google.dart.compiler.backend.js.ast.metadata.coroutineResult
import com.google.dart.compiler.backend.js.ast.metadata.sideEffects
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.js.inline.ExpressionDecomposer
import org.jetbrains.kotlin.js.inline.clean.FunctionPostProcessor
import org.jetbrains.kotlin.js.inline.util.collectLocalVariables
import org.jetbrains.kotlin.js.inline.util.getInnerFunction
@@ -44,19 +45,6 @@ class CoroutineFunctionTransformer(
private val className = function.scope.parent.declareFreshName("Coroutine\$${function.name}")
fun transform(): List<JsStatement> {
val visitor = object : JsVisitorWithContextImpl() {
override fun <T : JsNode?> doTraverse(node: T, ctx: JsContext<in JsStatement>) {
super.doTraverse(node, ctx)
if (node is JsStatement) {
val statements = ExpressionDecomposer.preserveEvaluationOrder(function.scope, node) {
it is JsInvocation && it.isSuspend
}
ctx.addPrevious(statements)
}
}
}
visitor.accept(body)
val throwFunction = controllerType.findFunction("handleException")
val throwName = throwFunction?.let {
val throwId = nameSuggestion.suggest(it)!!.names.last()
@@ -238,12 +226,23 @@ class CoroutineFunctionTransformer(
val visitor = object : JsVisitorWithContextImpl() {
override fun endVisit(x: JsNameRef, ctx: JsContext<in JsNode>) {
if (x.coroutineController) {
ctx.replaceMe(JsNameRef(transformer.controllerFieldName, x.qualifier))
}
if (x.qualifier == null && x.name in localVariables) {
val fieldName = scope.getFieldName(x.name!!)
ctx.replaceMe(JsNameRef(fieldName, JsLiteral.THIS))
when {
x.coroutineController -> {
ctx.replaceMe(JsNameRef(transformer.controllerFieldName, x.qualifier).apply {
sideEffects = SideEffectKind.PURE
})
}
x.coroutineResult -> {
ctx.replaceMe(JsNameRef(transformer.resultFieldName, x.qualifier).apply {
sideEffects = SideEffectKind.DEPENDS_ON_STATE
})
}
x.qualifier == null && x.name in localVariables -> {
val fieldName = scope.getFieldName(x.name!!)
ctx.replaceMe(JsNameRef(fieldName, JsLiteral.THIS))
}
}
}
@@ -455,7 +455,7 @@ internal open class JsExpressionVisitor() : JsVisitorWithContextImpl() {
/**
* Returns descendants of receiver, matched by [predicate].
*/
fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
private fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
val visitor = object : JsExpressionVisitor() {
val matched = IdentitySet<JsNode>()
@@ -475,7 +475,7 @@ fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
/**
* Returns set of nodes, that satisfy transitive closure of `is parent` relation, starting from [nodes].
*/
fun JsNode.withParentsOfNodes(nodes: Set<JsNode>): Set<JsNode> {
private fun JsNode.withParentsOfNodes(nodes: Set<JsNode>): Set<JsNode> {
val visitor = object : JsExpressionVisitor() {
private val stack = SmartList<JsNode>()
val matched = IdentitySet<JsNode>()
@@ -17,13 +17,14 @@
package org.jetbrains.kotlin.js.inline
import com.google.dart.compiler.backend.js.ast.*
import com.google.dart.compiler.backend.js.ast.metadata.staticRef
import com.google.dart.compiler.backend.js.ast.metadata.synthetic
import com.google.dart.compiler.backend.js.ast.metadata.*
import org.jetbrains.kotlin.js.inline.clean.removeDefaultInitializers
import org.jetbrains.kotlin.js.inline.clean.removeFakeSuspend
import org.jetbrains.kotlin.js.inline.context.InliningContext
import org.jetbrains.kotlin.js.inline.context.NamingContext
import org.jetbrains.kotlin.js.inline.util.*
import org.jetbrains.kotlin.js.inline.util.rewriters.ReturnReplacingVisitor
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
class FunctionInlineMutator
private constructor(
@@ -34,6 +35,7 @@ private constructor(
private val namingContext: NamingContext
private val body: JsBlock
private var resultExpr: JsExpression? = null
private var resultName: JsName? = null
private var breakLabel: JsLabel? = null
private val currentStatement = inliningContext.statementContext.currentNode
@@ -41,10 +43,20 @@ private constructor(
namingContext = inliningContext.newNamingContext()
val functionContext = inliningContext.functionContext
invokedFunction = uncoverClosure(functionContext.getFunctionDefinition(call).deepCopy())
body = invokedFunction.body
// Removing fakeSuspend is not just an optimization.
// Reentrant suspends are not supported by coroutine transformers.
body = if (call.isSuspend) invokedFunction.body.removeFakeSuspend() else invokedFunction.body
}
private fun process() {
if (call.isSuspend) {
val fakeSuspendCall = JsInvocation(JsAstUtils.pureFqn("fakeSuspend", JsAstUtils.pureFqn("Kotlin", null)))
fakeSuspendCall.isPreSuspend = true
fakeSuspendCall.isFakeSuspend = true
body.statements.add(0, JsAstUtils.asSyntheticStatement(fakeSuspendCall))
}
val arguments = getArguments()
val parameters = getParameters()
@@ -111,14 +123,19 @@ private constructor(
val breakName = namingContext.getFreshName(getBreakLabel())
this.breakLabel = JsLabel(breakName).apply { synthetic = true }
val visitor = ReturnReplacingVisitor(resultExpr as? JsNameRef, breakName.makeRef(), invokedFunction)
val visitor = ReturnReplacingVisitor(resultExpr as? JsNameRef, breakName.makeRef(), invokedFunction, call.isSuspend)
visitor.accept(body)
visitor.makeFakeSuspendCall(null)?.let { fakeSuspend ->
body.statements += JsAstUtils.asSyntheticStatement(fakeSuspend)
}
}
private fun getResultReference(): JsNameRef? {
if (!isResultNeeded(call)) return null
val resultName = namingContext.getFreshName(getResultLabel())
this.resultName = resultName
namingContext.newVar(resultName, null)
return resultName.makeRef()
}
@@ -0,0 +1,48 @@
/*
* Copyright 2010-2016 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jetbrains.kotlin.js.inline.clean
import com.google.dart.compiler.backend.js.ast.*
import com.google.dart.compiler.backend.js.ast.metadata.isFakeSuspend
import com.google.dart.compiler.backend.js.ast.metadata.isPreSuspend
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
import com.google.dart.compiler.backend.js.ast.metadata.synthetic
import org.jetbrains.kotlin.js.translate.context.Namer
fun <T : JsNode> T.removeFakeSuspend(): T {
val visitor = object : JsVisitorWithContextImpl() {
override fun endVisit(x: JsInvocation, ctx: JsContext<in JsNode>) {
if (x.isFakeSuspend) {
ctx.replaceMe(x.arguments.getOrElse(0) { Namer.getUndefinedExpression() })
}
else {
x.isSuspend = false
x.isPreSuspend = false
}
super.endVisit(x, ctx)
}
override fun visit(x: JsExpressionStatement, ctx: JsContext<*>): Boolean {
val expression = x.expression
if (expression is JsInvocation && expression.isFakeSuspend && expression.arguments.isEmpty()) {
x.synthetic = true
}
return super.visit(x, ctx)
}
}
return visitor.accept(this)
}
@@ -17,15 +17,14 @@
package org.jetbrains.kotlin.js.inline.util.rewriters
import com.google.dart.compiler.backend.js.ast.*
import com.google.dart.compiler.backend.js.ast.metadata.functionDescriptor
import com.google.dart.compiler.backend.js.ast.metadata.returnTarget
import com.google.dart.compiler.backend.js.ast.metadata.synthetic
import com.google.dart.compiler.backend.js.ast.metadata.*
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
class ReturnReplacingVisitor(
private val resultRef: JsNameRef?,
private val breakLabel: JsNameRef?,
private val function: JsFunction
private val function: JsFunction,
private val isSuspend: Boolean
) : JsVisitorWithContextImpl() {
/**
@@ -55,13 +54,26 @@ class ReturnReplacingVisitor(
private fun getReturnReplacement(returnExpression: JsExpression?): JsExpression? {
return if (returnExpression != null) {
val assignment = resultRef?.let {
JsAstUtils.assignment(resultRef, returnExpression).apply { synthetic = true }
val assignment = resultRef?.let { lhs ->
val rhs = makeFakeSuspendCall(returnExpression)!!
JsAstUtils.assignment(lhs, rhs).apply { synthetic = true }
}
assignment ?: returnExpression
assignment ?: makeFakeSuspendCall(returnExpression)
}
else {
null
makeFakeSuspendCall(null)
}
}
fun makeFakeSuspendCall(expression: JsExpression?): JsExpression? {
if (!isSuspend) return expression
val fakeSuspendCall = JsInvocation(JsAstUtils.pureFqn("fakeSuspend", JsAstUtils.pureFqn("Kotlin", null)))
fakeSuspendCall.isSuspend = true
fakeSuspendCall.isFakeSuspend = true
if (expression != null) {
fakeSuspendCall.arguments += expression
}
return fakeSuspendCall
}
}
@@ -18,7 +18,9 @@ package org.jetbrains.kotlin.js.translate.callTranslator
import com.google.dart.compiler.backend.js.ast.JsExpression
import com.google.dart.compiler.backend.js.ast.JsInvocation
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
import com.google.dart.compiler.backend.js.ast.JsLiteral
import com.google.dart.compiler.backend.js.ast.JsNameRef
import com.google.dart.compiler.backend.js.ast.metadata.*
import org.jetbrains.kotlin.descriptors.CallableDescriptor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.VariableDescriptor
@@ -26,10 +28,12 @@ import org.jetbrains.kotlin.js.translate.context.TranslationContext
import org.jetbrains.kotlin.js.translate.general.Translation
import org.jetbrains.kotlin.js.translate.reference.CallArgumentTranslator
import org.jetbrains.kotlin.js.translate.utils.AnnotationsUtils
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
import org.jetbrains.kotlin.psi.Call.CallType
import org.jetbrains.kotlin.resolve.calls.callResolverUtil.isInvokeCallOnVariable
import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall
import org.jetbrains.kotlin.resolve.calls.model.VariableAsFunctionResolvedCall
import org.jetbrains.kotlin.resolve.calls.resolvedCallUtil.getImplicitReceiverValue
import org.jetbrains.kotlin.resolve.calls.tasks.ExplicitReceiverKind
import org.jetbrains.kotlin.resolve.calls.tasks.ExplicitReceiverKind.NO_EXPLICIT_RECEIVER
import org.jetbrains.kotlin.resolve.scopes.receivers.ExpressionReceiver
@@ -118,8 +122,17 @@ private fun translateFunctionCall(context: TranslationContext,
explicitReceivers: ExplicitReceivers
): JsExpression {
val callExpression = context.getCallInfo(resolvedCall, explicitReceivers).translateFunctionCall()
if (resolvedCall.resultingDescriptor.isSuspend) {
(callExpression as JsInvocation).isSuspend = true
if (resolvedCall.resultingDescriptor.isSuspend && resolvedCall.resultingDescriptor.initialSignatureDescriptor != null) {
context.currentBlock.statements += JsAstUtils.asSyntheticStatement((callExpression as JsInvocation).apply {
isSuspend = true
isPreSuspend = true
})
val coroutineDescriptor = resolvedCall.getImplicitReceiverValue()!!.declarationDescriptor
val coroutineRef = context.getAliasForDescriptor(coroutineDescriptor) ?: JsLiteral.THIS
return context.defineTemporary(JsNameRef("\$\$coroutineResult\$\$", coroutineRef).apply {
sideEffects = SideEffectKind.DEPENDS_ON_STATE
coroutineResult = true
})
}
return callExpression
}
@@ -423,6 +423,12 @@ public final class StaticContext {
@Nullable
@Override
public JsName apply(@NotNull DeclarationDescriptor descriptor) {
if (descriptor instanceof FunctionDescriptor) {
FunctionDescriptor initialDescriptor = ((FunctionDescriptor) descriptor).getInitialSignatureDescriptor();
if (initialDescriptor != null) {
return getInnerNameForDescriptor(initialDescriptor);
}
}
if (descriptor instanceof ModuleDescriptor) {
return getModuleInnerName(descriptor);
}