diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/ir/NewIrUtils.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/ir/NewIrUtils.kt index d063119df2d..4cfde62b7c5 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/ir/NewIrUtils.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/ir/NewIrUtils.kt @@ -7,6 +7,7 @@ package org.jetbrains.kotlin.backend.konan.ir import org.jetbrains.kotlin.backend.common.atMostOne import org.jetbrains.kotlin.backend.konan.DECLARATION_ORIGIN_INLINE_CLASS_SPECIAL_FUNCTION +import org.jetbrains.kotlin.backend.konan.descriptors.allOverriddenFunctions import org.jetbrains.kotlin.backend.konan.descriptors.isInteropLibrary import org.jetbrains.kotlin.backend.konan.llvm.KonanMetadata import org.jetbrains.kotlin.backend.konan.serialization.KonanFileMetadataSource @@ -19,10 +20,7 @@ import org.jetbrains.kotlin.descriptors.konan.klibModuleOrigin import org.jetbrains.kotlin.ir.IrBuiltIns import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.declarations.lazy.IrLazyDeclarationBase -import org.jetbrains.kotlin.ir.expressions.IrCall -import org.jetbrains.kotlin.ir.expressions.IrConst -import org.jetbrains.kotlin.ir.expressions.IrConstructorCall -import org.jetbrains.kotlin.ir.expressions.IrExpression +import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl import org.jetbrains.kotlin.ir.symbols.IrClassSymbol @@ -107,6 +105,21 @@ fun buildSimpleAnnotation(irBuiltIns: IrBuiltIns, startOffset: Int, endOffset: I internal fun IrExpression.isBoxOrUnboxCall() = (this is IrCall && symbol.owner.origin == DECLARATION_ORIGIN_INLINE_CLASS_SPECIAL_FUNCTION) +internal fun IrBranch.isUnconditional(): Boolean = (condition as? IrConst<*>)?.value == true + +internal val IrFunctionAccessExpression.actualCallee: IrFunction + get() { + val callee = symbol.owner + return ((this as? IrCall)?.superQualifierSymbol?.owner?.getOverridingOf(callee) ?: callee).target + } + +internal val IrFunctionAccessExpression.isVirtualCall: Boolean + get() = this is IrCall && this.superQualifierSymbol == null && this.symbol.owner.isOverridable + +private fun IrClass.getOverridingOf(function: IrFunction) = (function as? IrSimpleFunction)?.let { + it.allOverriddenFunctions.atMostOne { it.parent == this } +} + val ModuleDescriptor.konanLibrary get() = (this.klibModuleOrigin as? DeserializedKlibModuleOrigin)?.library val IrModuleFragment.konanLibrary get() = (this as? KonanIrModuleFragmentImpl)?.konanLibrary ?: descriptor.konanLibrary diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/BitcodePhases.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/BitcodePhases.kt index a1c7643c559..2fe48d50a9a 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/BitcodePhases.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/BitcodePhases.kt @@ -12,18 +12,16 @@ import org.jetbrains.kotlin.backend.common.phaser.PhaserState import org.jetbrains.kotlin.backend.common.phaser.namedUnitPhase import org.jetbrains.kotlin.backend.konan.* import org.jetbrains.kotlin.backend.konan.descriptors.GlobalHierarchyAnalysis -import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER -import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER -import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER import org.jetbrains.kotlin.backend.konan.lower.RedundantCoercionsCleaner import org.jetbrains.kotlin.backend.konan.lower.ReturnsInsertionLowering import org.jetbrains.kotlin.backend.konan.optimizations.* import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.IrStatement import org.jetbrains.kotlin.ir.declarations.* -import org.jetbrains.kotlin.ir.expressions.IrBlockBody -import org.jetbrains.kotlin.ir.expressions.IrCall -import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.util.defaultType +import org.jetbrains.kotlin.ir.util.isFunction +import org.jetbrains.kotlin.ir.util.isReal +import org.jetbrains.kotlin.ir.util.parentAsClass import org.jetbrains.kotlin.ir.visitors.* import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.utils.addToStdlib.cast @@ -248,33 +246,11 @@ internal val removeRedundantCallsToFileInitializersPhase = makeKonanModuleOpPhas nonDevirtualizedCallSitesUnfoldFactor = Int.MAX_VALUE ).build() - val functionsBeingCalledFromOtherFiles = mutableSetOf() - for (node in callGraph.directEdges.values) { - val callerFile = node.symbol.irFile - node.callSites.forEach { - require(!it.isVirtual) { "There should be no virtual calls in the call graph, but was: ${it.actualCallee}" } - val calleeFile = it.actualCallee.irFile - if (callerFile == null || callerFile != calleeFile) - functionsBeingCalledFromOtherFiles.add(it.actualCallee.irFunction ?: error("No IR for: ${it.actualCallee}")) - } - } + val rootSet = DevirtualizationAnalysis.computeRootSet(context, moduleDFG, externalModulesDFG) + .mapNotNull { it.irFunction } + .toSet() - val rootSet = DevirtualizationAnalysis.computeRootSet(context, moduleDFG, externalModulesDFG).toSet() - context.irModule!!.transformChildrenVoid(object : IrElementTransformerVoid() { - override fun visitFunction(declaration: IrFunction): IrStatement { - declaration.transformChildrenVoid(this) - if (declaration in functionsBeingCalledFromOtherFiles - || moduleDFG.symbolTable.mapFunction(declaration) in rootSet) return declaration - val body = declaration.body ?: return declaration - (body as IrBlockBody).statements.removeAll { - val calleeOrigin = (it as? IrCall)?.symbol?.owner?.origin - calleeOrigin == DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER - || calleeOrigin == DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER - || calleeOrigin == DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER - } - return declaration - } - }) + FileInitializersOptimization.removeRedundantCalls(context, callGraph, rootSet) } ) diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/IrToBitcode.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/IrToBitcode.kt index 6243758af00..549f8e89126 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/IrToBitcode.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/llvm/IrToBitcode.kt @@ -1271,7 +1271,7 @@ internal class CodeGeneratorVisitor(val context: Context, val lifetimes: Map // If branch condition is constant. - && (branch.condition as IrConst<*>).value as Boolean // If condition is "true" - //-------------------------------------------------------------------------// private fun evaluateWhileLoop(loop: IrWhileLoop): LLVMValueRef { diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/Autoboxing.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/Autoboxing.kt index 3327b36a81d..79a99dc7585 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/Autoboxing.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/Autoboxing.kt @@ -133,8 +133,7 @@ private class AutoboxingTransformer(val context: Context) : AbstractValueUsageTr } private val IrCall.callTarget: IrFunction - get() = if (superQualifierSymbol == null && symbol.owner.isOverridable) { - // A virtual call. + get() = if (this.isVirtualCall) { symbol.owner } else { symbol.owner.target diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/FileInitializersLowering.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/FileInitializersLowering.kt index 07d661ea86b..78a70d4c255 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/FileInitializersLowering.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/FileInitializersLowering.kt @@ -31,6 +31,11 @@ internal object DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER : IrDeclarationOrigin internal object DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER : IrDeclarationOriginImpl("FILE_THREAD_LOCAL_INITIALIZER") internal object DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER : IrDeclarationOriginImpl("FILE_STANDALONE_THREAD_LOCAL_INITIALIZER") +internal val IrFunction.isFileInitializer: Boolean + get() = origin == DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER + || origin == DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER + || origin == DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER + internal fun IrBuilderWithScope.irCallFileInitializer(initializer: IrFunctionSymbol) = irCall(initializer).apply { putValueArgument(0, irFalse()) } diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DFGBuilder.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DFGBuilder.kt index b43e75fb17b..2dd4469bf62 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DFGBuilder.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DFGBuilder.kt @@ -42,9 +42,8 @@ internal class ExternalModulesDFG(val allTypes: List, val publicFunctions: Map, val functionDFGs: Map) -private fun IrClass.getOverridingOf(function: IrFunction) = (function as? IrSimpleFunction)?.let { - it.allOverriddenFunctions.atMostOne { it.parent == this } -} +internal object STATEMENT_ORIGIN_PRODUCER_INVOCATION : IrStatementOriginImpl("PRODUCER_INVOCATION") +internal object STATEMENT_ORIGIN_JOB_INVOCATION : IrStatementOriginImpl("JOB_INVOCATION") private fun IrTypeOperator.isCast() = this == IrTypeOperator.CAST || this == IrTypeOperator.IMPLICIT_CAST || this == IrTypeOperator.SAFE_CAST @@ -336,7 +335,8 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag executeImplProducerInvoke.returnType, executeImplProducerInvoke.symbol, executeImplProducerInvoke.symbol.owner.typeParameters.size, - executeImplProducerInvoke.symbol.owner.valueParameters.size) + executeImplProducerInvoke.symbol.owner.valueParameters.size, + STATEMENT_ORIGIN_PRODUCER_INVOCATION) producerInvocation.dispatchReceiver = expression.getValueArgument(2) expressions += producerInvocation to currentLoop @@ -347,7 +347,8 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag jobFunctionReference.symbol.owner.returnType, jobFunctionReference.symbol as IrSimpleFunctionSymbol, jobFunctionReference.symbol.owner.typeParameters.size, - jobFunctionReference.symbol.owner.valueParameters.size) + jobFunctionReference.symbol.owner.valueParameters.size, + STATEMENT_ORIGIN_JOB_INVOCATION) jobInvocation.putValueArgument(0, producerInvocation) expressions += jobInvocation to currentLoop @@ -706,9 +707,7 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag getContinuationSymbol -> getContinuation().value in arrayGetSymbols -> { - val callee = value.symbol.owner - val actualCallee = (value.superQualifierSymbol?.owner?.getOverridingOf(callee) - ?: callee).target + val actualCallee = value.actualCallee DataFlowIR.Node.ArrayRead( symbolTable.mapFunction(actualCallee), @@ -719,9 +718,7 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag } in arraySetSymbols -> { - val callee = value.symbol.owner - val actualCallee = (value.superQualifierSymbol?.owner?.getOverridingOf(callee) - ?: callee).target + val actualCallee = value.actualCallee DataFlowIR.Node.ArrayWrite( symbolTable.mapFunction(actualCallee), array = expressionToEdge(value.dispatchReceiver!!), @@ -764,7 +761,7 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag if (callee is IrConstructor) { error("Constructor call should be done with IrConstructorCall") } else { - if (callee.isOverridable && value.superQualifierSymbol == null) { + if (value.isVirtualCall) { val owner = callee.parentAsClass val actualReceiverType = value.dispatchReceiver!!.type val actualReceiverClassifier = actualReceiverType.classifierOrFail @@ -813,7 +810,7 @@ internal class ModuleDFGBuilder(val context: Context, val irModule: IrModuleFrag ) } } else { - val actualCallee = (value.superQualifierSymbol?.owner?.getOverridingOf(callee) ?: callee).target + val actualCallee = value.actualCallee DataFlowIR.Node.StaticCall( symbolTable.mapFunction(actualCallee), arguments, diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DataFlowIR.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DataFlowIR.kt index b2d115c0253..aeb486e5054 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DataFlowIR.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DataFlowIR.kt @@ -124,7 +124,7 @@ internal object DataFlowIR { class FunctionParameter(val type: Type, val boxFunction: FunctionSymbol?, val unboxFunction: FunctionSymbol?) - abstract class FunctionSymbol(val attributes: Int, val irFile: IrFile?, val irFunction: IrFunction?, val name: String?) { + abstract class FunctionSymbol(val attributes: Int, val irDeclaration: IrDeclaration?, val name: String?) { lateinit var parameters: Array lateinit var returnParameter: FunctionParameter @@ -134,11 +134,14 @@ internal object DataFlowIR { val returnsNothing = attributes.and(FunctionAttributes.RETURNS_NOTHING) != 0 val explicitlyExported = attributes.and(FunctionAttributes.EXPLICITLY_EXPORTED) != 0 + val irFunction: IrFunction? get() = irDeclaration as? IrFunction + val irFile: IrFile? get() = irDeclaration?.fileOrNull + var escapes: Int? = null var pointsTo: IntArray? = null - class External(val hash: Long, attributes: Int, irFile: IrFile?, irFunction: IrFunction?, name: String? = null, val isExported: Boolean) - : FunctionSymbol(attributes, irFile, irFunction, name) { + class External(val hash: Long, attributes: Int, irDeclaration: IrDeclaration?, name: String? = null, val isExported: Boolean) + : FunctionSymbol(attributes, irDeclaration, name) { override fun equals(other: Any?): Boolean { if (this === other) return true @@ -157,14 +160,14 @@ internal object DataFlowIR { } abstract class Declared(val module: Module, val symbolTableIndex: Int, - attributes: Int, irFile: IrFile?, irFunction: IrFunction?, var bridgeTarget: FunctionSymbol?, name: String?) - : FunctionSymbol(attributes, irFile, irFunction, name) { + attributes: Int, irDeclaration: IrDeclaration?, var bridgeTarget: FunctionSymbol?, name: String?) + : FunctionSymbol(attributes, irDeclaration, name) { } class Public(val hash: Long, module: Module, symbolTableIndex: Int, - attributes: Int, irFile: IrFile?, irFunction: IrFunction?, bridgeTarget: FunctionSymbol?, name: String? = null) - : Declared(module, symbolTableIndex, attributes, irFile, irFunction, bridgeTarget, name) { + attributes: Int, irDeclaration: IrDeclaration?, bridgeTarget: FunctionSymbol?, name: String? = null) + : Declared(module, symbolTableIndex, attributes, irDeclaration, bridgeTarget, name) { override fun equals(other: Any?): Boolean { if (this === other) return true @@ -183,8 +186,8 @@ internal object DataFlowIR { } class Private(val index: Int, module: Module, symbolTableIndex: Int, - attributes: Int, irFile: IrFile?, irFunction: IrFunction?, bridgeTarget: FunctionSymbol?, name: String? = null) - : Declared(module, symbolTableIndex, attributes, irFile, irFunction, bridgeTarget, name) { + attributes: Int, irDeclaration: IrDeclaration?, bridgeTarget: FunctionSymbol?, name: String? = null) + : Declared(module, symbolTableIndex, attributes, irDeclaration, bridgeTarget, name) { override fun equals(other: Any?): Boolean { if (this === other) return true @@ -595,7 +598,7 @@ internal object DataFlowIR { val escapesBitMask = (escapesAnnotation?.getValueArgument(0) as? IrConst)?.value @Suppress("UNCHECKED_CAST") val pointsToBitMask = (pointsToAnnotation?.getValueArgument(0) as? IrVararg)?.elements?.map { (it as IrConst).value } - FunctionSymbol.External(name.localHash.value, attributes, it.fileOrNull, it, takeName { name }, it.isExported()).apply { + FunctionSymbol.External(name.localHash.value, attributes, it, takeName { name }, it.isExported()).apply { escapes = escapesBitMask pointsTo = pointsToBitMask?.toIntArray() } @@ -615,9 +618,9 @@ internal object DataFlowIR { val symbolTableIndex = if (placeToFunctionsTable) module.numberOfFunctions++ else -1 val frozen = it is IrConstructor && irClass!!.annotations.findAnnotation(KonanFqNames.frozen) != null val functionSymbol = if (it.isExported()) - FunctionSymbol.Public(name.localHash.value, module, symbolTableIndex, attributes, it.fileOrNull, it, bridgeTargetSymbol, takeName { name }) + FunctionSymbol.Public(name.localHash.value, module, symbolTableIndex, attributes, it, bridgeTargetSymbol, takeName { name }) else - FunctionSymbol.Private(privateFunIndex++, module, symbolTableIndex, attributes, it.fileOrNull, it, bridgeTargetSymbol, takeName { name }) + FunctionSymbol.Private(privateFunIndex++, module, symbolTableIndex, attributes, it, bridgeTargetSymbol, takeName { name }) if (frozen) { functionSymbol.escapes = 0b1 // Assume instances of frozen classes escape. } @@ -647,7 +650,7 @@ internal object DataFlowIR { assert(irField.parent !is IrClass) { "All local properties initializers should've been lowered" } val attributes = FunctionAttributes.IS_TOP_LEVEL_FIELD_INITIALIZER or FunctionAttributes.RETURNS_UNIT -val symbol = FunctionSymbol.Private(privateFunIndex++, module, -1, attributes, irField.fileOrNull, null, null, takeName { "${irField.computeSymbolName()}_init" }) + val symbol = FunctionSymbol.Private(privateFunIndex++, module, -1, attributes, irField, null, takeName { "${irField.computeSymbolName()}_init" }) functionMap[irField] = symbol diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt index b376f689488..b17e9862f02 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt @@ -45,6 +45,8 @@ inline fun BitSet.forEachBit(block: (Int) -> Unit) { } } +fun BitSet.copy() = BitSet(this.size()).apply { this.or(this@copy) } + // Devirtualization analysis is performed using Variable Type Analysis algorithm. // See http://web.cs.ucla.edu/~palsberg/tba/papers/sundaresan-et-al-oopsla00.pdf for details. internal object DevirtualizationAnalysis { @@ -264,8 +266,6 @@ internal object DevirtualizationAnalysis { else -> error("Unreachable") } - fun BitSet.copy() = BitSet(this.size()).apply { this.or(this@copy) } - fun logPathToType(reversedEdges: IntArray, node: Node, type: Int) { val nodes = constraintGraph.nodes val visited = BitSet() diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/FileInitializersOptimization.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/FileInitializersOptimization.kt new file mode 100644 index 00000000000..6c1ef09678d --- /dev/null +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/FileInitializersOptimization.kt @@ -0,0 +1,625 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.backend.konan.optimizations + +import org.jetbrains.kotlin.backend.common.atMostOne +import org.jetbrains.kotlin.backend.common.lower.createIrBuilder +import org.jetbrains.kotlin.backend.common.lower.irBlock +import org.jetbrains.kotlin.backend.konan.Context +import org.jetbrains.kotlin.backend.konan.DirectedGraphCondensationBuilder +import org.jetbrains.kotlin.backend.konan.DirectedGraphMultiNode +import org.jetbrains.kotlin.backend.konan.ir.actualCallee +import org.jetbrains.kotlin.backend.konan.ir.isOverridable +import org.jetbrains.kotlin.backend.konan.ir.isUnconditional +import org.jetbrains.kotlin.backend.konan.ir.isVirtualCall +import org.jetbrains.kotlin.backend.konan.logMultiple +import org.jetbrains.kotlin.backend.konan.lower.* +import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER +import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER +import org.jetbrains.kotlin.backend.konan.lower.DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER +import org.jetbrains.kotlin.ir.IrElement +import org.jetbrains.kotlin.ir.IrStatement +import org.jetbrains.kotlin.ir.builders.IrBuilderWithScope +import org.jetbrains.kotlin.ir.builders.irBlock +import org.jetbrains.kotlin.ir.declarations.* +import org.jetbrains.kotlin.ir.declarations.lazy.IrLazyClass +import org.jetbrains.kotlin.ir.expressions.* +import org.jetbrains.kotlin.ir.symbols.IrReturnTargetSymbol +import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.visitors.* +import java.util.* + +/* + * A data flow analysis to remove or move around calls to file initializers. + * The goal is to find for each function and for each call site the set of + * definitely initialized files before the corresponding call. + * + * This is done in three quite similar steps using global interprocedural analysis: + * 1. For each function find the set of definitely initialized files after returning from the function. + * Handle all the functions in the reverse topological order. + * 2. For each function find the set of definitely initialized files before executing the function's body. + * Handle all the functions in the topological order and use the results of the first step + * for updating the result after some function call. + * 3. For each call site find the set of definitely initialized files before the actual call is made. + * Handle all the functions in the arbitrary order and use the results of the first step + * for updating the result. Then use the results from the second step to see if the initializer call + * could be extracted from the callee to the call site. + * + * All three steps use similar local intraprocedural data flow analysis on IR using an IR visitor + * taking the set of already initialized files before evaluating some expression and returning the modified + * set after evaluating that expression. + */ + +internal object FileInitializersOptimization { + private class AnalysisResult(val functionsRequiringGlobalInitializerCall: Set, + val functionsRequiringThreadLocalInitializerCall: Set, + val callSitesRequiringGlobalInitializerCall: Set, + val callSitesRequiringThreadLocalInitializerCall: Set) + + private class InitializedFiles(val fileIds: Map) { + val afterCall = mutableMapOf() + val beforeCallGlobal = mutableMapOf() + val beforeCallThreadLocal = mutableMapOf() + } + + private val invalidFileId = 0 + + private class InterproceduralAnalysis(val context: Context, val callGraph: CallGraph, + val rootSet: Set) { + fun analyze(): AnalysisResult { + context.logMultiple { + +"CALL GRAPH" + callGraph.directEdges.forEach { (t, u) -> + +" FUN $t" + u.callSites.forEach { + val label = when { + it.isVirtual -> "VIRTUAL" + callGraph.directEdges.containsKey(it.actualCallee) -> "LOCAL" + else -> "EXTERNAL" + } + +" CALLS $label ${it.actualCallee}" + } + callGraph.reversedEdges[t]!!.forEach { +" CALLED BY $it" } + } + +"" + } + + val condensation = DirectedGraphCondensationBuilder(callGraph).build() + + context.logMultiple { + +"CONDENSATION" + condensation.topologicalOrder.forEach { multiNode -> + +" MULTI-NODE" + multiNode.nodes.forEach { +" $it" } + } + +"" + +"CONDENSATION(DETAILED)" + condensation.topologicalOrder.forEach { multiNode -> + +" MULTI-NODE" + multiNode.nodes.forEach { + +" $it" + callGraph.directEdges[it]!!.callSites + .filter { callGraph.directEdges.containsKey(it.actualCallee) } + .forEach { +" CALLS ${it.actualCallee}" } + callGraph.reversedEdges[it]!!.forEach { +" CALLED BY $it" } + } + } + +"" + } + + var fileId = invalidFileId + val fileIds = mutableMapOf() + for (node in callGraph.directEdges.values) { + val callerFile = node.symbol.irFile + if (callerFile != null && fileIds[callerFile] == null) + fileIds[callerFile] = ++fileId + for (callSite in node.callSites) { + val calleeFile = callSite.actualCallee.irFile + if (calleeFile != null && fileIds[calleeFile] == null) + fileIds[calleeFile] = ++fileId + } + } + + val initializedFiles = InitializedFiles(fileIds) + + context.log { "FIRST PHASE: compute initialized after call" } + + for (multiNode in condensation.topologicalOrder.reversed()) + analyze(multiNode, initializedFiles, AnalysisGoal.ComputeInitializedAfterCall) + + context.log { "SECOND PHASE: compute initialized before call" } + + // Each function from the root set can be called as the first one, so pessimistically assume that + // none of the files has been initialized yet. + for (node in callGraph.directEdges.values) { + val function = node.symbol.irFunction ?: continue + if (function in rootSet) { + initializedFiles.beforeCallGlobal[function] = BitSet() + initializedFiles.beforeCallThreadLocal[function] = BitSet() + } + } + + for (multiNode in condensation.topologicalOrder) + analyze(multiNode, initializedFiles, AnalysisGoal.ComputeInitializedBeforeCall) + + context.log { "THIRD PHASE: collect call sites" } + + val callSitesRequiringGlobalInitializerCall = mutableSetOf() + val callSitesRequiringThreadLocalInitializerCall = mutableSetOf() + val callSitesNotRequiringGlobalInitializerCall = mutableSetOf() + val callSitesNotRequiringThreadLocalInitializerCall = mutableSetOf() + + for (node in callGraph.directEdges.values) { + intraproceduralAnalysis(node, initializedFiles, AnalysisGoal.CollectCallSites, + callSitesRequiringGlobalInitializerCall, callSitesRequiringThreadLocalInitializerCall, + callSitesNotRequiringGlobalInitializerCall, callSitesNotRequiringThreadLocalInitializerCall) + } + + fun collectFunctionsRequiringInitializerCall( + initializedFiles: Map, + functionsWhoseInitializerCallCanBeExtractedToCallSites: Set + ): Set { + val result = mutableSetOf() + initializedFiles.forEach { (function, functionInitializedFiles) -> + val irFile = function.fileOrNull + val backingField = (function as? IrSimpleFunction)?.correspondingPropertySymbol?.owner?.backingField + val isDefaultAccessor = backingField != null && function.origin == IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR + if (irFile == null || + (!functionInitializedFiles.get(fileIds[irFile]!!) + && function !in functionsWhoseInitializerCallCanBeExtractedToCallSites + // Extract calls to file initializers off of default accessors to simplify their inlining. + && (!isDefaultAccessor || function in rootSet)) + ) { + result += function + } + } + return result + } + + val functionsRequiringGlobalInitializerCall = collectFunctionsRequiringInitializerCall( + initializedFiles.beforeCallGlobal, + callSitesRequiringGlobalInitializerCall.map { it.actualCallee } + .toMutableSet().intersect(callSitesNotRequiringGlobalInitializerCall.map { it.actualCallee }) + ) + val functionsRequiringThreadLocalInitializerCall = collectFunctionsRequiringInitializerCall( + initializedFiles.beforeCallThreadLocal, + callSitesRequiringThreadLocalInitializerCall.map { it.actualCallee } + .toMutableSet().intersect(callSitesNotRequiringThreadLocalInitializerCall.map { it.actualCallee }) + ) + + return AnalysisResult(functionsRequiringGlobalInitializerCall, functionsRequiringThreadLocalInitializerCall, + callSitesRequiringGlobalInitializerCall, callSitesRequiringThreadLocalInitializerCall) + } + + private fun analyze(multiNode: DirectedGraphMultiNode, + initializedFiles: InitializedFiles, + analysisGoal: AnalysisGoal) { + val nodes = multiNode.nodes.toList() + + context.logMultiple { + +"Analyzing multiNode:\n ${nodes.joinToString("\n ") { it.toString() }}" + nodes.forEach { from -> + +"IR" + +(from.irFunction?.dump() ?: "") + callGraph.directEdges[from]!!.callSites.forEach { to -> + +"CALL" + +" from $from" + +" to ${to.actualCallee}" + } + } + } + + if (nodes.size == 1) + intraproceduralAnalysis(callGraph.directEdges[nodes[0]] ?: return, initializedFiles, analysisGoal) + else { + nodes.forEach { intraproceduralAnalysis(callGraph.directEdges[it]!!, initializedFiles, analysisGoal) } + // The process is convergent since files can only be removed from the sets. + var sum = nodes.sumOf { + (initializedFiles.beforeCallGlobal[it.irFunction!!]?.cardinality() ?: 0) + + (initializedFiles.beforeCallThreadLocal[it.irFunction!!]?.cardinality() ?: 0) + + (initializedFiles.afterCall[it.irFunction!!]?.cardinality() ?: 0) + } + do { + val prevSum = sum + nodes.forEach { intraproceduralAnalysis(callGraph.directEdges[it]!!, initializedFiles, analysisGoal) } + sum = nodes.sumOf { + (initializedFiles.beforeCallGlobal[it.irFunction!!]?.cardinality() ?: 0) + + (initializedFiles.beforeCallThreadLocal[it.irFunction!!]?.cardinality() ?: 0) + + (initializedFiles.afterCall[it.irFunction!!]?.cardinality() ?: 0) + } + } while (sum != prevSum) + } + } + + private val executeImplSymbol = context.ir.symbols.executeImpl + private val getContinuationSymbol = context.ir.symbols.getContinuation + + private var dummySet = mutableSetOf() + + private enum class AnalysisGoal { + ComputeInitializedAfterCall, + ComputeInitializedBeforeCall, + CollectCallSites + } + + private fun IrFunction.callsFileInitializer() = + (body?.statements?.get(0) as? IrCall)?.symbol?.owner?.isFileInitializer == true + + private fun intraproceduralAnalysis( + node: CallGraphNode, + initializedFiles: InitializedFiles, + analysisGoal: AnalysisGoal, + callSitesRequiringGlobalInitializerCall: MutableSet = dummySet, + callSitesRequiringThreadLocalInitializerCall: MutableSet = dummySet, + callSitesNotRequiringGlobalInitializerCall: MutableSet = dummySet, + callSitesNotRequiringThreadLocalInitializerCall: MutableSet = dummySet + ) { + val irDeclaration = node.symbol.irDeclaration ?: return + val body = if (node.symbol.isTopLevelFieldInitializer) + (irDeclaration as IrField).initializer?.expression + else { + val function = irDeclaration as IrFunction + val builder = context.createIrBuilder(function.symbol) + function.body?.let { body -> builder.irBlock { (body as IrBlockBody).statements.forEach { +it } } } + } + if (body == null) return + + val filesWithInitializedGlobals = BitSet() + val filesWithInitializedThreadLocals = BitSet() + if (!node.symbol.isTopLevelFieldInitializer) { + initializedFiles.beforeCallGlobal[irDeclaration as IrFunction]?.let { filesWithInitializedGlobals.or(it) } + initializedFiles.beforeCallThreadLocal[irDeclaration]?.let { filesWithInitializedThreadLocals.or(it) } + } + + val producerInvocations = mutableMapOf() + val jobInvocations = mutableMapOf() + val virtualCallSites = mutableMapOf>() + for (callSite in node.callSites) { + val call = callSite.call + val irCall = call.irCallSite as? IrCall ?: continue + if (irCall.origin == STATEMENT_ORIGIN_PRODUCER_INVOCATION) + producerInvocations[irCall.dispatchReceiver!!] = irCall + else if (irCall.origin == STATEMENT_ORIGIN_JOB_INVOCATION) + jobInvocations[irCall.getValueArgument(0) as IrCall] = irCall + if (call !is DataFlowIR.Node.VirtualCall) continue + virtualCallSites.getOrPut(irCall) { mutableListOf() }.add(callSite) + } + val returnTargetsInitializedFiles = mutableMapOf() + val initializedFilesAtLoopsBreaks = mutableMapOf() + val initializedFilesAtLoopsContinues = mutableMapOf() + // Each visitXXX function gets as [data] parameter the set of initialized files before evaluating + // current element and returns the set of initialized files after evaluating this element. + val callerResult = body.accept(object : IrElementVisitor { + private fun intersectInitializedFiles(previous: BitSet?, current: BitSet) = + previous?.copy()?.also { it.and(current) } ?: current + + private fun intersectInitializedFiles(map: MutableMap, key: K, set: BitSet) { + val previous = map[key] + if (previous == null) + map[key] = set.copy() + else + previous.and(set) + } + + override fun visitElement(element: IrElement, data: BitSet): BitSet = TODO(element.render()) + override fun visitExpression(expression: IrExpression, data: BitSet): BitSet = TODO(expression.render()) + override fun visitDeclaration(declaration: IrDeclarationBase, data: BitSet): BitSet = TODO(declaration.render()) + + override fun visitTypeOperator(expression: IrTypeOperatorCall, data: BitSet) = expression.argument.accept(this, data) + override fun visitConst(expression: IrConst, data: BitSet) = data + override fun visitInstanceInitializerCall(expression: IrInstanceInitializerCall, data: BitSet) = data + + override fun visitGetValue(expression: IrGetValue, data: BitSet) = data + override fun visitSetValue(expression: IrSetValue, data: BitSet) = expression.value.accept(this, data) + override fun visitVariable(declaration: IrVariable, data: BitSet) = declaration.initializer?.accept(this, data) ?: data + + override fun visitSuspendableExpression(expression: IrSuspendableExpression, data: BitSet) = expression.result.accept(this, data) + override fun visitSuspensionPoint(expression: IrSuspensionPoint, data: BitSet) = expression.result.accept(this, data) + + override fun visitGetField(expression: IrGetField, data: BitSet) = expression.receiver?.accept(this, data) ?: data + override fun visitSetField(expression: IrSetField, data: BitSet) = + expression.value.accept(this, expression.receiver?.accept(this, data) ?: data) + + override fun visitFunctionReference(expression: IrFunctionReference, data: BitSet) = data + override fun visitVararg(expression: IrVararg, data: BitSet) = data + + override fun visitBreak(jump: IrBreak, data: BitSet): BitSet { + intersectInitializedFiles(initializedFilesAtLoopsBreaks, jump.loop, data) + return data + } + override fun visitContinue(jump: IrContinue, data: BitSet): BitSet { + intersectInitializedFiles(initializedFilesAtLoopsContinues, jump.loop, data) + return data + } + // A while loop might not execute even a single iteration. + override fun visitWhileLoop(loop: IrWhileLoop, data: BitSet) = + loop.condition.accept(this, data).also { loop.body?.accept(this, it) } + override fun visitDoWhileLoop(loop: IrDoWhileLoop, data: BitSet): BitSet { + val bodyFallThroughResult = loop.body?.accept(this, data) ?: data + val continuesResult = initializedFilesAtLoopsContinues[loop] + // We can end up in the condition part either by falling through the entire body or by executing one of the continue clauses. + val bodyResult = intersectInitializedFiles(continuesResult, bodyFallThroughResult) + val conditionResult = loop.condition.accept(this, bodyResult) + val breaksResult = initializedFilesAtLoopsBreaks[loop] + // A loop can be finished either by checking the condition or by executing a break clause. + return intersectInitializedFiles(breaksResult, conditionResult) + } + + private fun updateResultForReturnTarget(symbol: IrReturnTargetSymbol, set: BitSet) = + intersectInitializedFiles(returnTargetsInitializedFiles, symbol, set) + + override fun visitReturn(expression: IrReturn, data: BitSet) = + expression.value.accept(this, data).also { + updateResultForReturnTarget(expression.returnTargetSymbol, it) + } + + override fun visitContainerExpression(expression: IrContainerExpression, data: BitSet): BitSet { + val result = expression.statements.fold(data) { set, statement -> statement.accept(this, set) } + return if (expression !is IrReturnableBlock) + result + else { + updateResultForReturnTarget(expression.symbol, result) + returnTargetsInitializedFiles[expression.symbol]!! + } + } + + override fun visitWhen(expression: IrWhen, data: BitSet): BitSet { + val firstBranch = expression.branches.first() + val firstConditionResult = firstBranch.condition.accept(this, data) + val bodiesResult = firstBranch.result.accept(this, firstConditionResult) + var conditionsResult = firstConditionResult + for (i in 1 until expression.branches.size) { + val branch = expression.branches[i] + conditionsResult = branch.condition.accept(this, conditionsResult) + val branchResult = branch.result.accept(this, conditionsResult) + bodiesResult.and(branchResult) + } + val isExhaustive = expression.branches.last().isUnconditional() + return if (isExhaustive) { + // One of the branches must have been executed. + bodiesResult + } else { + // The first condition is always executed. + firstConditionResult + } + } + + override fun visitThrow(expression: IrThrow, data: BitSet): BitSet { + expression.value.accept(this, data) + return data // Conservative but correct. + } + + override fun visitTry(aTry: IrTry, data: BitSet): BitSet { + require(aTry.finallyExpression == null) + aTry.tryResult.accept(this, data) + // Catch blocks can't assume that the try part has been executed entirely, + // so only take what was known at the beginning of the try block. + aTry.catches.forEach { it.result.accept(this, data) } + // Since the try part could have been executed with an exception which then could've been caught by + // some of the catch clauses, it is incorrect to take the try block's result, + // so conservatively don't change the result. + return data + } + + private fun BitSet.withSetBit(bit: Int): BitSet = + if (this.get(bit)) this else copy().also { it.set(bit) } + + private fun getResultAfterCall(function: IrFunction, set: BitSet): BitSet { + val result = initializedFiles.afterCall[function] + if (result == null) { + if (!function.callsFileInitializer()) return set + val file = function.fileOrNull ?: return set + return set.withSetBit(initializedFiles.fileIds[file]!!) + } + return result.copy().also { it.or(set) } + } + + private fun updateResultForFunction(function: IrFunction, globalSet: BitSet, threadLocalSet: BitSet) { + if (analysisGoal != AnalysisGoal.ComputeInitializedBeforeCall) return + intersectInitializedFiles(initializedFiles.beforeCallGlobal, function, globalSet) + intersectInitializedFiles(initializedFiles.beforeCallThreadLocal, function, threadLocalSet) + } + + private fun updateResultForFunction(function: IrFunction, set: BitSet) { + if (analysisGoal != AnalysisGoal.ComputeInitializedBeforeCall) return + intersectInitializedFiles(initializedFiles.beforeCallGlobal, function, set.copy().also { it.or(filesWithInitializedGlobals) }) + intersectInitializedFiles(initializedFiles.beforeCallThreadLocal, function, set.copy().also { it.or(filesWithInitializedThreadLocals) }) + } + + override fun visitGetObjectValue(expression: IrGetObjectValue, data: BitSet): BitSet { + val objectClass = expression.symbol.owner + val constructor = objectClass.constructors.toList().atMostOne() + if (constructor != null) { + updateResultForFunction(constructor, data) + } else { + require(objectClass.isExternal || objectClass is IrLazyClass) { "No constructor for ${objectClass.render()}" } + } + val file = objectClass.fileOrNull ?: return data + val fileId = initializedFiles.fileIds[file]!! + if (data.get(fileId)) return data + return data.copy().also { it.set(fileId) } + } + + private fun processCall(expression: IrFunctionAccessExpression, actualCallee: IrFunction, data: BitSet): BitSet { + val arguments = expression.getArgumentsWithIr() + val argumentsResult = arguments.fold(data) { set, arg -> arg.second.accept(this, set) } + updateResultForFunction(actualCallee, argumentsResult) + val file = actualCallee.fileOrNull + val fileId = file?.let { initializedFiles.fileIds[it]!! } ?: invalidFileId + if (analysisGoal == AnalysisGoal.CollectCallSites && file != null + // Only extract initializer calls from non-virtual functions. + && !actualCallee.isOverridable + ) { + // The initializer won't be optimized away from the function. + if (!initializedFiles.beforeCallGlobal[actualCallee]!!.get(fileId)) { + if (argumentsResult.get(fileId) || filesWithInitializedGlobals.get(fileId)) + callSitesNotRequiringGlobalInitializerCall += expression + else + callSitesRequiringGlobalInitializerCall += expression + } + // The initializer won't be optimized away from the function. + if (!initializedFiles.beforeCallThreadLocal[actualCallee]!!.get(fileId)) { + if (argumentsResult.get(fileId) || filesWithInitializedThreadLocals.get(fileId)) + callSitesNotRequiringThreadLocalInitializerCall += expression + else + callSitesRequiringThreadLocalInitializerCall += expression + } + } + return getResultAfterCall(actualCallee, argumentsResult) + } + + private fun processExecuteImpl(expression: IrCall, data: BitSet): BitSet { + var curData = processCall(expression, expression.symbol.owner, data) + val producerInvocation = producerInvocations[expression.getValueArgument(2)!!]!! + // Producer is invoked right here in the same thread, so can update the result. + // Albeit this call site is a fictitious one, it is always a virtual one, which aren't optimized for now. + curData = visitCall(producerInvocation, curData) + val jobInvocation = jobInvocations[producerInvocation]!! + if (analysisGoal != AnalysisGoal.CollectCallSites) { + require(!jobInvocation.isVirtualCall) { "Expected a static call but was: ${jobInvocation.render()}" } + updateResultForFunction(jobInvocation.actualCallee, + curData.copy().also { it.or(filesWithInitializedGlobals) }, // Globals (= shared) visible to other threads as well. + BitSet() // A new thread is about to be created - no thread locals initialized yet. + ) + } + // Actual job could be invoked on another thread, thus can't take the result from that call. + return curData + } + + override fun visitFunctionAccess(expression: IrFunctionAccessExpression, data: BitSet) = + processCall(expression, expression.actualCallee, data) + + override fun visitCall(expression: IrCall, data: BitSet): BitSet { + if (expression.symbol.owner.isFileInitializer) + return data.withSetBit(initializedFiles.fileIds[irDeclaration.file]!!) + if (expression.symbol == executeImplSymbol) + return processExecuteImpl(expression, data) + if (expression.symbol == getContinuationSymbol) + return data + if (!expression.isVirtualCall) + return processCall(expression, expression.actualCallee, data) + val devirtualizedCallSite = virtualCallSites[expression] ?: return data + val arguments = expression.getArgumentsWithIr() + val argumentsResult = arguments.fold(data) { set, arg -> arg.second.accept(this, set) } + var callResult = BitSet() + var first = true + for (callSite in devirtualizedCallSite) { + val callee = callSite.actualCallee.irFunction ?: error("No IR for: ${callSite.actualCallee}") + updateResultForFunction(callee, argumentsResult) + if (first) { + callResult = getResultAfterCall(callee, BitSet()) + first = false + } else { + val otherSet = getResultAfterCall(callee, BitSet()) + callResult.and(otherSet) + } + } + return argumentsResult.copy().also { it.or(callResult) } + } + }, BitSet()) + + if (analysisGoal == AnalysisGoal.ComputeInitializedAfterCall) { + if (!node.symbol.isTopLevelFieldInitializer) + initializedFiles.afterCall[irDeclaration as IrFunction] = returnTargetsInitializedFiles[irDeclaration.symbol] ?: callerResult + } + } + } + + fun removeRedundantCalls(context: Context, callGraph: CallGraph, rootSet: Set) { + val analysisResult = InterproceduralAnalysis(context, callGraph, rootSet).analyze() + + var numberOfFunctionsWithGlobalInitializerCall = 0 + var numberOfFunctionsWithThreadLocalInitializerCall = 0 + var numberOfRemovedGlobalInitializerCalls = 0 + var numberOfRemovedThreadLocalInitializerCalls = 0 + var numberOfCallSitesToFunctionsWithGlobalInitializerCall = 0 + var numberOfCallSitesToFunctionsWithThreadLocalInitializerCall = 0 + var numberOfCallSitesWithExtractedGlobalInitializerCall = 0 + var numberOfCallSitesWithExtractedThreadLocalInitializerCall = 0 + + context.irModule!!.transformChildren(object : IrElementTransformer { + override fun visitDeclaration(declaration: IrDeclarationBase, data: IrBuilderWithScope?): IrStatement { + return super.visitDeclaration(declaration, context.createIrBuilder(declaration.symbol, SYNTHETIC_OFFSET, SYNTHETIC_OFFSET)) + } + + override fun visitFunctionAccess(expression: IrFunctionAccessExpression, data: IrBuilderWithScope?): IrExpression { + expression.transformChildren(this, data) + + val callee = expression.actualCallee + val body = callee.body ?: return expression + val initializerCalls = (body as IrBlockBody).statements + .take(2) // The very first statements by construction. + .filter { + val calleeOrigin = (it as? IrCall)?.symbol?.owner?.origin + val isNotOptimizedAwayGlobalInitializerCall = calleeOrigin == DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER + && callee !in analysisResult.functionsRequiringGlobalInitializerCall + val isNotOptimizedAwayThreadLocalInitializerCall = (calleeOrigin == DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER + || calleeOrigin == DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER) + && callee !in analysisResult.functionsRequiringThreadLocalInitializerCall + if (isNotOptimizedAwayGlobalInitializerCall) + ++numberOfCallSitesToFunctionsWithGlobalInitializerCall + if (isNotOptimizedAwayThreadLocalInitializerCall) + ++numberOfCallSitesToFunctionsWithThreadLocalInitializerCall + val canExtractGlobalInitializerCall = isNotOptimizedAwayGlobalInitializerCall + && expression in analysisResult.callSitesRequiringGlobalInitializerCall + val canExtractThreadLocalInitializerCall = isNotOptimizedAwayThreadLocalInitializerCall + && expression in analysisResult.callSitesRequiringThreadLocalInitializerCall + if (canExtractGlobalInitializerCall) + ++numberOfCallSitesWithExtractedGlobalInitializerCall + if (canExtractThreadLocalInitializerCall) + ++numberOfCallSitesWithExtractedThreadLocalInitializerCall + canExtractGlobalInitializerCall || canExtractThreadLocalInitializerCall + } + if (initializerCalls.isEmpty()) return expression + + return data!!.irBlock(expression) { + initializerCalls.forEach { +irCallFileInitializer((it as IrCall).symbol) } + +expression + } + } + }, data = null) + + context.irModule!!.transformChildrenVoid(object : IrElementTransformerVoid() { + override fun visitFunction(declaration: IrFunction): IrStatement { + val body = declaration.body ?: return declaration + val statements = (body as IrBlockBody).statements + val globalInitializerCallIndex = statements + .take(2) // The very first statements by construction. + .indexOfFirst { + val calleeOrigin = (it as? IrCall)?.symbol?.owner?.origin + calleeOrigin == DECLARATION_ORIGIN_FILE_GLOBAL_INITIALIZER + } + if (globalInitializerCallIndex >= 0) { + ++numberOfFunctionsWithGlobalInitializerCall + if (declaration !in analysisResult.functionsRequiringGlobalInitializerCall) { + ++numberOfRemovedGlobalInitializerCalls + statements.removeAt(globalInitializerCallIndex) + } + } + val threadLocalInitializerCallIndex = statements + .take(2) + .indexOfFirst { + val calleeOrigin = (it as? IrCall)?.symbol?.owner?.origin + calleeOrigin == DECLARATION_ORIGIN_FILE_THREAD_LOCAL_INITIALIZER + || calleeOrigin == DECLARATION_ORIGIN_FILE_STANDALONE_THREAD_LOCAL_INITIALIZER + } + if (threadLocalInitializerCallIndex >= 0) { + ++numberOfFunctionsWithThreadLocalInitializerCall + if (declaration !in analysisResult.functionsRequiringThreadLocalInitializerCall) { + ++numberOfRemovedThreadLocalInitializerCalls + statements.removeAt(threadLocalInitializerCallIndex) + } + } + return declaration + } + }) + + context.log { "Removed ${numberOfRemovedGlobalInitializerCalls * 100.0 / numberOfFunctionsWithGlobalInitializerCall}% global initializers" } + context.log { "Removed ${numberOfRemovedThreadLocalInitializerCalls * 100.0 / numberOfFunctionsWithThreadLocalInitializerCall}% thread local initializers" } + context.log { "Removed ${(numberOfCallSitesWithExtractedGlobalInitializerCall) * 100.0 / numberOfCallSitesToFunctionsWithGlobalInitializerCall}% global initializer calls" } + context.log { "Removed ${(numberOfCallSitesWithExtractedThreadLocalInitializerCall) * 100.0 / numberOfCallSitesToFunctionsWithThreadLocalInitializerCall}% thread local initializer calls" } + } +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/build.gradle b/kotlin-native/backend.native/tests/build.gradle index 742df2235bd..8e38205ea52 100644 --- a/kotlin-native/backend.native/tests/build.gradle +++ b/kotlin-native/backend.native/tests/build.gradle @@ -1424,6 +1424,51 @@ standaloneTest("initializers_failInInitializer4") { flags = ['-Xir-property-lazy-initialization'] } +standaloneTest("initializers_when1") { + source = "codegen/initializers/when1.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_when2") { + source = "codegen/initializers/when2.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_throw1") { + source = "codegen/initializers/throw1.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_throw2") { + source = "codegen/initializers/throw2.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_while1") { + source = "codegen/initializers/while1.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_while2") { + source = "codegen/initializers/while2.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_while3") { + source = "codegen/initializers/while3.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_return1") { + source = "codegen/initializers/return1.kt" + goldValue = "42\n" +} + +standaloneTest("initializers_return2") { + source = "codegen/initializers/return2.kt" + goldValue = "42\n" +} + linkTest("initializers_linkTest1") { goldValue = "1200\n" source = "codegen/initializers/linkTest1_main.kt" diff --git a/kotlin-native/backend.native/tests/codegen/initializers/return1.kt b/kotlin-native/backend.native/tests/codegen/initializers/return1.kt new file mode 100644 index 00000000000..cbaca6eb96e --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/return1.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) { + if (x <= 0) return + bar(x) +} + +fun main() { + foo(0) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/return2.kt b/kotlin-native/backend.native/tests/codegen/initializers/return2.kt new file mode 100644 index 00000000000..222040ac716 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/return2.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +inline fun foo(x: Int) { + if (x <= 0) return + bar(x) +} + +fun main() { + foo(0) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/throw1.kt b/kotlin-native/backend.native/tests/codegen/initializers/throw1.kt new file mode 100644 index 00000000000..41b3bdbbbb8 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/throw1.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val x = computeX() + +private fun computeX() = 42 + +fun baz1() { } + +fun baz2() { } + +// FILE: main.kt +fun bar(x: Int) = if (x == 0) error("") else x + +fun foo(x: Int) { + try { + bar(x) + baz1() + } catch (t: Throwable) { } +} + +fun main() { + foo(0) + println(x) +} diff --git a/kotlin-native/backend.native/tests/codegen/initializers/throw2.kt b/kotlin-native/backend.native/tests/codegen/initializers/throw2.kt new file mode 100644 index 00000000000..95e17192df8 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/throw2.kt @@ -0,0 +1,29 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = computeY() + +private fun computeY() = 42 + +fun baz1() { } + +fun baz2() { } + +// FILE: main.kt +fun bar(x: Int) = if (x == 0) error("") else x + +fun foo(x: Int) { + try { + bar(x) + baz1() + } catch (t: Throwable) { + println(y) + } +} + +fun main() { + foo(0) +} diff --git a/kotlin-native/backend.native/tests/codegen/initializers/when1.kt b/kotlin-native/backend.native/tests/codegen/initializers/when1.kt new file mode 100644 index 00000000000..bb1939c0178 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/when1.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) = when { + x > 0 -> 42 + bar(x) -> 117 + else -> -1 +} + +fun main() { + foo(123) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/when2.kt b/kotlin-native/backend.native/tests/codegen/initializers/when2.kt new file mode 100644 index 00000000000..98776f9fb12 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/when2.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) { + if (x > 0) bar(x) +} + +fun main() { + foo(-1) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/while1.kt b/kotlin-native/backend.native/tests/codegen/initializers/while1.kt new file mode 100644 index 00000000000..d231e337232 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/while1.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) { + var i = 0 + while (i < x) { + bar(i) + ++i + } +} + +fun main() { + foo(0) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/while2.kt b/kotlin-native/backend.native/tests/codegen/initializers/while2.kt new file mode 100644 index 00000000000..b3e03ceed06 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/while2.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) { + var i = 0 + do { + if (i == x) break + ++i + } while (bar(i)) +} + +fun main() { + foo(0) + println(y) +} \ No newline at end of file diff --git a/kotlin-native/backend.native/tests/codegen/initializers/while3.kt b/kotlin-native/backend.native/tests/codegen/initializers/while3.kt new file mode 100644 index 00000000000..ea6916ef7c6 --- /dev/null +++ b/kotlin-native/backend.native/tests/codegen/initializers/while3.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +// FILE: lib.kt +val y = getY() + +private fun getY() = 42 + +fun bar(x: Int) = x == 0 + +// FILE: main.kt +fun foo(x: Int) { + var i = 0 + do { + ++i + if (i > 0) continue + bar(i) + } while (i < x) +} + +fun main() { + foo(0) + println(y) +} \ No newline at end of file