diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/InterfaceImplBodyCodegen.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/InterfaceImplBodyCodegen.kt.202 new file mode 100644 index 00000000000..8ba3fd56ffe --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/InterfaceImplBodyCodegen.kt.202 @@ -0,0 +1,182 @@ +/* + * Copyright 2010-2015 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.codegen + +import com.intellij.util.ArrayUtil +import org.jetbrains.kotlin.backend.common.bridges.findImplementationFromInterface +import org.jetbrains.kotlin.backend.common.bridges.firstSuperMethodFromKotlin +import org.jetbrains.kotlin.codegen.context.ClassContext +import org.jetbrains.kotlin.codegen.state.GenerationState +import org.jetbrains.kotlin.descriptors.* +import org.jetbrains.kotlin.load.java.descriptors.JavaMethodDescriptor +import org.jetbrains.kotlin.psi.KtPureClassOrObject +import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOriginKind +import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature +import org.jetbrains.org.objectweb.asm.MethodVisitor +import org.jetbrains.org.objectweb.asm.Opcodes.* + +class InterfaceImplBodyCodegen( + aClass: KtPureClassOrObject, + context: ClassContext, + v: ClassBuilder, + state: GenerationState, + parentCodegen: MemberCodegen<*>? +) : ClassBodyCodegen(aClass, context, InterfaceImplBodyCodegen.InterfaceImplClassBuilder(v), state, parentCodegen) { + private var isAnythingGenerated: Boolean = false + get() = (v as InterfaceImplClassBuilder).isAnythingGenerated + + private val defaultImplType = typeMapper.mapDefaultImpls(descriptor) + + override fun generateDeclaration() { + val codegenFlags = ACC_PUBLIC or ACC_FINAL or ACC_SUPER + val flags = if (state.classBuilderMode == ClassBuilderMode.LIGHT_CLASSES) codegenFlags or ACC_STATIC else codegenFlags + v.defineClass( + myClass.psiOrParent, state.classFileVersion, flags, + defaultImplType.internalName, + null, "java/lang/Object", ArrayUtil.EMPTY_STRING_ARRAY + ) + v.visitSource(myClass.containingKtFile.name, null) + } + + override fun classForInnerClassRecord(): ClassDescriptor? { + if (!isAnythingGenerated) return null + return InnerClassConsumer.classForInnerClassRecord(descriptor, true) + } + + override fun generateSyntheticPartsAfterBody() { + for (memberDescriptor in descriptor.defaultType.memberScope.getContributedDescriptors()) { + if (memberDescriptor !is CallableMemberDescriptor) continue + + if (memberDescriptor.kind.isReal) continue + if (memberDescriptor.visibility == Visibilities.INVISIBLE_FAKE) continue + if (memberDescriptor.modality == Modality.ABSTRACT) continue + + val implementation = findImplementationFromInterface(memberDescriptor) ?: continue + + // If implementation is a default interface method (JVM 8 only) + if (implementation.isDefinitelyNotDefaultImplsMethod()) continue + + if (memberDescriptor is FunctionDescriptor) { + generateDelegationToSuperDefaultImpls(memberDescriptor, implementation as FunctionDescriptor) + } + else if (memberDescriptor is PropertyDescriptor) { + implementation as PropertyDescriptor + val getter = memberDescriptor.getter + val implGetter = implementation.getter + if (getter != null && implGetter != null) { + generateDelegationToSuperDefaultImpls(getter, implGetter) + } + val setter = memberDescriptor.setter + val implSetter = implementation.setter + if (setter != null && implSetter != null) { + generateDelegationToSuperDefaultImpls(setter, implSetter) + } + } + } + + generateSyntheticAccessors() + } + + private fun generateDelegationToSuperDefaultImpls(descriptor: FunctionDescriptor, implementation: FunctionDescriptor) { + val delegateTo = firstSuperMethodFromKotlin(descriptor, implementation) as FunctionDescriptor? ?: return + + // We can't call super methods from Java 1.8 interfaces because that requires INVOKESPECIAL which is forbidden from TImpl class + if (delegateTo is JavaMethodDescriptor) return + + functionCodegen.generateMethod( + JvmDeclarationOrigin( + JvmDeclarationOriginKind.DEFAULT_IMPL_DELEGATION_TO_SUPERINTERFACE_DEFAULT_IMPL, + DescriptorToSourceUtils.descriptorToDeclaration(descriptor), descriptor + ), + descriptor, + object : FunctionGenerationStrategy.CodegenBased(state) { + override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) { + val iv = codegen.v + + val method = typeMapper.mapToCallableMethod(delegateTo, true) + val myParameters = signature.valueParameters + val calleeParameters = method.getValueParameters() + + if (myParameters.size != calleeParameters.size) { + throw AssertionError( + "Method from super interface has a different signature.\n" + + "This method:\n%s\n%s\n%s\nSuper method:\n%s\n%s\n%s".format( + descriptor, signature, myParameters, delegateTo, method, calleeParameters + ) + ) + } + + var k = 0 + val it = calleeParameters.iterator() + for (parameter in myParameters) { + val type = parameter.asmType + StackValue.local(k, type).put(it.next().asmType, iv) + k += type.size + } + + method.genInvokeInstruction(iv) + StackValue.coerce(method.returnType, signature.returnType, iv) + iv.areturn(signature.returnType) + } + }) + } + + override fun generateKotlinMetadataAnnotation() { + (v as InterfaceImplClassBuilder).stopCounting() + + writeSyntheticClassMetadata(v, state) + } + + override fun done() { + super.done() + if (!isAnythingGenerated) { + state.factory.removeClasses(setOf(defaultImplType.internalName)) + } + } + + private class InterfaceImplClassBuilder(private val v: ClassBuilder) : DelegatingClassBuilder() { + private var shouldCount: Boolean = true + var isAnythingGenerated: Boolean = false + private set + + fun stopCounting() { + shouldCount = false + } + + override fun getDelegate() = v + + override fun newMethod( + origin: JvmDeclarationOrigin, + access: Int, + name: String, + desc: String, + signature: String?, + exceptions: Array? + ): MethodVisitor { + if (shouldCount) { + isAnythingGenerated = true + } + return super.newMethod(origin, access, name, desc, signature, exceptions) + } + } + + override fun generateSyntheticPartsBeforeBody() { + generatePropertyMetadataArrayFieldIfNeeded(defaultImplType) + } +} diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/OriginCollectingClassBuilderFactory.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/OriginCollectingClassBuilderFactory.kt.202 new file mode 100644 index 00000000000..b8254176580 --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/OriginCollectingClassBuilderFactory.kt.202 @@ -0,0 +1,68 @@ +/* + * Copyright 2010-2018 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.codegen + +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin +import org.jetbrains.org.objectweb.asm.* +import org.jetbrains.org.objectweb.asm.tree.* + +class OriginCollectingClassBuilderFactory(private val builderMode: ClassBuilderMode) : ClassBuilderFactory { + val compiledClasses = mutableListOf() + val origins = mutableMapOf() + + override fun getClassBuilderMode(): ClassBuilderMode = builderMode + + override fun newClassBuilder(origin: JvmDeclarationOrigin): AbstractClassBuilder.Concrete { + val classNode = ClassNode() + compiledClasses += classNode + origins[classNode] = origin + return OriginCollectingClassBuilder(classNode) + } + + private inner class OriginCollectingClassBuilder(val classNode: ClassNode) : AbstractClassBuilder.Concrete(classNode) { + override fun newField( + origin: JvmDeclarationOrigin, + access: Int, + name: String, + desc: String, + signature: String?, + value: Any? + ): FieldVisitor { + val fieldNode = super.newField(origin, access, name, desc, signature, value) as FieldNode + origins[fieldNode] = origin + return fieldNode + } + + override fun newMethod( + origin: JvmDeclarationOrigin, + access: Int, + name: String, + desc: String, + signature: String?, + exceptions: Array? + ): MethodVisitor { + val methodNode = super.newMethod(origin, access, name, desc, signature, exceptions) as MethodNode + origins[methodNode] = origin + + // ASM doesn't read information about local variables for the `abstract` methods so we need to get it manually + if ((access and Opcodes.ACC_ABSTRACT) != 0 && methodNode.localVariables == null) { + methodNode.localVariables = mutableListOf() + } + + return methodNode + } + } + + override fun asBytes(builder: ClassBuilder): ByteArray { + val classWriter = ClassWriter(0) + (builder as OriginCollectingClassBuilder).classNode.accept(classWriter) + return classWriter.toByteArray() + } + + override fun asText(builder: ClassBuilder) = throw UnsupportedOperationException() + + override fun close() {} +} \ No newline at end of file diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/SignatureCollectingClassBuilderFactory.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/SignatureCollectingClassBuilderFactory.kt.202 new file mode 100644 index 00000000000..96cb43312a0 --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/SignatureCollectingClassBuilderFactory.kt.202 @@ -0,0 +1,89 @@ +/* + * Copyright 2010-2016 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.codegen + +import com.intellij.psi.PsiElement +import com.intellij.util.containers.LinkedMultiMap +import com.intellij.util.containers.MultiMap +import org.jetbrains.kotlin.resolve.jvm.diagnostics.ConflictingJvmDeclarationsData +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin +import org.jetbrains.kotlin.resolve.jvm.diagnostics.MemberKind +import org.jetbrains.kotlin.resolve.jvm.diagnostics.RawSignature +import org.jetbrains.org.objectweb.asm.FieldVisitor +import org.jetbrains.org.objectweb.asm.MethodVisitor + +abstract class SignatureCollectingClassBuilderFactory( + delegate: ClassBuilderFactory, val shouldGenerate: (JvmDeclarationOrigin) -> Boolean +) : DelegatingClassBuilderFactory(delegate) { + + protected abstract fun handleClashingSignatures(data: ConflictingJvmDeclarationsData) + protected abstract fun onClassDone(classOrigin: JvmDeclarationOrigin, + classInternalName: String, + signatures: MultiMap) + + override fun newClassBuilder(origin: JvmDeclarationOrigin): DelegatingClassBuilder { + return SignatureCollectingClassBuilder(origin, delegate.newClassBuilder(origin)) + } + + private inner class SignatureCollectingClassBuilder( + private val classCreatedFor: JvmDeclarationOrigin, + internal val _delegate: ClassBuilder + ) : DelegatingClassBuilder() { + + override fun getDelegate() = _delegate + + private lateinit var classInternalName: String + + private val signatures = LinkedMultiMap() + + override fun defineClass(origin: PsiElement?, version: Int, access: Int, name: String, signature: String?, superName: String, interfaces: Array) { + classInternalName = name + super.defineClass(origin, version, access, name, signature, superName, interfaces) + } + + override fun newField(origin: JvmDeclarationOrigin, access: Int, name: String, desc: String, signature: String?, value: Any?): FieldVisitor { + signatures.putValue(RawSignature(name, desc, MemberKind.FIELD), origin) + if (!shouldGenerate(origin)) { + return AbstractClassBuilder.EMPTY_FIELD_VISITOR + } + return super.newField(origin, access, name, desc, signature, value) + } + + override fun newMethod(origin: JvmDeclarationOrigin, access: Int, name: String, desc: String, signature: String?, exceptions: Array?): MethodVisitor { + signatures.putValue(RawSignature(name, desc, MemberKind.METHOD), origin) + if (!shouldGenerate(origin)) { + return AbstractClassBuilder.EMPTY_METHOD_VISITOR + } + return super.newMethod(origin, access, name, desc, signature, exceptions) + } + + override fun done() { + for ((signature, elementsAndDescriptors) in signatures.entrySet()) { + if (elementsAndDescriptors.size == 1) continue // no clash + handleClashingSignatures(ConflictingJvmDeclarationsData( + classInternalName, + classCreatedFor, + signature, + elementsAndDescriptors + )) + } + onClassDone(classCreatedFor, classInternalName, signatures) + super.done() + } + + } +} diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt.202 new file mode 100644 index 00000000000..53d3e45a0e5 --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformerMethodVisitor.kt.202 @@ -0,0 +1,1236 @@ +/* + * Copyright 2010-2019 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.codegen.coroutines + +import org.jetbrains.kotlin.backend.common.CodegenUtil +import org.jetbrains.kotlin.codegen.AsmUtil +import org.jetbrains.kotlin.codegen.ClassBuilder +import org.jetbrains.kotlin.codegen.StackValue +import org.jetbrains.kotlin.codegen.TransformationMethodVisitor +import org.jetbrains.kotlin.codegen.inline.* +import org.jetbrains.kotlin.codegen.optimization.boxing.isUnitInstance +import org.jetbrains.kotlin.codegen.optimization.common.* +import org.jetbrains.kotlin.codegen.optimization.fixStack.FixStackMethodTransformer +import org.jetbrains.kotlin.codegen.optimization.fixStack.top +import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer +import org.jetbrains.kotlin.config.LanguageVersionSettings +import org.jetbrains.kotlin.config.isReleaseCoroutines +import org.jetbrains.kotlin.diagnostics.DiagnosticSink +import org.jetbrains.kotlin.psi.KtElement +import org.jetbrains.kotlin.resolve.jvm.AsmTypes +import org.jetbrains.kotlin.resolve.jvm.diagnostics.ErrorsJvm +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin +import org.jetbrains.kotlin.utils.sure +import org.jetbrains.org.objectweb.asm.Label +import org.jetbrains.org.objectweb.asm.MethodVisitor +import org.jetbrains.org.objectweb.asm.Opcodes +import org.jetbrains.org.objectweb.asm.Type +import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter +import org.jetbrains.org.objectweb.asm.tree.* +import org.jetbrains.org.objectweb.asm.tree.analysis.* +import kotlin.math.max + +private const val COROUTINES_DEBUG_METADATA_VERSION = 1 + +private const val COROUTINES_METADATA_SOURCE_FILE_JVM_NAME = "f" +private const val COROUTINES_METADATA_LINE_NUMBERS_JVM_NAME = "l" +private const val COROUTINES_METADATA_LOCAL_NAMES_JVM_NAME = "n" +private const val COROUTINES_METADATA_SPILLED_JVM_NAME = "s" +private const val COROUTINES_METADATA_INDEX_TO_LABEL_JVM_NAME = "i" +private const val COROUTINES_METADATA_METHOD_NAME_JVM_NAME = "m" +private const val COROUTINES_METADATA_CLASS_NAME_JVM_NAME = "c" +private const val COROUTINES_METADATA_VERSION_JVM_NAME = "v" + +const val SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME = "\$completion" +const val SUSPEND_CALL_RESULT_NAME = "\$result" +const val ILLEGAL_STATE_ERROR_MESSAGE = "call to 'resume' before 'invoke' with coroutine" + +class CoroutineTransformerMethodVisitor( + delegate: MethodVisitor, + access: Int, + name: String, + desc: String, + signature: String?, + exceptions: Array?, + private val containingClassInternalName: String, + obtainClassBuilderForCoroutineState: () -> ClassBuilder, + private val isForNamedFunction: Boolean, + private val shouldPreserveClassInitialization: Boolean, + private val languageVersionSettings: LanguageVersionSettings, + // Since tail-call optimization of functions with Unit return type relies on ability of call-site to recognize them, + // in order to ignore return value and push Unit, when we cannot ensure this ability, for example, when the function overrides function, + // returning Any, we need to disable tail-call optimization for these functions. + private val disableTailCallOptimizationForFunctionReturningUnit: Boolean, + private val reportSuspensionPointInsideMonitor: (String) -> Unit, + private val lineNumber: Int, + private val sourceFile: String, + // It's only matters for named functions, may differ from '!isStatic(access)' in case of DefaultImpls + private val needDispatchReceiver: Boolean = false, + // May differ from containingClassInternalName in case of DefaultImpls + private val internalNameForDispatchReceiver: String? = null, + // JVM_IR backend generates $completion, while old backend does not + private val putContinuationParameterToLvt: Boolean = true +) : TransformationMethodVisitor(delegate, access, name, desc, signature, exceptions) { + + private val classBuilderForCoroutineState: ClassBuilder by lazy(obtainClassBuilderForCoroutineState) + + private var continuationIndex = if (isForNamedFunction) -1 else 0 + private var dataIndex = if (isForNamedFunction) -1 else 1 + private var exceptionIndex = if (isForNamedFunction || languageVersionSettings.isReleaseCoroutines()) -1 else 2 + + override fun performTransformations(methodNode: MethodNode) { + removeFakeContinuationConstructorCall(methodNode) + + replaceReturnsUnitMarkersWithPushingUnitOnStack(methodNode) + + replaceFakeContinuationsWithRealOnes( + methodNode, + if (isForNamedFunction) getLastParameterIndex(methodNode.desc, methodNode.access) else 0 + ) + + FixStackMethodTransformer().transform(containingClassInternalName, methodNode) + RedundantLocalsEliminationMethodTransformer(languageVersionSettings).transform(containingClassInternalName, methodNode) + if (languageVersionSettings.isReleaseCoroutines()) { + ChangeBoxingMethodTransformer.transform(containingClassInternalName, methodNode) + } + updateMaxStack(methodNode) + + val suspensionPoints = collectSuspensionPoints(methodNode) + + checkForSuspensionPointInsideMonitor(methodNode, suspensionPoints) + + // First instruction in the method node may change in case of named function + val actualCoroutineStart = methodNode.instructions.first + + if (isForNamedFunction) { + if (putContinuationParameterToLvt) { + addCompletionParameterToLVT(methodNode) + } + + val examiner = MethodNodeExaminer( + languageVersionSettings, + containingClassInternalName, + methodNode, + disableTailCallOptimizationForFunctionReturningUnit + ) + if (examiner.allSuspensionPointsAreTailCalls(suspensionPoints)) { + examiner.replacePopsBeforeSafeUnitInstancesWithCoroutineSuspendedChecks() + dropSuspensionMarkers(methodNode) + return + } + + dataIndex = methodNode.maxLocals++ + if (!languageVersionSettings.isReleaseCoroutines()) { + exceptionIndex = methodNode.maxLocals++ + } + continuationIndex = methodNode.maxLocals++ + + prepareMethodNodePreludeForNamedFunction(methodNode) + } + + for (suspensionPoint in suspensionPoints) { + splitTryCatchBlocksContainingSuspensionPoint(methodNode, suspensionPoint) + } + + // Actual max stack might be increased during the previous phases + updateMaxStack(methodNode) + + UninitializedStoresProcessor(methodNode, shouldPreserveClassInitialization).run() + + val spilledToVariableMapping = spillVariables(suspensionPoints, methodNode) + + val suspendMarkerVarIndex = methodNode.maxLocals++ + + val suspensionPointLineNumbers = suspensionPoints.map { findSuspensionPointLineNumber(it) } + + val continuationLabels = suspensionPoints.withIndex().map { + transformCallAndReturnContinuationLabel( + it.index + 1, it.value, methodNode, suspendMarkerVarIndex, suspensionPointLineNumbers[it.index]) + } + + methodNode.instructions.apply { + val tableSwitchLabel = LabelNode() + val firstStateLabel = LabelNode() + val defaultLabel = LabelNode() + + // tableswitch(this.label) + insertBefore( + actualCoroutineStart, + insnListOf( + *withInstructionAdapter { loadCoroutineSuspendedMarker(languageVersionSettings) }.toArray(), + tableSwitchLabel, + // Allow debugger to stop on enter into suspend function + LineNumberNode(lineNumber, tableSwitchLabel), + VarInsnNode(Opcodes.ASTORE, suspendMarkerVarIndex), + VarInsnNode(Opcodes.ALOAD, continuationIndex), + *withInstructionAdapter { getLabel() }.toArray(), + TableSwitchInsnNode( + 0, + suspensionPoints.size, + defaultLabel, + firstStateLabel, *continuationLabels.toTypedArray() + ), + firstStateLabel + ) + ) + + insert(firstStateLabel, withInstructionAdapter { + generateResumeWithExceptionCheck(languageVersionSettings.isReleaseCoroutines(), dataIndex, exceptionIndex) + }) + insert(last, defaultLabel) + + insert(last, withInstructionAdapter { + AsmUtil.genThrow(this, "java/lang/IllegalStateException", ILLEGAL_STATE_ERROR_MESSAGE) + areturn(Type.VOID_TYPE) + }) + } + + dropSuspensionMarkers(methodNode) + methodNode.removeEmptyCatchBlocks() + + // The parameters (and 'this') shall live throughout the method, otherwise, d8 emits warning about invalid debug info + val startLabel = LabelNode() + val endLabel = LabelNode() + methodNode.instructions.insertBefore(methodNode.instructions.first, startLabel) + methodNode.instructions.insert(methodNode.instructions.last, endLabel) + + fixLvtForParameters(methodNode, startLabel, endLabel) + + if (languageVersionSettings.isReleaseCoroutines()) { + writeDebugMetadata(methodNode, suspensionPointLineNumbers, spilledToVariableMapping) + } + } + + private fun addCompletionParameterToLVT(methodNode: MethodNode) { + val index = + /* all args */ Type.getMethodType(methodNode.desc).argumentTypes.fold(0) { a, b -> a + b.size } + + /* this */ (if (isStatic(methodNode.access)) 0 else 1) - + /* only last */ 1 + val startLabel = with(methodNode.instructions) { + if (first is LabelNode) first as LabelNode + else LabelNode().also { insertBefore(first, it) } + } + + val endLabel = with(methodNode.instructions) { + if (last is LabelNode) last as LabelNode + else LabelNode().also { insert(last, it) } + } + methodNode.localVariables.add( + LocalVariableNode( + SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME, + languageVersionSettings.continuationAsmType().descriptor, + null, + startLabel, + endLabel, + index + ) + ) + } + + /* Put { POP, GETSTATIC Unit } after suspension point if suspension point is a call of suspend function, that returns Unit. + * + * Otherwise, upon resume, the function would seem to not return Unit, despite being declared as returning Unit. + * + * This happens when said function is tail-call and its callee does not return Unit. + * + * Let's have an example + * + * suspend fun int(): Int = suspendCoroutine { ...; 1 } + * + * suspend fun unit() { + * int() + * } + * + * suspend fun main() { + * println(unit()) + * } + * + * So, in order to understand the necessity of { POP, GETSTATIC Unit } inside `main`, we need to consider two different scenarios + * + * 1. `unit` is not a tail-call function. + * 2. `unit` is a tail-call function. + * + * When `unit` is a not tail-call function, calling `resumeWith` on its continuation will resume `unit`, + * it will hit { GETSTATIC Unit; ARETURN } and this Unit will be the result of the suspend call. `unit`'s continuation will then call + * `main` continuation's `resumeWith`, passing the Unit instance. The continuation in turn will resume `main` and the Unit will be + * the result of `unit()` call. This result will then printed. + * + * However, when `unit` is a tail-call function, there is no continuation, generated for it. This is the point of tail-call + * optimization. Thus, resume call will skip `unit` and land direcly in `main` continuation's `resumeWith`. And its result is not + * Unit. Thus, we must ignore this result on call-site and use Unit instead. In other words, POP the result and GETSTATIC Unit + * instead. + */ + private fun replaceReturnsUnitMarkersWithPushingUnitOnStack(methodNode: MethodNode) { + for (marker in methodNode.instructions.asSequence().filter(::isReturnsUnitMarker).toList()) { + assert(marker.next?.next?.let { isAfterSuspendMarker(it) } == true) { + "Expected AfterSuspendMarker after ReturnUnitMarker, got ${marker.next?.next}" + } + methodNode.instructions.insert( + marker.next.next, + withInstructionAdapter { + pop() + getstatic("kotlin/Unit", "INSTANCE", "Lkotlin/Unit;") + } + ) + methodNode.instructions.removeAll(listOf(marker.previous, marker)) + } + } + + private fun findSuspensionPointLineNumber(suspensionPoint: SuspensionPoint) = + suspensionPoint.suspensionCallBegin.findPreviousOrNull { it is LineNumberNode } as LineNumberNode? + + private fun checkForSuspensionPointInsideMonitor(methodNode: MethodNode, suspensionPoints: List) { + if (methodNode.instructions.asSequence().none { it.opcode == Opcodes.MONITORENTER }) return + + val cfg = ControlFlowGraph.build(methodNode) + val monitorDepthMap = hashMapOf() + fun addMonitorDepthToSuccs(index: Int, depth: Int) { + val insn = methodNode.instructions[index] + monitorDepthMap[insn] = depth + val newDepth = when (insn.opcode) { + Opcodes.MONITORENTER -> depth + 1 + Opcodes.MONITOREXIT -> depth - 1 + else -> depth + } + for (succIndex in cfg.getSuccessorsIndices(index)) { + if (monitorDepthMap[methodNode.instructions[succIndex]] == null) { + addMonitorDepthToSuccs(succIndex, newDepth) + } + } + } + + addMonitorDepthToSuccs(0, 0) + + for (suspensionPoint in suspensionPoints) { + if (monitorDepthMap[suspensionPoint.suspensionCallBegin]?.let { it > 0 } == true) { + // TODO: Support crossinline suspend lambdas + val stackTraceElement = StackTraceElement( + containingClassInternalName, + methodNode.name, + sourceFile, + findSuspensionPointLineNumber(suspensionPoint)?.line ?: -1 + ) + reportSuspensionPointInsideMonitor("$stackTraceElement") + return + } + } + } + + private fun fixLvtForParameters(methodNode: MethodNode, startLabel: LabelNode, endLabel: LabelNode) { + val paramsNum = + /* this */ (if (isStatic(methodNode.access)) 0 else 1) + + /* real params */ Type.getArgumentTypes(methodNode.desc).fold(0) { a, b -> a + b.size } + + for (i in 0 until paramsNum) { + fixRangeOfLvtRecord(methodNode, i, startLabel, endLabel) + } + } + + private fun fixRangeOfLvtRecord(methodNode: MethodNode, index: Int, startLabel: LabelNode, endLabel: LabelNode) { + val vars = methodNode.localVariables.filter { it.index == index } + assert(vars.size <= 1) { + "Someone else occupies parameter's slot at $index" + } + vars.firstOrNull()?.let { + it.start = startLabel + it.end = endLabel + } + } + + private fun writeDebugMetadata( + methodNode: MethodNode, + suspensionPointLineNumbers: List, + spilledToLocalMapping: List> + ) { + val lines = suspensionPointLineNumbers.map { it?.line ?: -1 } + val metadata = classBuilderForCoroutineState.newAnnotation(DEBUG_METADATA_ANNOTATION_ASM_TYPE.descriptor, true) + metadata.visit(COROUTINES_METADATA_SOURCE_FILE_JVM_NAME, sourceFile) + metadata.visit(COROUTINES_METADATA_LINE_NUMBERS_JVM_NAME, lines.toIntArray()) + + val debugIndexToLabel = spilledToLocalMapping.withIndex().flatMap { (labelIndex, list) -> + list.map { labelIndex } + } + val variablesMapping = spilledToLocalMapping.flatten() + metadata.visit(COROUTINES_METADATA_INDEX_TO_LABEL_JVM_NAME, debugIndexToLabel.toIntArray()) + metadata.visitArray(COROUTINES_METADATA_SPILLED_JVM_NAME).also { v -> + variablesMapping.forEach { v.visit(null, it.fieldName) } + }.visitEnd() + metadata.visitArray(COROUTINES_METADATA_LOCAL_NAMES_JVM_NAME).also { v -> + variablesMapping.forEach { v.visit(null, it.variableName) } + }.visitEnd() + metadata.visit(COROUTINES_METADATA_METHOD_NAME_JVM_NAME, methodNode.name) + metadata.visit(COROUTINES_METADATA_CLASS_NAME_JVM_NAME, Type.getObjectType(containingClassInternalName).className) + @Suppress("ConstantConditionIf") + if (COROUTINES_DEBUG_METADATA_VERSION != 1) { + metadata.visit(COROUTINES_METADATA_VERSION_JVM_NAME, COROUTINES_DEBUG_METADATA_VERSION) + } + metadata.visitEnd() + } + + // Warning! This is _continuation_, not _completion_, it can be allocated inside the method, thus, it is incorrect to treat it + // as a parameter + private fun addContinuationAndResultToLvt( + methodNode: MethodNode, + startLabel: Label, + resultStartLabel: Label + ) { + val endLabel = Label() + methodNode.instructions.add(withInstructionAdapter { mark(endLabel) }) + methodNode.visitLocalVariable( + CONTINUATION_VARIABLE_NAME, + languageVersionSettings.continuationAsmType().descriptor, + null, + startLabel, + endLabel, + continuationIndex + ) + methodNode.visitLocalVariable( + SUSPEND_CALL_RESULT_NAME, + AsmTypes.OBJECT_TYPE.descriptor, + null, + resultStartLabel, + endLabel, + dataIndex + ) + } + + private fun removeFakeContinuationConstructorCall(methodNode: MethodNode) { + val seq = methodNode.instructions.asSequence() + val first = seq.firstOrNull(::isBeforeFakeContinuationConstructorCallMarker)?.previous ?: return + val last = seq.firstOrNull(::isAfterFakeContinuationConstructorCallMarker).sure { + "BeforeFakeContinuationConstructorCallMarker without AfterFakeContinuationConstructorCallMarker" + } + val toRemove = InsnSequence(first, last).toList() + methodNode.instructions.removeAll(toRemove) + methodNode.instructions.set(last, InsnNode(Opcodes.ACONST_NULL)) + } + + private fun InstructionAdapter.getLabel() { + if (isForNamedFunction && !languageVersionSettings.isReleaseCoroutines()) + invokevirtual( + classBuilderForCoroutineState.thisName, + "getLabel", + Type.getMethodDescriptor(Type.INT_TYPE), + false + ) + else + getfield( + computeLabelOwner(languageVersionSettings, classBuilderForCoroutineState.thisName).internalName, + COROUTINE_LABEL_FIELD_NAME, Type.INT_TYPE.descriptor + ) + } + + private fun InstructionAdapter.setLabel() { + if (isForNamedFunction && !languageVersionSettings.isReleaseCoroutines()) + invokevirtual( + classBuilderForCoroutineState.thisName, + "setLabel", + Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE), + false + ) + else + putfield( + computeLabelOwner(languageVersionSettings, classBuilderForCoroutineState.thisName).internalName, + COROUTINE_LABEL_FIELD_NAME, Type.INT_TYPE.descriptor + ) + } + + private fun updateMaxStack(methodNode: MethodNode) { + methodNode.instructions.resetLabels() + methodNode.accept( + MaxStackFrameSizeAndLocalsCalculator( + Opcodes.API_VERSION, methodNode.access, methodNode.desc, + object : MethodVisitor(Opcodes.API_VERSION) { + override fun visitMaxs(maxStack: Int, maxLocals: Int) { + methodNode.maxStack = maxStack + } + } + ) + ) + } + + private fun prepareMethodNodePreludeForNamedFunction(methodNode: MethodNode) { + val objectTypeForState = Type.getObjectType(classBuilderForCoroutineState.thisName) + val continuationArgumentIndex = getLastParameterIndex(methodNode.desc, methodNode.access) + methodNode.instructions.asSequence().filterIsInstance().forEach { + if (it.`var` != continuationArgumentIndex) return@forEach + assert(it.opcode == Opcodes.ALOAD) { "Only ALOADs are allowed for continuation arguments" } + it.`var` = continuationIndex + } + + methodNode.instructions.insert(withInstructionAdapter { + val createStateInstance = Label() + val afterCoroutineStateCreated = Label() + + // We have to distinguish the following situations: + // - Our function got called in a common way (e.g. from another function or via recursive call) and we should execute our + // code from the beginning + // - We got called from `doResume` of our continuation, i.e. we need to continue from the last suspension point + // + // Also in the first case we wrap the completion into a special anonymous class instance (let's call it X$1) + // that we'll use as a continuation argument for suspension points + // + // How we distinguish the cases: + // - If the continuation is not an instance of X$1 we know exactly it's not the second case, because when resuming + // the continuation we pass an instance of that class + // - Otherwise it's still can be a recursive call. To check it's not the case we set the last bit in the label in + // `doResume` just before calling the suspend function (see kotlin.coroutines.experimental.jvm.internal.CoroutineImplForNamedFunction). + // So, if it's set we're in continuation. + + visitVarInsn(Opcodes.ALOAD, continuationArgumentIndex) + instanceOf(objectTypeForState) + ifeq(createStateInstance) + + visitVarInsn(Opcodes.ALOAD, continuationArgumentIndex) + checkcast(objectTypeForState) + visitVarInsn(Opcodes.ASTORE, continuationIndex) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + getLabel() + + iconst(1 shl 31) + and(Type.INT_TYPE) + ifeq(createStateInstance) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + dup() + getLabel() + + iconst(1 shl 31) + sub(Type.INT_TYPE) + setLabel() + + goTo(afterCoroutineStateCreated) + + visitLabel(createStateInstance) + + generateContinuationConstructorCall( + objectTypeForState, + methodNode, + needDispatchReceiver, + internalNameForDispatchReceiver, + containingClassInternalName, + classBuilderForCoroutineState, + languageVersionSettings + ) + + visitVarInsn(Opcodes.ASTORE, continuationIndex) + + visitLabel(afterCoroutineStateCreated) + + visitVarInsn(Opcodes.ALOAD, continuationIndex) + getfield(classBuilderForCoroutineState.thisName, languageVersionSettings.dataFieldName(), AsmTypes.OBJECT_TYPE.descriptor) + visitVarInsn(Opcodes.ASTORE, dataIndex) + + val resultStartLabel = Label() + visitLabel(resultStartLabel) + + addContinuationAndResultToLvt(methodNode, afterCoroutineStateCreated, resultStartLabel) + + if (!languageVersionSettings.isReleaseCoroutines()) { + visitVarInsn(Opcodes.ALOAD, continuationIndex) + getfield(classBuilderForCoroutineState.thisName, EXCEPTION_FIELD_NAME, AsmTypes.JAVA_THROWABLE_TYPE.descriptor) + visitVarInsn(Opcodes.ASTORE, exceptionIndex) + } + }) + } + + /* + * Every suspension point should be surrounded by two markers: before suspension point marker (start marker) + * and after suspension point marker (end marker) + * + * However, if suspension point comes from inline function and its end marker is unreachable, the end marker is removed by + * either inliner or bytecode optimization. + * + * If this happens, we should restore end marker. + * + * Since in both cases (when end marker is reachable and when it is not) all paths should lead to + * either a single end marker or to ATHROWs and ARETURNs, we just compute all paths from start marker until they reach + * these instructions. + */ + private fun collectSuspensionPoints(methodNode: MethodNode): List { + // Exception paths lead outside suspension points, thus we should ignore them + val cfg = ControlFlowGraph.build(methodNode, followExceptions = false) + + // DFS until end marker or ATHROW or ARETURN. + // return true if it contains nested suspension points, which happens when we inline suspend lambda + // with multiple suspension points via several inlines. See boxInline/state/stateMachine/passLambda.kt as an example. + // In this case we simply ignore them. + fun collectSuspensionPointEnds( + insn: AbstractInsnNode, + visited: MutableSet, + ends: MutableSet + ): Boolean { + if (!visited.add(insn)) return false + if (insn.opcode == Opcodes.ARETURN || insn.opcode == Opcodes.ATHROW || isAfterSuspendMarker(insn)) { + ends.add(insn) + } else { + for (index in cfg.getSuccessorsIndices(insn)) { + val succ = methodNode.instructions[index] + if (isBeforeSuspendMarker(succ)) return true + if (collectSuspensionPointEnds(succ, visited, ends)) return true + } + } + return false + } + + val starts = methodNode.instructions.asSequence().filter { + isBeforeSuspendMarker(it) && + cfg.getPredecessorsIndices(it).isNotEmpty() // Ignore unreachable start markers + }.toList() + return starts.mapNotNull { start -> + val ends = mutableSetOf() + if (collectSuspensionPointEnds(start, mutableSetOf(), ends)) return@mapNotNull null + // Ignore suspension points, if the suspension call begin is alive and suspension call end is dead + // (e.g., an inlined suspend function call ends with throwing a exception -- see KT-15017), + // (also see boxInline/suspend/stateMachine/unreachableSuspendMarker.kt) + // this is an exit point for the corresponding coroutine. + val end = ends.find { isAfterSuspendMarker(it) } ?: return@mapNotNull null + SuspensionPoint(start.previous, end) + } + } + + private fun dropSuspensionMarkers(methodNode: MethodNode) { + // Drop markers, including ones, which we ignored in recognizing phase + for (marker in methodNode.instructions.asSequence().filter { isBeforeSuspendMarker(it) || isAfterSuspendMarker(it) }.toList()) { + methodNode.instructions.removeAll(listOf(marker.previous, marker)) + } + } + + private fun spillVariables(suspensionPoints: List, methodNode: MethodNode): List> { + val instructions = methodNode.instructions + val frames = performRefinedTypeAnalysis(methodNode, containingClassInternalName) + fun AbstractInsnNode.index() = instructions.indexOf(this) + + // We postpone these actions because they change instruction indices that we use when obtaining frames + val postponedActions = mutableListOf<() -> Unit>() + val maxVarsCountByType = mutableMapOf() + val livenessFrames = analyzeLiveness(methodNode) + val spilledToVariableMapping = arrayListOf>() + + for (suspension in suspensionPoints) { + val suspensionCallBegin = suspension.suspensionCallBegin + + assert(frames[suspension.suspensionCallEnd.next.index()]?.stackSize == 1) { + "Stack should be spilled before suspension call" + } + + val frame = frames[suspensionCallBegin.index()].sure { "Suspension points containing in dead code must be removed" } + val localsCount = frame.locals + val varsCountByType = mutableMapOf() + + // We consider variable liveness to avoid problems with inline suspension functions: + // + // * + // RETURN (appears only on further transformation phase) + // ... + // + // + // The problem is that during current phase (before inserting RETURN opcode) we suppose variables generated + // within inline suspension point as correctly initialized, thus trying to spill them. + // While after RETURN introduction these variables become uninitialized (at the same time they can't be used further). + // So we only spill variables that are alive at the begin of suspension point. + // NB: it's also rather useful for sake of optimization + val livenessFrame = livenessFrames[suspensionCallBegin.index()] + + val spilledToVariable = arrayListOf() + + // 0 - this + // 1 - parameter + // ... + // k - continuation + // k + 1 - data + // k + 2 - exception + val variablesToSpill = + (0 until localsCount) + .filterNot { it in setOf(continuationIndex, dataIndex, exceptionIndex) } + .map { Pair(it, frame.getLocal(it)) } + .filter { (index, value) -> + (index == 0 && needDispatchReceiver && isForNamedFunction) || + (value != StrictBasicValue.UNINITIALIZED_VALUE && livenessFrame.isAlive(index)) + } + + for ((index, basicValue) in variablesToSpill) { + if (basicValue === StrictBasicValue.NULL_VALUE) { + postponedActions.add { + with(instructions) { + insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter { + aconst(null) + store(index, AsmTypes.OBJECT_TYPE) + }) + } + } + continue + } + + val type = basicValue.type + val normalizedType = type.normalize() + + val indexBySort = varsCountByType[normalizedType]?.plus(1) ?: 0 + varsCountByType[normalizedType] = indexBySort + + val fieldName = normalizedType.fieldNameForVar(indexBySort) + localVariableName(methodNode, index, suspension.suspensionCallEnd.next.index()) + ?.let { spilledToVariable.add(SpilledVariableDescriptor(fieldName, it)) } + + postponedActions.add { + with(instructions) { + // store variable before suspension call + insertBefore(suspension.suspensionCallBegin, withInstructionAdapter { + load(continuationIndex, AsmTypes.OBJECT_TYPE) + load(index, type) + StackValue.coerce(type, normalizedType, this) + putfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor) + }) + + // restore variable after suspension call + insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter { + load(continuationIndex, AsmTypes.OBJECT_TYPE) + getfield(classBuilderForCoroutineState.thisName, fieldName, normalizedType.descriptor) + StackValue.coerce(normalizedType, type, this) + store(index, type) + }) + } + } + } + + spilledToVariableMapping.add(spilledToVariable) + + varsCountByType.forEach { + maxVarsCountByType[it.key] = max(maxVarsCountByType[it.key] ?: 0, it.value) + } + } + + postponedActions.forEach(Function0::invoke) + + maxVarsCountByType.forEach { entry -> + val (type, maxIndex) = entry + for (index in 0..maxIndex) { + classBuilderForCoroutineState.newField( + JvmDeclarationOrigin.NO_ORIGIN, AsmUtil.NO_FLAG_PACKAGE_PRIVATE, + type.fieldNameForVar(index), type.descriptor, null, null + ) + } + } + return spilledToVariableMapping + } + + private fun localVariableName( + methodNode: MethodNode, + index: Int, + suspensionCallIndex: Int + ): String? { + val variable = methodNode.localVariables.find { + index == it.index && methodNode.instructions.indexOf(it.start) <= suspensionCallIndex + && suspensionCallIndex < methodNode.instructions.indexOf(it.end) + } + return variable?.name + } + + /** + * See 'splitTryCatchBlocksContainingSuspensionPoint' + */ + private val SuspensionPoint.tryCatchBlockEndLabelAfterSuspensionCall: LabelNode + get() { + assert(suspensionCallEnd.next is LabelNode) { + "Next instruction after ${this} should be a label, but " + + "${suspensionCallEnd.next::class.java}/${suspensionCallEnd.next.opcode} was found" + } + + return suspensionCallEnd.next as LabelNode + } + + private fun transformCallAndReturnContinuationLabel( + id: Int, + suspension: SuspensionPoint, + methodNode: MethodNode, + suspendMarkerVarIndex: Int, + suspendPointLineNumber: LineNumberNode? + ): LabelNode { + val continuationLabel = LabelNode() + val continuationLabelAfterLoadedResult = LabelNode() + val suspendElementLineNumber = lineNumber + var nextLineNumberNode = nextDefinitelyHitLineNumber(suspension) + with(methodNode.instructions) { + // Save state + insertBefore( + suspension.suspensionCallBegin, + withInstructionAdapter { + visitVarInsn(Opcodes.ALOAD, continuationIndex) + iconst(id) + setLabel() + } + ) + + insert(suspension.tryCatchBlockEndLabelAfterSuspensionCall, withInstructionAdapter { + dup() + load(suspendMarkerVarIndex, AsmTypes.OBJECT_TYPE) + ifacmpne(continuationLabelAfterLoadedResult.label) + + // Exit + val returnLabel = LabelNode() + visitLabel(returnLabel.label) + // Special line number to stop in debugger before suspend return + visitLineNumber(suspendElementLineNumber, returnLabel.label) + load(suspendMarkerVarIndex, AsmTypes.OBJECT_TYPE) + areturn(AsmTypes.OBJECT_TYPE) + // Mark place for continuation + visitLabel(continuationLabel.label) + }) + + // After suspension point there is always three nodes: L1, NOP, L2 + // And if there are relevant exception handlers, they always start at L2 + // See 'splitTryCatchBlocksContainingSuspensionPoint' + val possibleTryCatchBlockStart = suspension.tryCatchBlocksContinuationLabel + + // Move NOP, which is inserted in `splitTryCatchBlocksContainingSuspentionPoint`, inside the try catch block, + // so the inliner can transform suspend lambdas during inlining + assert(possibleTryCatchBlockStart.previous.opcode == Opcodes.NOP) { + "NOP expected but ${possibleTryCatchBlockStart.previous.opcode} was found" + } + remove(possibleTryCatchBlockStart.previous) + + insert(possibleTryCatchBlockStart, withInstructionAdapter { + nop() + generateResumeWithExceptionCheck(languageVersionSettings.isReleaseCoroutines(), dataIndex, exceptionIndex) + + // Load continuation argument just like suspending function returns it + load(dataIndex, AsmTypes.OBJECT_TYPE) + + visitLabel(continuationLabelAfterLoadedResult.label) + + if (nextLineNumberNode != null) { + // If there is a clear next linenumber instruction, extend it. Can't use line number of suspension point + // here because both non-suspended execution and re-entering after suspension passes this label. + if (possibleTryCatchBlockStart.next?.opcode?.let { + it != Opcodes.ASTORE && it != Opcodes.CHECKCAST && it != Opcodes.INVOKESTATIC && + it != Opcodes.INVOKEVIRTUAL && it != Opcodes.INVOKEINTERFACE + } == true + ) { + visitLineNumber(nextLineNumberNode!!.line, continuationLabelAfterLoadedResult.label) + } else { + // But keep the linenumber if the result of the call is used afterwards + nextLineNumberNode = null + } + } else if (suspendPointLineNumber != null) { + // If there is no clear next linenumber instruction, the continuation is still on the + // same line as the suspend point. + visitLineNumber(suspendPointLineNumber.line, continuationLabelAfterLoadedResult.label) + } + }) + + if (nextLineNumberNode != null) { + // Remove the line number instruction as it now covered with line number on continuation label. + // If both linenumber are present in bytecode, debugger will trigger line specific events twice. + remove(nextLineNumberNode) + } + } + + return continuationLabel + } + + // Find the next line number instruction that is defintely hit. That is, a line number + // that comes before any branch or method call. + private fun nextDefinitelyHitLineNumber(suspension: SuspensionPoint): LineNumberNode? { + var next = suspension.suspensionCallEnd.next + while (next != null) { + if (next.isBranchOrCall) return null + else if (next is LineNumberNode) return next + else next = next.next + } + return next + } + + // It's necessary to preserve some sensible invariants like there should be no jump in the middle of try-catch-block + // Also it's important that spilled variables are being restored outside of TCB, + // otherwise they would be treated as uninitialized within catch-block while they can be used there + // How suspension point area will look like after all transformations: + // + // INVOKESTATIC beforeSuspensionMarker + // INVOKEVIRTUAL suspensionMethod()Ljava/lang/Object; + // CHECKCAST SomeType + // INVOKESTATIC afterSuspensionMarker + // L1: -- end of all TCB's that are containing the suspension point (inserted by this method) + // RETURN + // L2: -- continuation label (used for the TABLESWITCH) + // (no try-catch blocks here) + // L3: begin/continuation of all TCB's that are containing the suspension point (inserted by this method) + // ... + private fun splitTryCatchBlocksContainingSuspensionPoint(methodNode: MethodNode, suspensionPoint: SuspensionPoint) { + val instructions = methodNode.instructions + val beginIndex = instructions.indexOf(suspensionPoint.suspensionCallBegin) + val endIndex = instructions.indexOf(suspensionPoint.suspensionCallEnd) + + val firstLabel = LabelNode() + val secondLabel = LabelNode() + instructions.insert(suspensionPoint.suspensionCallEnd, firstLabel) + // NOP is needed to preventing these label merge + // Here between these labels additional instructions are supposed to be inserted (variables spilling, etc.) + instructions.insert(firstLabel, InsnNode(Opcodes.NOP)) + instructions.insert(firstLabel.next, secondLabel) + + methodNode.tryCatchBlocks = + methodNode.tryCatchBlocks.flatMap { + val isContainingSuspensionPoint = + instructions.indexOf(it.start) < beginIndex && beginIndex < instructions.indexOf(it.end) + + if (isContainingSuspensionPoint) { + assert(instructions.indexOf(it.start) < endIndex && endIndex < instructions.indexOf(it.end)) { + "Try catch block ${instructions.indexOf(it.start)}:${instructions.indexOf(it.end)} containing marker before " + + "suspension point $beginIndex should also contain the marker after suspension point $endIndex" + } + listOf( + TryCatchBlockNode(it.start, firstLabel, it.handler, it.type), + TryCatchBlockNode(secondLabel, it.end, it.handler, it.type) + ) + } else + listOf(it) + } + + suspensionPoint.tryCatchBlocksContinuationLabel = secondLabel + + return + } + + private data class SpilledVariableDescriptor(val fieldName: String, val variableName: String) +} + +// TODO Use this in variable liveness analysis +private class MethodNodeExaminer( + val languageVersionSettings: LanguageVersionSettings, + val containingClassInternalName: String, + val methodNode: MethodNode, + disableTailCallOptimizationForFunctionReturningUnit: Boolean +) { + private val sourceFrames: Array> = + MethodTransformer.analyze(containingClassInternalName, methodNode, IgnoringCopyOperationSourceInterpreter()) + private val controlFlowGraph = ControlFlowGraph.build(methodNode) + + private val safeUnitInstances = mutableSetOf() + private val popsBeforeSafeUnitInstances = mutableSetOf() + private val areturnsAfterSafeUnitInstances = mutableSetOf() + private val meaningfulSuccessorsCache = hashMapOf>() + private val meaningfulPredecessorsCache = hashMapOf>() + + init { + if (!disableTailCallOptimizationForFunctionReturningUnit) { + // retrieve all POP insns + val pops = methodNode.instructions.asSequence().filter { it.opcode == Opcodes.POP } + // for each of them check that all successors are PUSH Unit + val popsBeforeUnitInstances = pops.map { it to it.meaningfulSuccessors() } + .filter { (_, succs) -> succs.all { it.isUnitInstance() } } + .map { it.first }.toList() + for (pop in popsBeforeUnitInstances) { + val units = pop.meaningfulSuccessors() + val allUnitsAreSafe = units.all { unit -> + // check no other predecessor exists + unit.meaningfulPredecessors().all { it in popsBeforeUnitInstances } && + // check they have only returns among successors + unit.meaningfulSuccessors().all { it.opcode == Opcodes.ARETURN } + } + if (!allUnitsAreSafe) continue + // save them all to the properties + popsBeforeSafeUnitInstances += pop + safeUnitInstances += units + units.flatMapTo(areturnsAfterSafeUnitInstances) { it.meaningfulSuccessors() } + } + } + } + + private fun AbstractInsnNode.index() = methodNode.instructions.indexOf(this) + + // GETSTATIC kotlin/Unit.INSTANCE is considered safe iff + // it is part of POP, PUSH Unit, ARETURN sequence. + private fun AbstractInsnNode.isSafeUnitInstance(): Boolean = this in safeUnitInstances + + private fun AbstractInsnNode.isPopBeforeSafeUnitInstance(): Boolean = this in popsBeforeSafeUnitInstances + private fun AbstractInsnNode.isAreturnAfterSafeUnitInstance(): Boolean = this in areturnsAfterSafeUnitInstances + + private fun AbstractInsnNode.meaningfulSuccessors(): List = meaningfulSuccessorsCache.getOrPut(this) { + meaningfulSuccessorsOrPredecessors(true) + } + + private fun AbstractInsnNode.meaningfulPredecessors(): List = meaningfulPredecessorsCache.getOrPut(this) { + meaningfulSuccessorsOrPredecessors(false) + } + + private fun AbstractInsnNode.meaningfulSuccessorsOrPredecessors(isSuccessors: Boolean): List { + fun AbstractInsnNode.isMeaningful() = isMeaningful && opcode != Opcodes.NOP && opcode != Opcodes.GOTO && this !is LineNumberNode + + fun AbstractInsnNode.getIndices() = + if (isSuccessors) controlFlowGraph.getSuccessorsIndices(this) + else controlFlowGraph.getPredecessorsIndices(this) + + val visited = arrayListOf() + fun dfs(insn: AbstractInsnNode) { + if (insn in visited) return + visited += insn + if (!insn.isMeaningful()) { + for (succIndex in insn.getIndices()) { + dfs(methodNode.instructions[succIndex]) + } + } + } + + for (succIndex in getIndices()) { + dfs(methodNode.instructions[succIndex]) + } + return visited.filter { it.isMeaningful() } + } + + fun replacePopsBeforeSafeUnitInstancesWithCoroutineSuspendedChecks() { + val basicAnalyser = Analyzer(BasicInterpreter()) + basicAnalyser.analyze(containingClassInternalName, methodNode) + val typedFrames = basicAnalyser.frames + + val isReferenceMap = popsBeforeSafeUnitInstances + .map { it to (!isUnreachable(it.index(), sourceFrames) && typedFrames[it.index()]?.top()?.isReference == true) } + .toMap() + + for (pop in popsBeforeSafeUnitInstances) { + if (isReferenceMap[pop] == true) { + val label = Label() + methodNode.instructions.insertBefore(pop, withInstructionAdapter { + dup() + loadCoroutineSuspendedMarker(languageVersionSettings) + ifacmpne(label) + areturn(AsmTypes.OBJECT_TYPE) + mark(label) + }) + } + } + } + + fun allSuspensionPointsAreTailCalls(suspensionPoints: List): Boolean { + val safelyReachableReturns = findSafelyReachableReturns() + + val instructions = methodNode.instructions + return suspensionPoints.all { suspensionPoint -> + val beginIndex = instructions.indexOf(suspensionPoint.suspensionCallBegin) + val endIndex = instructions.indexOf(suspensionPoint.suspensionCallEnd) + + if (isUnreachable(endIndex, sourceFrames)) return@all true + + val insideTryBlock = methodNode.tryCatchBlocks.any { block -> + val tryBlockStartIndex = instructions.indexOf(block.start) + val tryBlockEndIndex = instructions.indexOf(block.end) + + beginIndex in tryBlockStartIndex..tryBlockEndIndex + } + if (insideTryBlock) return@all false + + safelyReachableReturns[endIndex + 1]?.all { returnIndex -> + sourceFrames[returnIndex]?.top().sure { + "There must be some value on stack to return" + }.insns.any { sourceInsn -> + sourceInsn?.let(instructions::indexOf) in beginIndex..endIndex + } + } ?: false + } + } + + /** + * Let's call an instruction safe if its execution is always invisible: stack modifications, branching, variable insns (invisible in debug) + * + * For some instruction `insn` define the result as following: + * - if there is a path leading to the non-safe instruction then result is `null` + * - Otherwise result contains all the reachable ARETURN indices + * + * @return indices of safely reachable returns for each instruction in the method node + */ + private fun findSafelyReachableReturns(): Array?> { + val insns = methodNode.instructions + val reachableReturnsIndices = Array?>(insns.size()) init@{ index -> + val insn = insns[index] + + if (insn.opcode == Opcodes.ARETURN && !insn.isAreturnAfterSafeUnitInstance()) { + if (isUnreachable(index, sourceFrames)) return@init null + return@init setOf(index) + } + + // Since POP, PUSH Unit, ARETURN behaves like normal return in terms of tail-call optimization, set return index to POP + if (insn.isPopBeforeSafeUnitInstance()) { + return@init setOf(index) + } + + if (!insn.isMeaningful || insn.opcode in SAFE_OPCODES || insn.isInvisibleInDebugVarInsn(methodNode) || isInlineMarker(insn) + || insn.isSafeUnitInstance() || insn.isAreturnAfterSafeUnitInstance() + ) { + setOf() + } else null + } + + var changed: Boolean + do { + changed = false + for (index in 0 until insns.size()) { + if (insns[index].opcode == Opcodes.ARETURN) continue + + @Suppress("RemoveExplicitTypeArguments") + val newResult = + controlFlowGraph + .getSuccessorsIndices(index).plus(index) + .map(reachableReturnsIndices::get) + .fold?, Set?>(mutableSetOf()) { acc, successorsResult -> + if (acc != null && successorsResult != null) acc + successorsResult else null + } + + if (newResult != reachableReturnsIndices[index]) { + reachableReturnsIndices[index] = newResult + changed = true + } + } + } while (changed) + + return reachableReturnsIndices + } +} + +internal fun InstructionAdapter.generateContinuationConstructorCall( + objectTypeForState: Type?, + methodNode: MethodNode, + needDispatchReceiver: Boolean, + internalNameForDispatchReceiver: String?, + containingClassInternalName: String, + classBuilderForCoroutineState: ClassBuilder, + languageVersionSettings: LanguageVersionSettings +) { + anew(objectTypeForState) + dup() + + val parameterTypesAndIndices = + getParameterTypesIndicesForCoroutineConstructor( + methodNode.desc, + methodNode.access, + needDispatchReceiver, internalNameForDispatchReceiver ?: containingClassInternalName, + languageVersionSettings + ) + for ((type, index) in parameterTypesAndIndices) { + load(index, type) + } + + invokespecial( + classBuilderForCoroutineState.thisName, + "", + Type.getMethodDescriptor( + Type.VOID_TYPE, + *getParameterTypesForCoroutineConstructor( + methodNode.desc, needDispatchReceiver, + internalNameForDispatchReceiver ?: containingClassInternalName + ) + ), + false + ) +} + +private fun InstructionAdapter.generateResumeWithExceptionCheck(isReleaseCoroutines: Boolean, dataIndex: Int, exceptionIndex: Int) { + // Check if resumeWithException has been called + + if (isReleaseCoroutines) { + load(dataIndex, AsmTypes.OBJECT_TYPE) + invokestatic("kotlin/ResultKt", "throwOnFailure", "(Ljava/lang/Object;)V", false) + } else { + load(exceptionIndex, AsmTypes.OBJECT_TYPE) + dup() + val noExceptionLabel = Label() + ifnull(noExceptionLabel) + athrow() + + mark(noExceptionLabel) + pop() + } +} + +private fun Type.fieldNameForVar(index: Int) = descriptor.first() + "$" + index + +inline fun withInstructionAdapter(block: InstructionAdapter.() -> Unit): InsnList { + val tmpMethodNode = MethodNode() + + InstructionAdapter(tmpMethodNode).apply(block) + + return tmpMethodNode.instructions +} + +private fun Type.normalize() = + when (sort) { + Type.ARRAY, Type.OBJECT -> AsmTypes.OBJECT_TYPE + else -> this + } + +/** + * Suspension call may consists of several instructions: + * ICONST_0 + * INVOKESTATIC InlineMarker.mark() + * INVOKEVIRTUAL suspensionMethod()Ljava/lang/Object; // actually it could be some inline method instead of plain call + * CHECKCAST Type + * ICONST_1 + * INVOKESTATIC InlineMarker.mark() + */ +private class SuspensionPoint( + // ICONST_0 + val suspensionCallBegin: AbstractInsnNode, + // INVOKESTATIC InlineMarker.mark() + val suspensionCallEnd: AbstractInsnNode +) { + lateinit var tryCatchBlocksContinuationLabel: LabelNode +} + +internal fun getLastParameterIndex(desc: String, access: Int) = + Type.getArgumentTypes(desc).dropLast(1).map { it.size }.sum() + (if (!isStatic(access)) 1 else 0) + +private fun getParameterTypesForCoroutineConstructor(desc: String, hasDispatchReceiver: Boolean, thisName: String) = + listOfNotNull(if (!hasDispatchReceiver) null else Type.getObjectType(thisName)).toTypedArray() + + Type.getArgumentTypes(desc).last() + +private fun isStatic(access: Int) = access and Opcodes.ACC_STATIC != 0 + +private fun getParameterTypesIndicesForCoroutineConstructor( + desc: String, + containingFunctionAccess: Int, + needDispatchReceiver: Boolean, + thisName: String, + languageVersionSettings: LanguageVersionSettings +): Collection> { + return mutableListOf>().apply { + if (needDispatchReceiver) { + add(Type.getObjectType(thisName) to 0) + } + val continuationIndex = + getAllParameterTypes(desc, !isStatic(containingFunctionAccess), thisName).dropLast(1).map(Type::getSize).sum() + add(languageVersionSettings.continuationAsmType() to continuationIndex) + } +} + +private fun getAllParameterTypes(desc: String, hasDispatchReceiver: Boolean, thisName: String) = + listOfNotNull(if (!hasDispatchReceiver) null else Type.getObjectType(thisName)).toTypedArray() + + Type.getArgumentTypes(desc) + +internal class IgnoringCopyOperationSourceInterpreter : SourceInterpreter(Opcodes.API_VERSION) { + override fun copyOperation(insn: AbstractInsnNode?, value: SourceValue?) = value +} + +// Check whether this instruction is unreachable, i.e. there is no path leading to this instruction +internal fun isUnreachable(index: Int, sourceFrames: Array?>): Boolean = + sourceFrames.size <= index || sourceFrames[index] == null + +private fun AbstractInsnNode?.isInvisibleInDebugVarInsn(methodNode: MethodNode): Boolean { + val insns = methodNode.instructions + val index = insns.indexOf(this) + return (this is VarInsnNode && methodNode.localVariables.none { + it.index == `var` && index in it.start.let(insns::indexOf)..it.end.let(insns::indexOf) + }) +} + +private val SAFE_OPCODES = + ((Opcodes.DUP..Opcodes.DUP2_X2) + Opcodes.NOP + Opcodes.POP + Opcodes.POP2 + (Opcodes.IFEQ..Opcodes.GOTO)).toSet() + +internal fun replaceFakeContinuationsWithRealOnes(methodNode: MethodNode, continuationIndex: Int) { + val fakeContinuations = methodNode.instructions.asSequence().filter(::isFakeContinuationMarker).toList() + for (fakeContinuation in fakeContinuations) { + methodNode.instructions.removeAll(listOf(fakeContinuation.previous.previous, fakeContinuation.previous)) + methodNode.instructions.set(fakeContinuation, VarInsnNode(Opcodes.ALOAD, continuationIndex)) + } +} diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/common/variableLiveness.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/common/variableLiveness.kt.202 new file mode 100644 index 00000000000..b9b0de93bb8 --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/optimization/common/variableLiveness.kt.202 @@ -0,0 +1,101 @@ +/* + * Copyright 2010-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license + * that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.codegen.optimization.common + +import org.jetbrains.kotlin.codegen.coroutines.SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME +import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer +import org.jetbrains.kotlin.load.java.JvmAbi +import org.jetbrains.org.objectweb.asm.Type +import org.jetbrains.org.objectweb.asm.tree.AbstractInsnNode +import org.jetbrains.org.objectweb.asm.tree.IincInsnNode +import org.jetbrains.org.objectweb.asm.tree.MethodNode +import org.jetbrains.org.objectweb.asm.tree.VarInsnNode +import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue +import org.jetbrains.org.objectweb.asm.tree.analysis.Frame +import java.util.* + + +class VariableLivenessFrame(val maxLocals: Int) : VarFrame { + private val bitSet = BitSet(maxLocals) + + override fun mergeFrom(other: VariableLivenessFrame) { + bitSet.or(other.bitSet) + } + + fun markAlive(varIndex: Int) { + bitSet.set(varIndex, true) + } + + fun markAllAlive(bitSet: BitSet) { + this.bitSet.or(bitSet) + } + + fun markDead(varIndex: Int) { + bitSet.set(varIndex, false) + } + + fun isAlive(varIndex: Int): Boolean = bitSet.get(varIndex) + + override fun equals(other: Any?): Boolean { + if (other !is VariableLivenessFrame) return false + return bitSet == other.bitSet + } + + override fun hashCode() = bitSet.hashCode() +} + +fun analyzeLiveness(node: MethodNode): List { + val typeAnnotatedFrames = MethodTransformer.analyze("fake", node, OptimizationBasicInterpreter()) + val visibleByDebuggerVariables = analyzeVisibleByDebuggerVariables(node, typeAnnotatedFrames) + return analyze(node, object : BackwardAnalysisInterpreter { + override fun newFrame(maxLocals: Int) = VariableLivenessFrame(maxLocals) + override fun def(frame: VariableLivenessFrame, insn: AbstractInsnNode) = defVar(frame, insn) + override fun use(frame: VariableLivenessFrame, insn: AbstractInsnNode) = + useVar(frame, insn, node, visibleByDebuggerVariables[node.instructions.indexOf(insn)]) + }) +} + +private fun analyzeVisibleByDebuggerVariables( + node: MethodNode, + typeAnnotatedFrames: Array> +): Array { + val res = Array(node.instructions.size()) { BitSet(node.maxLocals) } + for (local in node.localVariables) { + if (local.name.isInvisibleDebuggerVariable()) continue + for (index in node.instructions.indexOf(local.start) until node.instructions.indexOf(local.end)) { + if (Type.getType(local.desc).sort == typeAnnotatedFrames[index]?.getLocal(local.index)?.type?.sort) { + res[index].set(local.index) + } + } + } + return res +} + +private fun defVar(frame: VariableLivenessFrame, insn: AbstractInsnNode) { + if (insn is VarInsnNode && insn.isStoreOperation()) { + frame.markDead(insn.`var`) + } +} + +private fun useVar( + frame: VariableLivenessFrame, + insn: AbstractInsnNode, + node: MethodNode, + visibleByDebuggerVariables: BitSet +) { + frame.markAllAlive(visibleByDebuggerVariables) + + if (insn is VarInsnNode && insn.isLoadOperation()) { + frame.markAlive(insn.`var`) + } else if (insn is IincInsnNode) { + frame.markAlive(insn.`var`) + } +} + +private fun String.isInvisibleDebuggerVariable(): Boolean = + startsWith(JvmAbi.LOCAL_VARIABLE_NAME_PREFIX_INLINE_ARGUMENT) || + startsWith(JvmAbi.LOCAL_VARIABLE_NAME_PREFIX_INLINE_FUNCTION) || + this == SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/state/SignatureDumpingBuilderFactory.kt.202 b/compiler/backend/src/org/jetbrains/kotlin/codegen/state/SignatureDumpingBuilderFactory.kt.202 new file mode 100644 index 00000000000..b3981a00857 --- /dev/null +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/state/SignatureDumpingBuilderFactory.kt.202 @@ -0,0 +1,151 @@ +/* + * Copyright 2010-2016 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.codegen.state + +import com.intellij.psi.PsiElement +import org.jetbrains.kotlin.codegen.ClassBuilder +import org.jetbrains.kotlin.codegen.ClassBuilderFactory +import org.jetbrains.kotlin.codegen.DelegatingClassBuilder +import org.jetbrains.kotlin.codegen.DelegatingClassBuilderFactory +import org.jetbrains.kotlin.descriptors.DeclarationDescriptor +import org.jetbrains.kotlin.descriptors.DeclarationDescriptorWithVisibility +import org.jetbrains.kotlin.renderer.DescriptorRenderer +import org.jetbrains.kotlin.renderer.DescriptorRendererModifier +import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin +import org.jetbrains.kotlin.resolve.jvm.diagnostics.MemberKind +import org.jetbrains.kotlin.resolve.jvm.diagnostics.RawSignature +import org.jetbrains.org.objectweb.asm.FieldVisitor +import org.jetbrains.org.objectweb.asm.MethodVisitor +import org.jetbrains.kotlin.codegen.coroutines.unwrapInitialDescriptorForSuspendFunction +import org.jetbrains.kotlin.descriptors.CallableDescriptor +import java.io.BufferedWriter +import java.io.File + + +class SignatureDumpingBuilderFactory( + builderFactory: ClassBuilderFactory, + val destination: File +) : DelegatingClassBuilderFactory(builderFactory) { + + companion object { + val MEMBER_RENDERER = DescriptorRenderer.withOptions { + withDefinedIn = false + modifiers -= DescriptorRendererModifier.VISIBILITY + } + val TYPE_RENDERER = DescriptorRenderer.withOptions { + withSourceFileForTopLevel = false + modifiers -= DescriptorRendererModifier.VISIBILITY + } + } + + private val outputStream: BufferedWriter by lazy { + // TODO: Replace with LOG.info and make log output go to MessageCollector + println("[INFO] Dumping signatures to $destination") + destination.parentFile?.mkdirs() + destination.bufferedWriter().apply { append("[\n") } + } + private var firstClassWritten: Boolean = false + + override fun close() { + outputStream.append("\n]\n") + outputStream.close() + super.close() + } + + override fun newClassBuilder(origin: JvmDeclarationOrigin): DelegatingClassBuilder { + return SignatureDumpingClassBuilder(origin, delegate.newClassBuilder(origin)) + } + + + private inner class SignatureDumpingClassBuilder(val origin: JvmDeclarationOrigin, val _delegate: ClassBuilder) : DelegatingClassBuilder() { + override fun getDelegate() = _delegate + + private val signatures = mutableListOf>() + private lateinit var javaClassName: String + + override fun defineClass(origin: PsiElement?, version: Int, access: Int, name: String, signature: String?, superName: String, interfaces: Array) { + javaClassName = name + + super.defineClass(origin, version, access, name, signature, superName, interfaces) + } + + override fun newMethod(origin: JvmDeclarationOrigin, access: Int, name: String, desc: String, signature: String?, exceptions: Array?): MethodVisitor { + signatures += RawSignature(name, desc, MemberKind.METHOD) to origin.descriptor?.let { + if (it is CallableDescriptor) it.unwrapInitialDescriptorForSuspendFunction() else it + } + return super.newMethod(origin, access, name, desc, signature, exceptions) + } + + override fun newField(origin: JvmDeclarationOrigin, access: Int, name: String, desc: String, signature: String?, value: Any?): FieldVisitor { + signatures += RawSignature(name, desc, MemberKind.FIELD) to origin.descriptor + return super.newField(origin, access, name, desc, signature, value) + } + + override fun done() { + if (firstClassWritten) outputStream.append(",\n") else firstClassWritten = true + outputStream.append("\t{\n") + origin.descriptor?.let { + outputStream.append("\t\t").appendNameValue("declaration", TYPE_RENDERER.render(it)).append(",\n") + (it as? DeclarationDescriptorWithVisibility)?.visibility?.let { + outputStream.append("\t\t").appendNameValue("visibility", it.internalDisplayName).append(",\n") + } + } + outputStream.append("\t\t").appendNameValue("class", javaClassName).append(",\n") + + outputStream.append("\t\t").appendQuoted("members").append(": [\n") + signatures.joinTo(outputStream, ",\n") { buildString { + val (signature, descriptor) = it + append("\t\t\t{") + descriptor?.let { + (it as? DeclarationDescriptorWithVisibility)?.visibility?.let { + appendNameValue("visibility", it.internalDisplayName).append(",\t") + } + appendNameValue("declaration", MEMBER_RENDERER.render(it)).append(", ") + + } + appendNameValue("name", signature.name).append(", ") + appendNameValue("desc", signature.desc).append("}") + }} + outputStream.append("\n\t\t]\n\t}") + + super.done() + } + } +} + +private fun Appendable.appendQuoted(value: String?): Appendable = value?.let { append('"').append(jsonEscape(it)).append('"') } ?: append("null") +private fun Appendable.appendNameValue(name: String, value: String?): Appendable = appendQuoted(name).append(": ").appendQuoted(value) + +private fun jsonEscape(value: String): String = buildString { + for (index in 0..value.length - 1) { + val ch = value[index] + when (ch) { + '\b' -> append("\\b") + '\t' -> append("\\t") + '\n' -> append("\\n") + '\r' -> append("\\r") + '\"' -> append("\\\"") + '\\' -> append("\\\\") + else -> if (ch.toInt() < 32) { + append("\\u" + Integer.toHexString(ch.toInt()).padStart(4, '0')) + } + else { + append(ch) + } + } + } +} \ No newline at end of file diff --git a/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtParsingTestCase.java.202 b/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtParsingTestCase.java.202 new file mode 100644 index 00000000000..912863a2f55 --- /dev/null +++ b/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtParsingTestCase.java.202 @@ -0,0 +1,345 @@ +/* + * Copyright 2000-2016 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.test.testFramework; + +import com.intellij.core.CoreASTFactory; +import com.intellij.ide.util.AppPropertiesComponentImpl; +import com.intellij.ide.util.PropertiesComponent; +import com.intellij.lang.*; +import com.intellij.lang.impl.PsiBuilderFactoryImpl; +import com.intellij.mock.*; +import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.PathManager; +import com.intellij.openapi.editor.Document; +import com.intellij.openapi.editor.EditorFactory; +import com.intellij.openapi.extensions.ExtensionPointName; +import com.intellij.openapi.extensions.Extensions; +import com.intellij.openapi.fileEditor.FileDocumentManager; +import com.intellij.openapi.fileEditor.impl.LoadTextUtil; +import com.intellij.openapi.fileTypes.FileTypeFactory; +import com.intellij.openapi.fileTypes.FileTypeManager; +import com.intellij.openapi.fileTypes.FileTypeRegistry; +import com.intellij.openapi.options.SchemeManagerFactory; +import com.intellij.openapi.progress.EmptyProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.progress.impl.CoreProgressManager; +import com.intellij.openapi.util.Disposer; +import com.intellij.openapi.util.Key; +import com.intellij.openapi.util.TextRange; +import com.intellij.openapi.util.io.FileUtil; +import com.intellij.openapi.vfs.CharsetToolkit; +import com.intellij.pom.PomModel; +import com.intellij.pom.core.impl.PomModelImpl; +import com.intellij.pom.tree.TreeAspect; +import com.intellij.psi.*; +import com.intellij.psi.impl.*; +import com.intellij.psi.util.CachedValuesManager; +import com.intellij.testFramework.LightVirtualFile; +import com.intellij.testFramework.MockSchemeManagerFactory; +import com.intellij.testFramework.TestDataFile; +import com.intellij.util.CachedValuesManagerImpl; +import com.intellij.util.Function; +import junit.framework.TestCase; +import org.jetbrains.annotations.NonNls; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.kotlin.idea.KotlinFileType; +import org.picocontainer.ComponentAdapter; +import org.picocontainer.MutablePicoContainer; + +import java.io.File; +import java.io.IOException; +import java.util.Set; + +@SuppressWarnings("ALL") +public abstract class KtParsingTestCase extends KtPlatformLiteFixture { + public static final Key HARD_REF_TO_DOCUMENT_KEY = Key.create("HARD_REF_TO_DOCUMENT_KEY"); + protected String myFilePrefix = ""; + protected String myFileExt; + protected final String myFullDataPath; + protected PsiFile myFile; + private MockPsiManager myPsiManager; + private PsiFileFactoryImpl myFileFactory; + protected Language myLanguage; + private final ParserDefinition[] myDefinitions; + private final boolean myLowercaseFirstLetter; + + protected KtParsingTestCase(@NonNls @NotNull String dataPath, @NotNull String fileExt, @NotNull ParserDefinition... definitions) { + this(dataPath, fileExt, false, definitions); + } + + protected KtParsingTestCase(@NonNls @NotNull String dataPath, @NotNull String fileExt, boolean lowercaseFirstLetter, @NotNull ParserDefinition... definitions) { + myDefinitions = definitions; + myFullDataPath = getTestDataPath() + "/" + dataPath; + myFileExt = fileExt; + myLowercaseFirstLetter = lowercaseFirstLetter; + } + + @Override + protected void setUp() throws Exception { + super.setUp(); + initApplication(); + ComponentAdapter component = getApplication().getPicoContainer().getComponentAdapter(ProgressManager.class.getName()); + + Extensions.registerAreaClass("IDEA_PROJECT", null); + myProject = new MockProjectEx(getTestRootDisposable()); + myPsiManager = new MockPsiManager(myProject); + myFileFactory = new PsiFileFactoryImpl(myPsiManager); + MutablePicoContainer appContainer = getApplication().getPicoContainer(); + final MockEditorFactory editorFactory = new MockEditorFactory(); + MockFileTypeManager mockFileTypeManager = new MockFileTypeManager(KotlinFileType.INSTANCE); + MockFileDocumentManagerImpl mockFileDocumentManager = new MockFileDocumentManagerImpl(new Function() { + @Override + public Document fun(CharSequence charSequence) { + return editorFactory.createDocument(charSequence); + } + }, HARD_REF_TO_DOCUMENT_KEY); + + registerApplicationService(PropertiesComponent.class, new AppPropertiesComponentImpl()); + registerApplicationService(PsiBuilderFactory.class, new PsiBuilderFactoryImpl()); + registerApplicationService(DefaultASTFactory.class, new CoreASTFactory()); + registerApplicationService(SchemeManagerFactory.class, new MockSchemeManagerFactory()); + registerApplicationService(FileTypeManager.class, mockFileTypeManager); + registerApplicationService(FileDocumentManager.class, mockFileDocumentManager); + + registerApplicationService(ProgressManager.class, new CoreProgressManager()); + + registerComponentInstance(appContainer, FileTypeRegistry.class, mockFileTypeManager); + registerComponentInstance(appContainer, FileTypeManager.class, mockFileTypeManager); + registerComponentInstance(appContainer, EditorFactory.class, editorFactory); + registerComponentInstance(appContainer, FileDocumentManager.class, mockFileDocumentManager); + registerComponentInstance(appContainer, PsiDocumentManager.class, new MockPsiDocumentManager()); + + + myProject.registerService(CachedValuesManager.class, new CachedValuesManagerImpl(myProject, new PsiCachedValuesFactory(myPsiManager))); + myProject.registerService(PsiManager.class, myPsiManager); + + this.registerExtensionPoint(FileTypeFactory.FILE_TYPE_FACTORY_EP, FileTypeFactory.class); + registerExtensionPoint(MetaLanguage.EP_NAME, MetaLanguage.class); + + for (ParserDefinition definition : myDefinitions) { + addExplicitExtension(LanguageParserDefinitions.INSTANCE, definition.getFileNodeType().getLanguage(), definition); + } + if (myDefinitions.length > 0) { + configureFromParserDefinition(myDefinitions[0], myFileExt); + } + + // That's for reparse routines + final PomModelImpl pomModel = new PomModelImpl(myProject); + myProject.registerService(PomModel.class, pomModel); + } + + public void configureFromParserDefinition(ParserDefinition definition, String extension) { + myLanguage = definition.getFileNodeType().getLanguage(); + myFileExt = extension; + addExplicitExtension(LanguageParserDefinitions.INSTANCE, this.myLanguage, definition); + registerComponentInstance( + getApplication().getPicoContainer(), FileTypeManager.class, + new MockFileTypeManager(new MockLanguageFileType(myLanguage, myFileExt))); + } + + protected void addExplicitExtension(final LanguageExtension instance, final Language language, final T object) { + instance.addExplicitExtension(language, object); + Disposer.register(myProject, new Disposable() { + @Override + public void dispose() { + instance.removeExplicitExtension(language, object); + } + }); + } + + @Override + protected void registerExtensionPoint(final ExtensionPointName extensionPointName, Class aClass) { + super.registerExtensionPoint(extensionPointName, aClass); + Disposer.register(myProject, new Disposable() { + @Override + public void dispose() { + Extensions.getRootArea().unregisterExtensionPoint(extensionPointName.getName()); + } + }); + } + + protected void registerApplicationService(final Class aClass, T object) { + getApplication().registerService(aClass, object); + Disposer.register(myProject, new Disposable() { + @Override + public void dispose() { + getApplication().getPicoContainer().unregisterComponent(aClass.getName()); + } + }); + } + + public MockProjectEx getProject() { + return myProject; + } + + public MockPsiManager getPsiManager() { + return myPsiManager; + } + + @Override + protected void tearDown() throws Exception { + super.tearDown(); + myFile = null; + myProject = null; + myPsiManager = null; + } + + protected String getTestDataPath() { + return PathManager.getHomePath(); + } + + @NotNull + public final String getTestName() { + return getTestName(myLowercaseFirstLetter); + } + + protected boolean includeRanges() { + return false; + } + + protected boolean skipSpaces() { + return false; + } + + protected boolean checkAllPsiRoots() { + return true; + } + + protected void doTest(boolean checkResult) { + String name = getTestName(); + try { + String text = loadFile(name + "." + myFileExt); + myFile = createPsiFile(name, text); + ensureParsed(myFile); + assertEquals("light virtual file text mismatch", text, ((LightVirtualFile)myFile.getVirtualFile()).getContent().toString()); + assertEquals("virtual file text mismatch", text, LoadTextUtil.loadText(myFile.getVirtualFile())); + assertEquals("doc text mismatch", text, myFile.getViewProvider().getDocument().getText()); + assertEquals("psi text mismatch", text, myFile.getText()); + ensureCorrectReparse(myFile); + if (checkResult){ + checkResult(name, myFile); + } + else{ + toParseTreeText(myFile, skipSpaces(), includeRanges()); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void doTest(String suffix) throws IOException { + String name = getTestName(); + String text = loadFile(name + "." + myFileExt); + myFile = createPsiFile(name, text); + ensureParsed(myFile); + assertEquals(text, myFile.getText()); + checkResult(name + suffix, myFile); + } + + protected void doCodeTest(String code) throws IOException { + String name = getTestName(); + myFile = createPsiFile("a", code); + ensureParsed(myFile); + assertEquals(code, myFile.getText()); + checkResult(myFilePrefix + name, myFile); + } + + protected PsiFile createPsiFile(String name, String text) { + return createFile(name + "." + myFileExt, text); + } + + protected PsiFile createFile(@NonNls String name, String text) { + LightVirtualFile virtualFile = new LightVirtualFile(name, myLanguage, text); + virtualFile.setCharset(CharsetToolkit.UTF8_CHARSET); + return createFile(virtualFile); + } + + protected PsiFile createFile(LightVirtualFile virtualFile) { + return myFileFactory.trySetupPsiForFile(virtualFile, myLanguage, true, false); + } + + protected void checkResult(@NonNls @TestDataFile String targetDataName, final PsiFile file) throws IOException { + doCheckResult(myFullDataPath, file, checkAllPsiRoots(), targetDataName, skipSpaces(), includeRanges()); + } + + public static void doCheckResult(String testDataDir, + PsiFile file, + boolean checkAllPsiRoots, + String targetDataName, + boolean skipSpaces, + boolean printRanges) throws IOException { + FileViewProvider provider = file.getViewProvider(); + Set languages = provider.getLanguages(); + + if (!checkAllPsiRoots || languages.size() == 1) { + doCheckResult(testDataDir, targetDataName + ".txt", toParseTreeText(file, skipSpaces, printRanges).trim()); + return; + } + + for (Language language : languages) { + PsiFile root = provider.getPsi(language); + String expectedName = targetDataName + "." + language.getID() + ".txt"; + doCheckResult(testDataDir, expectedName, toParseTreeText(root, skipSpaces, printRanges).trim()); + } + } + + protected void checkResult(String actual) throws IOException { + String name = getTestName(); + doCheckResult(myFullDataPath, myFilePrefix + name + ".txt", actual); + } + + protected void checkResult(@TestDataFile @NonNls String targetDataName, String actual) throws IOException { + doCheckResult(myFullDataPath, targetDataName, actual); + } + + public static void doCheckResult(String fullPath, String targetDataName, String actual) throws IOException { + String expectedFileName = fullPath + File.separatorChar + targetDataName; + KtUsefulTestCase.assertSameLinesWithFile(expectedFileName, actual); + } + + protected static String toParseTreeText(PsiElement file, boolean skipSpaces, boolean printRanges) { + return DebugUtil.psiToString(file, skipSpaces, printRanges); + } + + protected String loadFile(@NonNls @TestDataFile String name) throws IOException { + return loadFileDefault(myFullDataPath, name); + } + + public static String loadFileDefault(String dir, String name) throws IOException { + return FileUtil.loadFile(new File(dir, name), CharsetToolkit.UTF8, true).trim(); + } + + public static void ensureParsed(PsiFile file) { + file.accept(new PsiElementVisitor() { + @Override + public void visitElement(PsiElement element) { + element.acceptChildren(this); + } + }); + } + + public static void ensureCorrectReparse(@NotNull PsiFile file) { + String psiToStringDefault = DebugUtil.psiToString(file, false, false); + String fileText = file.getText(); + DiffLog diffLog = (new BlockSupportImpl()).reparseRange( + file, file.getNode(), TextRange.allOf(fileText), fileText, new EmptyProgressIndicator(), fileText); + diffLog.performActualPsiChange(file); + + TestCase.assertEquals(psiToStringDefault, DebugUtil.psiToString(file, false, false)); + } +} \ No newline at end of file diff --git a/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtUsefulTestCase.java.202 b/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtUsefulTestCase.java.202 new file mode 100644 index 00000000000..f04e03d66d4 --- /dev/null +++ b/compiler/tests-common/tests/org/jetbrains/kotlin/test/testFramework/KtUsefulTestCase.java.202 @@ -0,0 +1,1189 @@ +/* + * Copyright 2010-2016 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.test.testFramework; + +import com.intellij.codeInsight.CodeInsightSettings; +import com.intellij.concurrency.IdeaForkJoinWorkerThreadFactory; +import com.intellij.diagnostic.PerformanceWatcher; +import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.Application; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.application.PathManager; +import com.intellij.openapi.application.impl.ApplicationInfoImpl; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.fileTypes.StdFileTypes; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.util.Comparing; +import com.intellij.openapi.util.Disposer; +import com.intellij.openapi.util.IconLoader; +import com.intellij.openapi.util.JDOMUtil; +import com.intellij.openapi.util.io.FileUtil; +import com.intellij.openapi.util.text.StringUtil; +import com.intellij.openapi.vfs.LocalFileSystem; +import com.intellij.openapi.vfs.VfsUtilCore; +import com.intellij.openapi.vfs.VirtualFile; +import com.intellij.openapi.vfs.VirtualFileVisitor; +import com.intellij.psi.PsiDocumentManager; +import com.intellij.psi.codeStyle.CodeStyleSettings; +import com.intellij.psi.impl.DocumentCommitProcessor; +import com.intellij.psi.impl.DocumentCommitThread; +import com.intellij.psi.impl.source.PostprocessReformattingAspect; +import com.intellij.rt.execution.junit.FileComparisonFailure; +import com.intellij.testFramework.*; +import com.intellij.testFramework.exceptionCases.AbstractExceptionCase; +import com.intellij.testFramework.fixtures.IdeaTestExecutionPolicy; +import com.intellij.util.*; +import com.intellij.util.containers.ContainerUtil; +import com.intellij.util.containers.PeekableIterator; +import com.intellij.util.containers.PeekableIteratorWrapper; +import com.intellij.util.indexing.FileBasedIndex; +import com.intellij.util.indexing.FileBasedIndexImpl; +import com.intellij.util.lang.CompoundRuntimeException; +import com.intellij.util.ui.UIUtil; +import gnu.trove.Equality; +import gnu.trove.THashSet; +import junit.framework.AssertionFailedError; +import junit.framework.TestCase; +import org.jdom.Element; +import org.jetbrains.annotations.Contract; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.jetbrains.kotlin.testFramework.MockComponentManagerCreationTracer; +import org.jetbrains.kotlin.types.AbstractTypeChecker; +import org.jetbrains.kotlin.types.FlexibleTypeImpl; +import org.junit.Assert; +import org.junit.ComparisonFailure; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +/** + * @author peter + */ +@SuppressWarnings("ALL") +public abstract class KtUsefulTestCase extends TestCase { + public static final boolean IS_UNDER_TEAMCITY = System.getenv("TEAMCITY_VERSION") != null; + public static final String TEMP_DIR_MARKER = "unitTest_"; + public static final boolean OVERWRITE_TESTDATA = Boolean.getBoolean("idea.tests.overwrite.data"); + + private static final String ORIGINAL_TEMP_DIR = FileUtil.getTempDirectory(); + + private static final Map TOTAL_SETUP_COST_MILLIS = new HashMap<>(); + private static final Map TOTAL_TEARDOWN_COST_MILLIS = new HashMap<>(); + + private Application application; + + static { + IdeaForkJoinWorkerThreadFactory.setupPoisonFactory(); + Logger.setFactory(TestLoggerFactory.class); + } + protected static final Logger LOG = Logger.getInstance(KtUsefulTestCase.class); + + @NotNull + private final Disposable myTestRootDisposable = new TestDisposable(); + + static Path ourPathToKeep; + private final List myPathsToKeep = new ArrayList<>(); + + private String myTempDir; + + private static final String DEFAULT_SETTINGS_EXTERNALIZED; + private static final CodeInsightSettings defaultSettings = new CodeInsightSettings(); + static { + // Radar #5755208: Command line Java applications need a way to launch without a Dock icon. + System.setProperty("apple.awt.UIElement", "true"); + + try { + Element oldS = new Element("temp"); + defaultSettings.writeExternal(oldS); + DEFAULT_SETTINGS_EXTERNALIZED = JDOMUtil.writeElement(oldS); + } + catch (Exception e) { + throw new RuntimeException(e); + } + + // -- KOTLIN ADDITIONAL START -- + + FlexibleTypeImpl.RUN_SLOW_ASSERTIONS = true; + AbstractTypeChecker.RUN_SLOW_ASSERTIONS = true; + + // -- KOTLIN ADDITIONAL END -- + } + + /** + * Pass here the exception you want to be thrown first + * E.g.
+     * {@code
+     *   void tearDown() {
+     *     try {
+     *       doTearDowns();
+     *     }
+     *     catch(Exception e) {
+     *       addSuppressedException(e);
+     *     }
+     *     finally {
+     *       super.tearDown();
+     *     }
+     *   }
+     * }
+     * 
+ * + */ + protected void addSuppressedException(@NotNull Throwable e) { + List list = mySuppressedExceptions; + if (list == null) { + mySuppressedExceptions = list = new SmartList<>(); + } + list.add(e); + } + private List mySuppressedExceptions; + + + public KtUsefulTestCase() { + } + + public KtUsefulTestCase(@NotNull String name) { + super(name); + } + + protected boolean shouldContainTempFiles() { + return true; + } + + @Override + protected void setUp() throws Exception { + // -- KOTLIN ADDITIONAL START -- + application = ApplicationManager.getApplication(); + + if (application != null && application.isDisposed()) { + MockComponentManagerCreationTracer.diagnoseDisposedButNotClearedApplication(application); + } + // -- KOTLIN ADDITIONAL END -- + + super.setUp(); + + if (shouldContainTempFiles()) { + IdeaTestExecutionPolicy policy = IdeaTestExecutionPolicy.current(); + String testName = null; + if (policy != null) { + testName = policy.getPerTestTempDirName(); + } + if (testName == null) { + testName = FileUtil.sanitizeFileName(getTestName(true)); + } + testName = new File(testName).getName(); // in case the test name contains file separators + myTempDir = FileUtil.createTempDirectory(TEMP_DIR_MARKER + testName, "", false).getPath(); + FileUtil.resetCanonicalTempPathCache(myTempDir); + } + + boolean isStressTest = isStressTest(); + ApplicationInfoImpl.setInStressTest(isStressTest); + if (isPerformanceTest()) { + Timings.getStatistics(); + } + + // turn off Disposer debugging for performance tests + Disposer.setDebugMode(!isStressTest); + + if (isIconRequired()) { + // ensure that IconLoader will use dummy empty icon + IconLoader.deactivate(); + //IconManager.activate(); + } + } + + protected boolean isIconRequired() { + return false; + } + + @Override + protected void tearDown() throws Exception { + try { + // don't use method references here to make stack trace reading easier + //noinspection Convert2MethodRef + new RunAll( + () -> { + if (isIconRequired()) { + //IconManager.deactivate(); + } + }, + () -> disposeRootDisposable(), + () -> cleanupSwingDataStructures(), + () -> cleanupDeleteOnExitHookList(), + () -> Disposer.setDebugMode(true), + () -> { + if (shouldContainTempFiles()) { + FileUtil.resetCanonicalTempPathCache(ORIGINAL_TEMP_DIR); + if (hasTmpFilesToKeep()) { + File[] files = new File(myTempDir).listFiles(); + if (files != null) { + for (File file : files) { + if (!shouldKeepTmpFile(file)) { + FileUtil.delete(file); + } + } + } + } + else { + FileUtil.delete(new File(myTempDir)); + } + } + }, + () -> waitForAppLeakingThreads(10, TimeUnit.SECONDS) + ).run(ObjectUtils.notNull(mySuppressedExceptions, Collections.emptyList())); + } + finally { + // -- KOTLIN ADDITIONAL START -- + TestApplicationUtilKt.resetApplicationToNull(application); + application = null; + // -- KOTLIN ADDITIONAL END -- + } + } + + protected final void disposeRootDisposable() { + Disposer.dispose(getTestRootDisposable()); + } + + protected void addTmpFileToKeep(@NotNull File file) { + myPathsToKeep.add(file.getPath()); + } + + private boolean hasTmpFilesToKeep() { + return ourPathToKeep != null && FileUtil.isAncestor(myTempDir, ourPathToKeep.toString(), false) || !myPathsToKeep.isEmpty(); + } + + private boolean shouldKeepTmpFile(@NotNull File file) { + String path = file.getPath(); + if (FileUtil.pathsEqual(path, ourPathToKeep.toString())) return true; + for (String pathToKeep : myPathsToKeep) { + if (FileUtil.pathsEqual(path, pathToKeep)) return true; + } + return false; + } + + private static final Set DELETE_ON_EXIT_HOOK_DOT_FILES; + private static final Class DELETE_ON_EXIT_HOOK_CLASS; + static { + Class aClass; + try { + aClass = Class.forName("java.io.DeleteOnExitHook"); + } + catch (Exception e) { + throw new RuntimeException(e); + } + @SuppressWarnings("unchecked") Set files = ReflectionUtil.getStaticFieldValue(aClass, Set.class, "files"); + DELETE_ON_EXIT_HOOK_CLASS = aClass; + DELETE_ON_EXIT_HOOK_DOT_FILES = files; + } + + @SuppressWarnings("SynchronizeOnThis") + private static void cleanupDeleteOnExitHookList() { + // try to reduce file set retained by java.io.DeleteOnExitHook + List list; + synchronized (DELETE_ON_EXIT_HOOK_CLASS) { + if (DELETE_ON_EXIT_HOOK_DOT_FILES.isEmpty()) return; + list = new ArrayList<>(DELETE_ON_EXIT_HOOK_DOT_FILES); + } + for (int i = list.size() - 1; i >= 0; i--) { + String path = list.get(i); + File file = new File(path); + if (file.delete() || !file.exists()) { + synchronized (DELETE_ON_EXIT_HOOK_CLASS) { + DELETE_ON_EXIT_HOOK_DOT_FILES.remove(path); + } + } + } + } + + @SuppressWarnings("ConstantConditions") + private static void cleanupSwingDataStructures() throws Exception { + Object manager = ReflectionUtil.getDeclaredMethod(Class.forName("javax.swing.KeyboardManager"), "getCurrentManager").invoke(null); + Map componentKeyStrokeMap = ReflectionUtil.getField(manager.getClass(), manager, Hashtable.class, "componentKeyStrokeMap"); + componentKeyStrokeMap.clear(); + Map containerMap = ReflectionUtil.getField(manager.getClass(), manager, Hashtable.class, "containerMap"); + containerMap.clear(); + } + + static void doCheckForSettingsDamage(@NotNull CodeStyleSettings oldCodeStyleSettings, @NotNull CodeStyleSettings currentCodeStyleSettings) { + final CodeInsightSettings settings = CodeInsightSettings.getInstance(); + // don't use method references here to make stack trace reading easier + //noinspection Convert2MethodRef + new RunAll() + .append(() -> { + try { + checkCodeInsightSettingsEqual(defaultSettings, settings); + } + catch (AssertionError error) { + CodeInsightSettings clean = new CodeInsightSettings(); + for (Field field : clean.getClass().getFields()) { + try { + ReflectionUtil.copyFieldValue(clean, settings, field); + } + catch (Exception ignored) { + } + } + throw error; + } + }) + .append(() -> { + currentCodeStyleSettings.getIndentOptions(StdFileTypes.JAVA); + try { + checkCodeStyleSettingsEqual(oldCodeStyleSettings, currentCodeStyleSettings); + } + finally { + currentCodeStyleSettings.clearCodeStyleSettings(); + } + }) + .run(); + } + + @NotNull + public Disposable getTestRootDisposable() { + return myTestRootDisposable; + } + + @Override + protected void runTest() throws Throwable { + final Throwable[] throwables = new Throwable[1]; + + Runnable runnable = () -> { + try { + TestLoggerFactory.onTestStarted(); + super.runTest(); + TestLoggerFactory.onTestFinished(true); + } + catch (InvocationTargetException e) { + TestLoggerFactory.onTestFinished(false); + e.fillInStackTrace(); + throwables[0] = e.getTargetException(); + } + catch (IllegalAccessException e) { + TestLoggerFactory.onTestFinished(false); + e.fillInStackTrace(); + throwables[0] = e; + } + catch (Throwable e) { + TestLoggerFactory.onTestFinished(false); + throwables[0] = e; + } + }; + + invokeTestRunnable(runnable); + + if (throwables[0] != null) { + throw throwables[0]; + } + } + + protected boolean shouldRunTest() { + return TestFrameworkUtil.canRunTest(getClass()); + } + + protected void invokeTestRunnable(@NotNull Runnable runnable) throws Exception { + if (runInDispatchThread()) { + EdtTestUtilKt.runInEdtAndWait(() -> { + runnable.run(); + return null; + }); + } + else { + runnable.run(); + } + } + + protected void defaultRunBare() throws Throwable { + Throwable exception = null; + try { + long setupStart = System.nanoTime(); + setUp(); + long setupCost = (System.nanoTime() - setupStart) / 1000000; + logPerClassCost(setupCost, TOTAL_SETUP_COST_MILLIS); + + runTest(); + } + catch (Throwable running) { + exception = running; + } + finally { + try { + long teardownStart = System.nanoTime(); + tearDown(); + long teardownCost = (System.nanoTime() - teardownStart) / 1000000; + logPerClassCost(teardownCost, TOTAL_TEARDOWN_COST_MILLIS); + } + catch (Throwable tearingDown) { + if (exception == null) { + exception = tearingDown; + } + else { + exception = new CompoundRuntimeException(Arrays.asList(exception, tearingDown)); + } + } + } + if (exception != null) { + throw exception; + } + } + + /** + * Logs the setup cost grouped by test fixture class (superclass of the current test class). + * + * @param cost setup cost in milliseconds + */ + private void logPerClassCost(long cost, @NotNull Map costMap) { + Class superclass = getClass().getSuperclass(); + Long oldCost = costMap.get(superclass.getName()); + long newCost = oldCost == null ? cost : oldCost + cost; + costMap.put(superclass.getName(), newCost); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + static void logSetupTeardownCosts() { + System.out.println("Setup costs"); + long totalSetup = 0; + for (Map.Entry entry : TOTAL_SETUP_COST_MILLIS.entrySet()) { + System.out.println(String.format(" %s: %d ms", entry.getKey(), entry.getValue())); + totalSetup += entry.getValue(); + } + System.out.println("Teardown costs"); + long totalTeardown = 0; + for (Map.Entry entry : TOTAL_TEARDOWN_COST_MILLIS.entrySet()) { + System.out.println(String.format(" %s: %d ms", entry.getKey(), entry.getValue())); + totalTeardown += entry.getValue(); + } + System.out.println(String.format("Total overhead: setup %d ms, teardown %d ms", totalSetup, totalTeardown)); + System.out.println(String.format("##teamcity[buildStatisticValue key='ideaTests.totalSetupMs' value='%d']", totalSetup)); + System.out.println(String.format("##teamcity[buildStatisticValue key='ideaTests.totalTeardownMs' value='%d']", totalTeardown)); + } + + @Override + public void runBare() throws Throwable { + if (!shouldRunTest()) return; + + if (runInDispatchThread()) { + TestRunnerUtil.replaceIdeEventQueueSafely(); + EdtTestUtil.runInEdtAndWait(this::defaultRunBare); + } + else { + defaultRunBare(); + } + } + + protected boolean runInDispatchThread() { + IdeaTestExecutionPolicy policy = IdeaTestExecutionPolicy.current(); + if (policy != null) { + return policy.runInDispatchThread(); + } + return true; + } + + /** + * If you want a more shorter name than runInEdtAndWait. + */ + protected void edt(@NotNull ThrowableRunnable runnable) { + EdtTestUtil.runInEdtAndWait(runnable); + } + + @NotNull + public static String toString(@NotNull Iterable collection) { + if (!collection.iterator().hasNext()) { + return ""; + } + + final StringBuilder builder = new StringBuilder(); + for (final Object o : collection) { + if (o instanceof THashSet) { + builder.append(new TreeSet<>((THashSet)o)); + } + else { + builder.append(o); + } + builder.append('\n'); + } + return builder.toString(); + } + + @SafeVarargs + public static void assertOrderedEquals(@NotNull T[] actual, @NotNull T... expected) { + assertOrderedEquals(Arrays.asList(actual), expected); + } + + @SafeVarargs + public static void assertOrderedEquals(@NotNull Iterable actual, @NotNull T... expected) { + assertOrderedEquals("", actual, expected); + } + + public static void assertOrderedEquals(@NotNull byte[] actual, @NotNull byte[] expected) { + assertEquals(expected.length, actual.length); + for (int i = 0; i < actual.length; i++) { + byte a = actual[i]; + byte e = expected[i]; + assertEquals("not equals at index: "+i, e, a); + } + } + + public static void assertOrderedEquals(@NotNull int[] actual, @NotNull int[] expected) { + if (actual.length != expected.length) { + fail("Expected size: "+expected.length+"; actual: "+actual.length+"\nexpected: "+Arrays.toString(expected)+"\nactual : "+Arrays.toString(actual)); + } + for (int i = 0; i < actual.length; i++) { + int a = actual[i]; + int e = expected[i]; + assertEquals("not equals at index: "+i, e, a); + } + } + + @SafeVarargs + public static void assertOrderedEquals(@NotNull String errorMsg, @NotNull Iterable actual, @NotNull T... expected) { + assertOrderedEquals(errorMsg, actual, Arrays.asList(expected)); + } + + public static void assertOrderedEquals(@NotNull Iterable actual, @NotNull Iterable expected) { + assertOrderedEquals("", actual, expected); + } + + public static void assertOrderedEquals(@NotNull String errorMsg, + @NotNull Iterable actual, + @NotNull Iterable expected) { + //noinspection unchecked + assertOrderedEquals(errorMsg, actual, expected, Equality.CANONICAL); + } + + public static void assertOrderedEquals(@NotNull String errorMsg, + @NotNull Iterable actual, + @NotNull Iterable expected, + @NotNull Equality comparator) { + if (!equals(actual, expected, comparator)) { + String expectedString = toString(expected); + String actualString = toString(actual); + Assert.assertEquals(errorMsg, expectedString, actualString); + Assert.fail("Warning! 'toString' does not reflect the difference.\nExpected: " + expectedString + "\nActual: " + actualString); + } + } + + private static boolean equals(@NotNull Iterable a1, + @NotNull Iterable a2, + @NotNull Equality comparator) { + Iterator it1 = a1.iterator(); + Iterator it2 = a2.iterator(); + while (it1.hasNext() || it2.hasNext()) { + if (!it1.hasNext() || !it2.hasNext()) return false; + if (!comparator.equals(it1.next(), it2.next())) return false; + } + return true; + } + + @SafeVarargs + public static void assertOrderedCollection(@NotNull T[] collection, @NotNull Consumer... checkers) { + assertOrderedCollection(Arrays.asList(collection), checkers); + } + + /** + * Checks {@code actual} contains same elements (in {@link #equals(Object)} meaning) as {@code expected} irrespective of their order + */ + @SafeVarargs + public static void assertSameElements(@NotNull T[] actual, @NotNull T... expected) { + assertSameElements(Arrays.asList(actual), expected); + } + + /** + * Checks {@code actual} contains same elements (in {@link #equals(Object)} meaning) as {@code expected} irrespective of their order + */ + @SafeVarargs + public static void assertSameElements(@NotNull Collection actual, @NotNull T... expected) { + assertSameElements(actual, Arrays.asList(expected)); + } + + /** + * Checks {@code actual} contains same elements (in {@link #equals(Object)} meaning) as {@code expected} irrespective of their order + */ + public static void assertSameElements(@NotNull Collection actual, @NotNull Collection expected) { + assertSameElements("", actual, expected); + } + + /** + * Checks {@code actual} contains same elements (in {@link #equals(Object)} meaning) as {@code expected} irrespective of their order + */ + public static void assertSameElements(@NotNull String message, @NotNull Collection actual, @NotNull Collection expected) { + if (actual.size() != expected.size() || !new HashSet<>(expected).equals(new HashSet(actual))) { + Assert.assertEquals(message, new HashSet<>(expected), new HashSet(actual)); + } + } + + @SafeVarargs + public static void assertContainsOrdered(@NotNull Collection collection, @NotNull T... expected) { + assertContainsOrdered(collection, Arrays.asList(expected)); + } + + public static void assertContainsOrdered(@NotNull Collection collection, @NotNull Collection expected) { + PeekableIterator expectedIt = new PeekableIteratorWrapper<>(expected.iterator()); + PeekableIterator actualIt = new PeekableIteratorWrapper<>(collection.iterator()); + + while (actualIt.hasNext() && expectedIt.hasNext()) { + T expectedElem = expectedIt.peek(); + T actualElem = actualIt.peek(); + if (expectedElem.equals(actualElem)) { + expectedIt.next(); + } + actualIt.next(); + } + if (expectedIt.hasNext()) { + throw new ComparisonFailure("", toString(expected), toString(collection)); + } + } + + @SafeVarargs + public static void assertContainsElements(@NotNull Collection collection, @NotNull T... expected) { + assertContainsElements(collection, Arrays.asList(expected)); + } + + public static void assertContainsElements(@NotNull Collection collection, @NotNull Collection expected) { + ArrayList copy = new ArrayList<>(collection); + copy.retainAll(expected); + assertSameElements(toString(collection), copy, expected); + } + + @NotNull + public static String toString(@NotNull Object[] collection, @NotNull String separator) { + return toString(Arrays.asList(collection), separator); + } + + @SafeVarargs + public static void assertDoesntContain(@NotNull Collection collection, @NotNull T... notExpected) { + assertDoesntContain(collection, Arrays.asList(notExpected)); + } + + public static void assertDoesntContain(@NotNull Collection collection, @NotNull Collection notExpected) { + ArrayList expected = new ArrayList<>(collection); + expected.removeAll(notExpected); + assertSameElements(collection, expected); + } + + @NotNull + public static String toString(@NotNull Collection collection, @NotNull String separator) { + List list = ContainerUtil.map2List(collection, String::valueOf); + Collections.sort(list); + StringBuilder builder = new StringBuilder(); + boolean flag = false; + for (final String o : list) { + if (flag) { + builder.append(separator); + } + builder.append(o); + flag = true; + } + return builder.toString(); + } + + @SafeVarargs + public static void assertOrderedCollection(@NotNull Collection collection, @NotNull Consumer... checkers) { + if (collection.size() != checkers.length) { + Assert.fail(toString(collection)); + } + int i = 0; + for (final T actual : collection) { + try { + checkers[i].consume(actual); + } + catch (AssertionFailedError e) { + //noinspection UseOfSystemOutOrSystemErr + System.out.println(i + ": " + actual); + throw e; + } + i++; + } + } + + @SafeVarargs + public static void assertUnorderedCollection(@NotNull T[] collection, @NotNull Consumer... checkers) { + assertUnorderedCollection(Arrays.asList(collection), checkers); + } + + @SafeVarargs + public static void assertUnorderedCollection(@NotNull Collection collection, @NotNull Consumer... checkers) { + if (collection.size() != checkers.length) { + Assert.fail(toString(collection)); + } + Set> checkerSet = ContainerUtil.set(checkers); + int i = 0; + Throwable lastError = null; + for (final T actual : collection) { + boolean flag = true; + for (final Consumer condition : checkerSet) { + Throwable error = accepts(condition, actual); + if (error == null) { + checkerSet.remove(condition); + flag = false; + break; + } + else { + lastError = error; + } + } + if (flag) { + //noinspection ConstantConditions,CallToPrintStackTrace + lastError.printStackTrace(); + Assert.fail("Incorrect element(" + i + "): " + actual); + } + i++; + } + } + + private static Throwable accepts(@NotNull Consumer condition, final T actual) { + try { + condition.consume(actual); + return null; + } + catch (Throwable e) { + return e; + } + } + + @Contract("null, _ -> fail") + @NotNull + public static T assertInstanceOf(Object o, @NotNull Class aClass) { + Assert.assertNotNull("Expected instance of: " + aClass.getName() + " actual: " + null, o); + Assert.assertTrue("Expected instance of: " + aClass.getName() + " actual: " + o.getClass().getName(), aClass.isInstance(o)); + @SuppressWarnings("unchecked") T t = (T)o; + return t; + } + + public static T assertOneElement(@NotNull Collection collection) { + Iterator iterator = collection.iterator(); + String toString = toString(collection); + Assert.assertTrue(toString, iterator.hasNext()); + T t = iterator.next(); + Assert.assertFalse(toString, iterator.hasNext()); + return t; + } + + public static T assertOneElement(@NotNull T[] ts) { + Assert.assertEquals(Arrays.asList(ts).toString(), 1, ts.length); + return ts[0]; + } + + @SafeVarargs + public static void assertOneOf(T value, @NotNull T... values) { + for (T v : values) { + if (Objects.equals(value, v)) { + return; + } + } + Assert.fail(value + " should be equal to one of " + Arrays.toString(values)); + } + + public static void printThreadDump() { + PerformanceWatcher.dumpThreadsToConsole("Thread dump:"); + } + + public static void assertEmpty(@NotNull Object[] array) { + assertOrderedEquals(array); + } + + public static void assertNotEmpty(final Collection collection) { + assertNotNull(collection); + assertFalse(collection.isEmpty()); + } + + public static void assertEmpty(@NotNull Collection collection) { + assertEmpty(collection.toString(), collection); + } + + public static void assertNullOrEmpty(@Nullable Collection collection) { + if (collection == null) return; + assertEmpty("", collection); + } + + public static void assertEmpty(final String s) { + assertTrue(s, StringUtil.isEmpty(s)); + } + + public static void assertEmpty(@NotNull String errorMsg, @NotNull Collection collection) { + assertOrderedEquals(errorMsg, collection, Collections.emptyList()); + } + + public static void assertSize(int expectedSize, @NotNull Object[] array) { + if (array.length != expectedSize) { + assertEquals(toString(Arrays.asList(array)), expectedSize, array.length); + } + } + + public static void assertSize(int expectedSize, @NotNull Collection c) { + if (c.size() != expectedSize) { + assertEquals(toString(c), expectedSize, c.size()); + } + } + + @NotNull + protected T disposeOnTearDown(@NotNull T disposable) { + Disposer.register(getTestRootDisposable(), disposable); + return disposable; + } + + public static void assertSameLines(@NotNull String expected, @NotNull String actual) { + assertSameLines(null, expected, actual); + } + + public static void assertSameLines(@Nullable String message, @NotNull String expected, @NotNull String actual) { + String expectedText = StringUtil.convertLineSeparators(expected.trim()); + String actualText = StringUtil.convertLineSeparators(actual.trim()); + Assert.assertEquals(message, expectedText, actualText); + } + + public static void assertExists(@NotNull File file){ + assertTrue("File should exist " + file, file.exists()); + } + + public static void assertDoesntExist(@NotNull File file){ + assertFalse("File should not exist " + file, file.exists()); + } + + @NotNull + protected String getTestName(boolean lowercaseFirstLetter) { + return getTestName(getName(), lowercaseFirstLetter); + } + + @NotNull + public static String getTestName(@Nullable String name, boolean lowercaseFirstLetter) { + return name == null ? "" : PlatformTestUtil.getTestName(name, lowercaseFirstLetter); + } + + @NotNull + protected String getTestDirectoryName() { + final String testName = getTestName(true); + return testName.replaceAll("_.*", ""); + } + + public static void assertSameLinesWithFile(@NotNull String filePath, @NotNull String actualText) { + assertSameLinesWithFile(filePath, actualText, true); + } + + public static void assertSameLinesWithFile(@NotNull String filePath, + @NotNull String actualText, + @NotNull Supplier messageProducer) { + assertSameLinesWithFile(filePath, actualText, true, messageProducer); + } + + public static void assertSameLinesWithFile(@NotNull String filePath, @NotNull String actualText, boolean trimBeforeComparing) { + assertSameLinesWithFile(filePath, actualText, trimBeforeComparing, null); + } + + public static void assertSameLinesWithFile(@NotNull String filePath, + @NotNull String actualText, + boolean trimBeforeComparing, + @Nullable Supplier messageProducer) { + String fileText; + try { + if (OVERWRITE_TESTDATA) { + VfsTestUtil.overwriteTestData(filePath, actualText); + //noinspection UseOfSystemOutOrSystemErr + System.out.println("File " + filePath + " created."); + } + fileText = FileUtil.loadFile(new File(filePath), StandardCharsets.UTF_8); + } + catch (FileNotFoundException e) { + VfsTestUtil.overwriteTestData(filePath, actualText); + throw new AssertionFailedError("No output text found. File " + filePath + " created."); + } + catch (IOException e) { + throw new RuntimeException(e); + } + String expected = StringUtil.convertLineSeparators(trimBeforeComparing ? fileText.trim() : fileText); + String actual = StringUtil.convertLineSeparators(trimBeforeComparing ? actualText.trim() : actualText); + if (!Comparing.equal(expected, actual)) { + throw new FileComparisonFailure(messageProducer == null ? null : messageProducer.get(), expected, actual, filePath); + } + } + + protected static void clearFields(@NotNull Object test) throws IllegalAccessException { + Class aClass = test.getClass(); + while (aClass != null) { + clearDeclaredFields(test, aClass); + aClass = aClass.getSuperclass(); + } + } + + public static void clearDeclaredFields(@NotNull Object test, @NotNull Class aClass) throws IllegalAccessException { + for (final Field field : aClass.getDeclaredFields()) { + final String name = field.getDeclaringClass().getName(); + if (!name.startsWith("junit.framework.") && !name.startsWith("com.intellij.testFramework.")) { + final int modifiers = field.getModifiers(); + if ((modifiers & Modifier.FINAL) == 0 && (modifiers & Modifier.STATIC) == 0 && !field.getType().isPrimitive()) { + field.setAccessible(true); + field.set(test, null); + } + } + } + } + + private static void checkCodeStyleSettingsEqual(@NotNull CodeStyleSettings expected, @NotNull CodeStyleSettings settings) { + if (!expected.equals(settings)) { + Element oldS = new Element("temp"); + expected.writeExternal(oldS); + Element newS = new Element("temp"); + settings.writeExternal(newS); + + String newString = JDOMUtil.writeElement(newS); + String oldString = JDOMUtil.writeElement(oldS); + Assert.assertEquals("Code style settings damaged", oldString, newString); + } + } + + private static void checkCodeInsightSettingsEqual(@NotNull CodeInsightSettings oldSettings, @NotNull CodeInsightSettings settings) { + if (!oldSettings.equals(settings)) { + Element newS = new Element("temp"); + settings.writeExternal(newS); + Assert.assertEquals("Code insight settings damaged", DEFAULT_SETTINGS_EXTERNALIZED, JDOMUtil.writeElement(newS)); + } + } + + public boolean isPerformanceTest() { + String testName = getName(); + String className = getClass().getSimpleName(); + return TestFrameworkUtil.isPerformanceTest(testName, className); + } + + /** + * @return true for a test which performs A LOT of computations. + * Such test should typically avoid performing expensive checks, e.g. data structure consistency complex validations. + * If you want your test to be treated as "Stress", please mention one of these words in its name: "Stress", "Slow". + * For example: {@code public void testStressPSIFromDifferentThreads()} + */ + public boolean isStressTest() { + return isStressTest(getName(), getClass().getName()); + } + + private static boolean isStressTest(String testName, String className) { + return TestFrameworkUtil.isPerformanceTest(testName, className) || + containsStressWords(testName) || + containsStressWords(className); + } + + private static boolean containsStressWords(@Nullable String name) { + return name != null && (name.contains("Stress") || name.contains("Slow")); + } + + public static void doPostponedFormatting(@NotNull Project project) { + DocumentUtil.writeInRunUndoTransparentAction(() -> { + PsiDocumentManager.getInstance(project).commitAllDocuments(); + PostprocessReformattingAspect.getInstance(project).doPostponedFormatting(); + }); + } + + /** + * Checks that code block throw corresponding exception. + * + * @param exceptionCase Block annotated with some exception type + */ + protected void assertException(@NotNull AbstractExceptionCase exceptionCase) { + assertException(exceptionCase, null); + } + + /** + * Checks that code block throw corresponding exception with expected error msg. + * If expected error message is null it will not be checked. + * + * @param exceptionCase Block annotated with some exception type + * @param expectedErrorMsg expected error message + */ + protected void assertException(@NotNull AbstractExceptionCase exceptionCase, @Nullable String expectedErrorMsg) { + //noinspection unchecked + assertExceptionOccurred(true, exceptionCase, expectedErrorMsg); + } + + /** + * Checks that the code block throws an exception of the specified class. + * + * @param exceptionClass Expected exception type + * @param runnable Block annotated with some exception type + */ + public static void assertThrows(@NotNull Class exceptionClass, + @NotNull ThrowableRunnable runnable) { + assertThrows(exceptionClass, null, runnable); + } + + /** + * Checks that the code block throws an exception of the specified class with expected error msg. + * If expected error message is null it will not be checked. + * + * @param exceptionClass Expected exception type + * @param expectedErrorMsgPart expected error message, of any + * @param runnable Block annotated with some exception type + */ + @SuppressWarnings({"unchecked", "SameParameterValue"}) + public static void assertThrows(@NotNull Class exceptionClass, + @Nullable String expectedErrorMsgPart, + @NotNull ThrowableRunnable runnable) { + assertExceptionOccurred(true, new AbstractExceptionCase() { + @Override + public Class getExpectedExceptionClass() { + return (Class)exceptionClass; + } + + @Override + public void tryClosure() throws Throwable { + runnable.run(); + } + }, expectedErrorMsgPart); + } + + /** + * Checks that code block doesn't throw corresponding exception. + * + * @param exceptionCase Block annotated with some exception type + */ + protected void assertNoException(@NotNull AbstractExceptionCase exceptionCase) throws T { + assertExceptionOccurred(false, exceptionCase, null); + } + + protected void assertNoThrowable(@NotNull Runnable closure) { + String throwableName = null; + try { + closure.run(); + } + catch (Throwable thr) { + throwableName = thr.getClass().getName(); + } + assertNull(throwableName); + } + + private static void assertExceptionOccurred(boolean shouldOccur, + @NotNull AbstractExceptionCase exceptionCase, + String expectedErrorMsgPart) throws T { + boolean wasThrown = false; + try { + exceptionCase.tryClosure(); + } + catch (Throwable e) { + Throwable cause = e; + + if (shouldOccur) { + wasThrown = true; + assertInstanceOf(cause, exceptionCase.getExpectedExceptionClass()); + if (expectedErrorMsgPart != null) { + assertTrue(cause.getMessage(), cause.getMessage().contains(expectedErrorMsgPart)); + } + } + else if (exceptionCase.getExpectedExceptionClass().equals(cause.getClass())) { + wasThrown = true; + + //noinspection UseOfSystemOutOrSystemErr + System.out.println(); + //noinspection UseOfSystemOutOrSystemErr + e.printStackTrace(System.out); + + fail("Exception isn't expected here. Exception message: " + cause.getMessage()); + } + else { + throw e; + } + } + finally { + if (shouldOccur && !wasThrown) { + fail(exceptionCase.getExpectedExceptionClass().getName() + " must be thrown."); + } + } + } + + protected boolean annotatedWith(@NotNull Class annotationClass) { + Class aClass = getClass(); + String methodName = "test" + getTestName(false); + boolean methodChecked = false; + while (aClass != null && aClass != Object.class) { + if (aClass.getAnnotation(annotationClass) != null) return true; + if (!methodChecked) { + Method method = ReflectionUtil.getDeclaredMethod(aClass, methodName); + if (method != null) { + if (method.getAnnotation(annotationClass) != null) return true; + methodChecked = true; + } + } + aClass = aClass.getSuperclass(); + } + return false; + } + + @NotNull + protected String getHomePath() { + return PathManager.getHomePath().replace(File.separatorChar, '/'); + } + + public static void refreshRecursively(@NotNull VirtualFile file) { + VfsUtilCore.visitChildrenRecursively(file, new VirtualFileVisitor() { + @Override + public boolean visitFile(@NotNull VirtualFile file) { + file.getChildren(); + return true; + } + }); + file.refresh(false, true); + } + + public static VirtualFile refreshAndFindFile(@NotNull final File file) { + return UIUtil.invokeAndWaitIfNeeded(() -> LocalFileSystem.getInstance().refreshAndFindFileByIoFile(file)); + } + + public static void waitForAppLeakingThreads(long timeout, @NotNull TimeUnit timeUnit) { + EdtTestUtil.runInEdtAndWait(() -> { + Application app = ApplicationManager.getApplication(); + if (app != null && !app.isDisposed()) { + FileBasedIndexImpl index = (FileBasedIndexImpl)app.getServiceIfCreated(FileBasedIndex.class); + if (index != null) { + index.getChangedFilesCollector().waitForVfsEventsExecuted(timeout, timeUnit); + } + + DocumentCommitThread commitThread = (DocumentCommitThread)app.getServiceIfCreated(DocumentCommitProcessor.class); + if (commitThread != null) { + commitThread.waitForAllCommits(timeout, timeUnit); + } + } + }); + } + + protected class TestDisposable implements Disposable { + private volatile boolean myDisposed; + + public TestDisposable() { + } + + @Override + public void dispose() { + myDisposed = true; + } + + public boolean isDisposed() { + return myDisposed; + } + + @Override + public String toString() { + String testName = getTestName(false); + return KtUsefulTestCase.this.getClass() + (StringUtil.isEmpty(testName) ? "" : ".test" + testName); + } + } +} \ No newline at end of file diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightExitPointsHandlerFactory.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightExitPointsHandlerFactory.kt.202 new file mode 100644 index 00000000000..85feef012f7 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightExitPointsHandlerFactory.kt.202 @@ -0,0 +1,174 @@ +/* + * Copyright 2010-2015 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.idea.highlighter + +import com.intellij.codeInsight.highlighting.HighlightUsagesHandlerBase +import com.intellij.codeInsight.highlighting.HighlightUsagesHandlerFactoryBase +import com.intellij.openapi.editor.Editor +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.intellij.psi.impl.source.tree.LeafPsiElement +import com.intellij.psi.tree.TokenSet +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.util.Consumer +import org.jetbrains.kotlin.idea.caches.resolve.analyze +import org.jetbrains.kotlin.idea.references.mainReference +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.psiUtil.parents +import org.jetbrains.kotlin.resolve.bindingContextUtil.isUsedAsResultOfLambda +import org.jetbrains.kotlin.resolve.inline.InlineUtil +import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode + +class KotlinHighlightExitPointsHandlerFactory : HighlightUsagesHandlerFactoryBase() { + companion object { + private val RETURN_AND_THROW = TokenSet.create(KtTokens.RETURN_KEYWORD, KtTokens.THROW_KEYWORD) + + private fun getOnReturnOrThrowUsageHandler(editor: Editor, file: PsiFile, target: PsiElement): HighlightUsagesHandlerBase<*>? { + if (target !is LeafPsiElement || target.elementType !in RETURN_AND_THROW) { + return null + } + + val returnOrThrow = PsiTreeUtil.getParentOfType( + target, + KtReturnExpression::class.java, + KtThrowExpression::class.java + ) ?: return null + + return OnExitUsagesHandler(editor, file, returnOrThrow) + } + + private fun getOnLambdaCallUsageHandler(editor: Editor, file: PsiFile, target: PsiElement): HighlightUsagesHandlerBase<*>? { + if (target !is LeafPsiElement || target.elementType != KtTokens.IDENTIFIER) { + return null + } + + val refExpr = target.parent as? KtNameReferenceExpression ?: return null + val call = refExpr.parent as? KtCallExpression ?: return null + if (call.calleeExpression != refExpr) return null + + val lambda = call.lambdaArguments.singleOrNull() ?: return null + val literal = lambda.getLambdaExpression()?.functionLiteral ?: return null + + return OnExitUsagesHandler(editor, file, literal, highlightReferences = true) + } + } + + override fun createHighlightUsagesHandler(editor: Editor, file: PsiFile, target: PsiElement): HighlightUsagesHandlerBase<*>? { + return getOnReturnOrThrowUsageHandler(editor, file, target) + ?: getOnLambdaCallUsageHandler(editor, file, target) + } + + private class OnExitUsagesHandler(editor: Editor, file: PsiFile, val target: KtExpression, val highlightReferences: Boolean = false) : + HighlightUsagesHandlerBase(editor, file) { + + override fun getTargets() = listOf(target) + + override fun selectTargets(targets: MutableList, selectionConsumer: Consumer>) { + selectionConsumer.consume(targets) + } + + override fun computeUsages(targets: MutableList) { + val relevantFunction: KtDeclarationWithBody? = + if (target is KtFunctionLiteral) { + target + } else { + target.getRelevantDeclaration() + } + + relevantFunction?.accept(object : KtVisitorVoid() { + override fun visitKtElement(element: KtElement) { + element.acceptChildren(this) + } + + override fun visitExpression(expression: KtExpression) { + if (relevantFunction is KtFunctionLiteral) { + if (occurrenceForFunctionLiteralReturnExpression(expression)) { + return + } + } + + super.visitExpression(expression) + } + + private fun occurrenceForFunctionLiteralReturnExpression(expression: KtExpression): Boolean { + if (!KtPsiUtil.isStatement(expression)) return false + + if (expression is KtIfExpression || expression is KtWhenExpression || expression is KtBlockExpression) { + return false + } + + val bindingContext = expression.analyze(BodyResolveMode.FULL) + if (!expression.isUsedAsResultOfLambda(bindingContext)) { + return false + } + + if (expression.getRelevantDeclaration() != relevantFunction) { + return false + } + + addOccurrence(expression) + return true + } + + private fun visitReturnOrThrow(expression: KtExpression) { + if (expression.getRelevantDeclaration() == relevantFunction) { + addOccurrence(expression) + } + } + + override fun visitReturnExpression(expression: KtReturnExpression) { + visitReturnOrThrow(expression) + } + + override fun visitThrowExpression(expression: KtThrowExpression) { + visitReturnOrThrow(expression) + } + }) + } + + override fun highlightReferences() = highlightReferences + } +} + +private fun KtExpression.getRelevantDeclaration(): KtDeclarationWithBody? { + if (this is KtReturnExpression) { + (this.getTargetLabel()?.mainReference?.resolve() as? KtFunction)?.let { + return it + } + } + + if (this is KtThrowExpression || this is KtReturnExpression) { + for (parent in parents) { + if (parent is KtDeclarationWithBody) { + if (parent is KtPropertyAccessor) { + return parent + } + + if (InlineUtil.canBeInlineArgument(parent) && + !InlineUtil.isInlinedArgument(parent as KtFunction, parent.analyze(BodyResolveMode.FULL), false) + ) { + return parent + } + } + } + + return null + } + + return parents.filterIsInstance().firstOrNull() +} \ No newline at end of file diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightImplicitItHandlerFactory.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightImplicitItHandlerFactory.kt.202 new file mode 100644 index 00000000000..896d4c34918 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinHighlightImplicitItHandlerFactory.kt.202 @@ -0,0 +1,58 @@ +/* + * Copyright 2010-2017 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.idea.highlighter + +import com.intellij.codeInsight.highlighting.HighlightUsagesHandlerBase +import com.intellij.codeInsight.highlighting.HighlightUsagesHandlerFactoryBase +import com.intellij.openapi.editor.Editor +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.intellij.psi.impl.source.tree.LeafPsiElement +import com.intellij.util.Consumer +import org.jetbrains.kotlin.idea.intentions.getLambdaByImplicitItReference +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.KtNameReferenceExpression +import org.jetbrains.kotlin.psi.KtSimpleNameExpression +import org.jetbrains.kotlin.psi.KtTreeVisitorVoid + +class KotlinHighlightImplicitItHandlerFactory : HighlightUsagesHandlerFactoryBase() { + override fun createHighlightUsagesHandler(editor: Editor, file: PsiFile, target: PsiElement): HighlightUsagesHandlerBase<*>? { + if (!(target is LeafPsiElement && target.elementType == KtTokens.IDENTIFIER)) return null + val refExpr = target.parent as? KtNameReferenceExpression ?: return null + val lambda = getLambdaByImplicitItReference(refExpr) ?: return null + return object : HighlightUsagesHandlerBase(editor, file) { + override fun getTargets() = listOf(refExpr) + + override fun selectTargets( + targets: MutableList, + selectionConsumer: Consumer> + ) = selectionConsumer.consume(targets) + + override fun computeUsages(targets: MutableList) { + lambda.accept( + object : KtTreeVisitorVoid() { + override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { + if (expression is KtNameReferenceExpression && getLambdaByImplicitItReference(expression) == lambda) { + addOccurrence(expression) + } + } + } + ) + } + } + } +} \ No newline at end of file diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinRecursiveCallLineMarkerProvider.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinRecursiveCallLineMarkerProvider.kt.202 new file mode 100644 index 00000000000..a6e776f5460 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinRecursiveCallLineMarkerProvider.kt.202 @@ -0,0 +1,188 @@ +/* + * Copyright 2010-2015 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.kotlin.idea.highlighter + +import com.intellij.codeHighlighting.Pass +import com.intellij.codeInsight.daemon.LineMarkerInfo +import com.intellij.codeInsight.daemon.LineMarkerProvider +import com.intellij.icons.AllIcons +import com.intellij.openapi.editor.markup.GutterIconRenderer +import com.intellij.openapi.progress.ProgressManager +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.ClassDescriptor +import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor +import org.jetbrains.kotlin.idea.KotlinBundle +import org.jetbrains.kotlin.idea.caches.resolve.analyze +import org.jetbrains.kotlin.idea.inspections.RecursivePropertyAccessorInspection +import org.jetbrains.kotlin.idea.util.getReceiverTargetDescriptor +import org.jetbrains.kotlin.lexer.KtToken +import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.psiUtil.parents +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.inline.InlineUtil +import org.jetbrains.kotlin.resolve.scopes.receivers.Receiver +import org.jetbrains.kotlin.resolve.scopes.receivers.ReceiverValue +import org.jetbrains.kotlin.types.expressions.OperatorConventions +import org.jetbrains.kotlin.util.OperatorNameConventions +import java.util.* + +class KotlinRecursiveCallLineMarkerProvider : LineMarkerProvider { + override fun getLineMarkerInfo(element: PsiElement) = null + + override fun collectSlowLineMarkers(elements: MutableList, result: MutableCollection>) { + val markedLineNumbers = HashSet() + + for (element in elements) { + ProgressManager.checkCanceled() + if (element is KtElement) { + val lineNumber = element.getLineNumber() + if (lineNumber !in markedLineNumbers && isRecursiveCall(element)) { + markedLineNumbers.add(lineNumber) + result.add(RecursiveMethodCallMarkerInfo(getElementForLineMark(element))) + } + } + } + } + + private fun getEnclosingFunction(element: KtElement, stopOnNonInlinedLambdas: Boolean): KtNamedFunction? { + for (parent in element.parents) { + when (parent) { + is KtFunctionLiteral -> if (stopOnNonInlinedLambdas && !InlineUtil.isInlinedArgument( + parent, + parent.analyze(), + false + ) + ) return null + is KtNamedFunction -> { + when (parent.parent) { + is KtBlockExpression, is KtClassBody, is KtFile, is KtScript -> return parent + else -> if (stopOnNonInlinedLambdas && !InlineUtil.isInlinedArgument(parent, parent.analyze(), false)) return null + } + } + is KtClassOrObject -> return null + } + } + return null + } + + private fun isRecursiveCall(element: KtElement): Boolean { + if (RecursivePropertyAccessorInspection.isRecursivePropertyAccess(element)) return true + if (RecursivePropertyAccessorInspection.isRecursiveSyntheticPropertyAccess(element)) return true + // Fast check for names without resolve + val resolveName = getCallNameFromPsi(element) ?: return false + val enclosingFunction = getEnclosingFunction(element, false) ?: return false + + val enclosingFunctionName = enclosingFunction.name + if (enclosingFunctionName != OperatorNameConventions.INVOKE.asString() + && enclosingFunctionName != resolveName.asString() + ) return false + + // Check that there were no not-inlined lambdas on the way to enclosing function + if (enclosingFunction != getEnclosingFunction(element, true)) return false + + val bindingContext = element.analyze() + val enclosingFunctionDescriptor = bindingContext[BindingContext.FUNCTION, enclosingFunction] ?: return false + + val call = bindingContext[BindingContext.CALL, element] ?: return false + val resolvedCall = bindingContext[BindingContext.RESOLVED_CALL, call] ?: return false + + if (resolvedCall.candidateDescriptor.original != enclosingFunctionDescriptor) return false + + fun isDifferentReceiver(receiver: Receiver?): Boolean { + if (receiver !is ReceiverValue) return false + + val receiverOwner = receiver.getReceiverTargetDescriptor(bindingContext) ?: return true + + return when (receiverOwner) { + is SimpleFunctionDescriptor -> receiverOwner != enclosingFunctionDescriptor + is ClassDescriptor -> receiverOwner != enclosingFunctionDescriptor.containingDeclaration + else -> return true + } + } + + if (isDifferentReceiver(resolvedCall.dispatchReceiver)) return false + return true + } + + private class RecursiveMethodCallMarkerInfo(callElement: PsiElement) : LineMarkerInfo( + callElement, + callElement.textRange, + AllIcons.Gutter.RecursiveMethod, + Pass.LINE_MARKERS, + { KotlinBundle.message("highlighter.tool.tip.text.recursive.call") }, + null, + GutterIconRenderer.Alignment.RIGHT + ) { + + override fun createGutterRenderer(): GutterIconRenderer? { + return object : LineMarkerInfo.LineMarkerGutterIconRenderer(this) { + override fun getClickAction() = null // to place breakpoint on mouse click + } + } + } + +} + +internal fun getElementForLineMark(callElement: PsiElement): PsiElement = + when (callElement) { + is KtSimpleNameExpression -> callElement.getReferencedNameElement() + else -> + // a fallback, + //but who knows what to reference in KtArrayAccessExpression ? + generateSequence(callElement, { it.firstChild }).last() + } + +private fun PsiElement.getLineNumber(): Int { + return PsiDocumentManager.getInstance(project).getDocument(containingFile)!!.getLineNumber(textOffset) +} + +private fun getCallNameFromPsi(element: KtElement): Name? { + when (element) { + is KtSimpleNameExpression -> { + val elementParent = element.getParent() + when (elementParent) { + is KtCallExpression -> return Name.identifier(element.getText()) + is KtOperationExpression -> { + val operationReference = elementParent.operationReference + if (element == operationReference) { + val node = operationReference.getReferencedNameElementType() + return if (node is KtToken) { + val conventionName = if (elementParent is KtPrefixExpression) + OperatorConventions.getNameForOperationSymbol(node, true, false) + else + OperatorConventions.getNameForOperationSymbol(node) + + conventionName ?: Name.identifier(element.getText()) + } else { + Name.identifier(element.getText()) + } + } + } + } + } + is KtArrayAccessExpression -> + return OperatorNameConventions.GET + is KtThisExpression -> + if (element.getParent() is KtCallExpression) { + return OperatorNameConventions.INVOKE + } + } + + return null +} diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinSuspendCallLineMarkerProvider.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinSuspendCallLineMarkerProvider.kt.202 new file mode 100644 index 00000000000..24ea7867362 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/KotlinSuspendCallLineMarkerProvider.kt.202 @@ -0,0 +1,125 @@ +/* + * Copyright 2010-2018 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.idea.highlighter + +import com.intellij.codeHighlighting.Pass +import com.intellij.codeInsight.daemon.LineMarkerInfo +import com.intellij.codeInsight.daemon.LineMarkerProvider +import com.intellij.openapi.actionSystem.AnAction +import com.intellij.openapi.editor.markup.GutterIconRenderer +import com.intellij.openapi.progress.ProgressManager +import com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.descriptors.PropertyDescriptor +import org.jetbrains.kotlin.descriptors.VariableDescriptorWithAccessors +import org.jetbrains.kotlin.descriptors.accessors +import org.jetbrains.kotlin.idea.KotlinBundle +import org.jetbrains.kotlin.idea.KotlinIcons +import org.jetbrains.kotlin.idea.caches.resolve.analyze +import org.jetbrains.kotlin.idea.refactoring.getLineNumber +import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.BindingContext.* +import org.jetbrains.kotlin.resolve.calls.callUtil.getResolvedCall +import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe +import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode + +class KotlinSuspendCallLineMarkerProvider : LineMarkerProvider { + private class SuspendCallMarkerInfo(callElement: PsiElement, message: String) : LineMarkerInfo( + callElement, + callElement.textRange, + KotlinIcons.SUSPEND_CALL, + Pass.LINE_MARKERS, + { message }, + null, + GutterIconRenderer.Alignment.RIGHT + ) { + override fun createGutterRenderer(): GutterIconRenderer? { + return object : LineMarkerInfo.LineMarkerGutterIconRenderer(this) { + override fun getClickAction(): AnAction? = null + } + } + } + + override fun getLineMarkerInfo(element: PsiElement): LineMarkerInfo<*>? = null + + override fun collectSlowLineMarkers( + elements: MutableList, + result: MutableCollection> + ) { + val markedLineNumbers = HashSet() + + for (element in elements) { + ProgressManager.checkCanceled() + + if (element !is KtExpression) continue + + val containingFile = element.containingFile + if (containingFile !is KtFile || containingFile is KtCodeFragment) { + continue + } + + val lineNumber = element.getLineNumber() + if (lineNumber in markedLineNumbers) continue + if (!element.hasSuspendCalls()) continue + + markedLineNumbers += lineNumber + result += if (element is KtForExpression) { + SuspendCallMarkerInfo( + getElementForLineMark(element.loopRange!!), + KotlinBundle.message("highlighter.message.suspending.iteration") + ) + } else { + SuspendCallMarkerInfo(getElementForLineMark(element), KotlinBundle.message("highlighter.message.suspend.function.call")) + } + } + } +} + +private fun KtExpression.isValidCandidateExpression(): Boolean { + if (this is KtParenthesizedExpression) return false + if (this is KtOperationReferenceExpression || this is KtForExpression || this is KtProperty || this is KtNameReferenceExpression) return true + val parent = parent + if (parent is KtCallExpression && parent.calleeExpression == this) return true + if (this is KtCallExpression && (calleeExpression is KtCallExpression || calleeExpression is KtParenthesizedExpression)) return true + return false +} + +fun KtExpression.hasSuspendCalls(bindingContext: BindingContext = analyze(BodyResolveMode.PARTIAL)): Boolean { + if (!isValidCandidateExpression()) return false + + return when (this) { + is KtForExpression -> { + val iteratorResolvedCall = bindingContext[LOOP_RANGE_ITERATOR_RESOLVED_CALL, loopRange] + val loopRangeHasNextResolvedCall = bindingContext[LOOP_RANGE_HAS_NEXT_RESOLVED_CALL, loopRange] + val loopRangeNextResolvedCall = bindingContext[LOOP_RANGE_NEXT_RESOLVED_CALL, loopRange] + listOf(iteratorResolvedCall, loopRangeHasNextResolvedCall, loopRangeNextResolvedCall).any { + it?.resultingDescriptor?.isSuspend == true + } + } + is KtProperty -> { + if (hasDelegateExpression()) { + val variableDescriptor = bindingContext[DECLARATION_TO_DESCRIPTOR, this] as? VariableDescriptorWithAccessors + val accessors = variableDescriptor?.accessors ?: emptyList() + accessors.any { accessor -> + val delegatedFunctionDescriptor = bindingContext[DELEGATED_PROPERTY_RESOLVED_CALL, accessor]?.resultingDescriptor + delegatedFunctionDescriptor?.isSuspend == true + } + } else { + false + } + } + else -> { + val resolvedCall = getResolvedCall(bindingContext) + if ((resolvedCall?.resultingDescriptor as? FunctionDescriptor)?.isSuspend == true) true + else { + val propertyDescriptor = resolvedCall?.resultingDescriptor as? PropertyDescriptor + val s = propertyDescriptor?.fqNameSafe?.asString() + s?.startsWith("kotlin.coroutines.") == true && s.endsWith(".coroutineContext") + } + } + } +} diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/DslHighlightingMarker.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/DslHighlightingMarker.kt.202 new file mode 100644 index 00000000000..355b4972bd1 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/DslHighlightingMarker.kt.202 @@ -0,0 +1,65 @@ +/* + * Copyright 2000-2018 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.idea.highlighter.markers + +import com.intellij.application.options.colors.ColorAndFontOptions +import com.intellij.codeHighlighting.Pass +import com.intellij.codeInsight.daemon.GutterIconNavigationHandler +import com.intellij.codeInsight.daemon.LineMarkerInfo +import com.intellij.ide.DataManager +import com.intellij.openapi.editor.markup.GutterIconRenderer +import com.intellij.psi.PsiElement +import com.intellij.util.Function +import org.jetbrains.kotlin.descriptors.ClassDescriptor +import org.jetbrains.kotlin.descriptors.ClassKind +import org.jetbrains.kotlin.idea.KotlinLanguage +import org.jetbrains.kotlin.idea.core.toDescriptor +import org.jetbrains.kotlin.idea.KotlinBundle +import org.jetbrains.kotlin.idea.highlighter.dsl.DslHighlighterExtension +import org.jetbrains.kotlin.idea.highlighter.dsl.isDslHighlightingMarker +import org.jetbrains.kotlin.psi.KtClass +import javax.swing.JComponent + +private val navHandler = GutterIconNavigationHandler { event, element -> + val dataContext = (event.component as? JComponent)?.let { DataManager.getInstance().getDataContext(it) } + ?: return@GutterIconNavigationHandler + val ktClass = element?.parent as? KtClass ?: return@GutterIconNavigationHandler + val styleId = ktClass.styleIdForMarkerAnnotation() ?: return@GutterIconNavigationHandler + ColorAndFontOptions.selectOrEditColor(dataContext, DslHighlighterExtension.styleOptionDisplayName(styleId), KotlinLanguage.NAME) +} + +private val toolTipHandler = Function { + KotlinBundle.message("highlighter.tool.tip.marker.annotation.for.dsl") +} + +fun collectHighlightingColorsMarkers( + ktClass: KtClass, + result: MutableCollection> +) { + if (!KotlinLineMarkerOptions.dslOption.isEnabled) return + + val styleId = ktClass.styleIdForMarkerAnnotation() ?: return + + val anchor = ktClass.nameIdentifier ?: return + + result.add( + LineMarkerInfo( + anchor, + anchor.textRange, + createDslStyleIcon(styleId), + Pass.LINE_MARKERS, + toolTipHandler, navHandler, + GutterIconRenderer.Alignment.RIGHT + ) + ) +} + +private fun KtClass.styleIdForMarkerAnnotation(): Int? { + val classDescriptor = toDescriptor() as? ClassDescriptor ?: return null + if (classDescriptor.kind != ClassKind.ANNOTATION_CLASS) return null + if (!classDescriptor.isDslHighlightingMarker()) return null + return DslHighlighterExtension.styleIdByMarkerAnnotation(classDescriptor) +} \ No newline at end of file diff --git a/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/KotlinLineMarkerProvider.kt.202 b/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/KotlinLineMarkerProvider.kt.202 new file mode 100644 index 00000000000..acbbc399a25 --- /dev/null +++ b/idea/src/org/jetbrains/kotlin/idea/highlighter/markers/KotlinLineMarkerProvider.kt.202 @@ -0,0 +1,561 @@ +/* + * Copyright 2000-2018 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package org.jetbrains.kotlin.idea.highlighter.markers + +import com.intellij.codeHighlighting.Pass +import com.intellij.codeInsight.daemon.* +import com.intellij.codeInsight.daemon.impl.LineMarkerNavigator +import com.intellij.codeInsight.daemon.impl.MarkerType +import com.intellij.codeInsight.daemon.impl.PsiElementListNavigator +import com.intellij.codeInsight.navigation.ListBackgroundUpdaterTask +import com.intellij.openapi.actionSystem.IdeActions +import com.intellij.openapi.editor.Document +import com.intellij.openapi.editor.colors.CodeInsightColors +import com.intellij.openapi.editor.colors.EditorColorsManager +import com.intellij.openapi.editor.markup.GutterIconRenderer +import com.intellij.openapi.editor.markup.SeparatorPlacement +import com.intellij.openapi.progress.ProgressManager +import com.intellij.openapi.project.DumbService +import com.intellij.openapi.util.text.StringUtil +import com.intellij.psi.* +import com.intellij.psi.search.searches.ClassInheritorsSearch +import com.intellij.psi.util.PsiTreeUtil +import org.jetbrains.kotlin.asJava.LightClassUtil +import org.jetbrains.kotlin.asJava.toLightClass +import org.jetbrains.kotlin.descriptors.CallableMemberDescriptor +import org.jetbrains.kotlin.descriptors.MemberDescriptor +import org.jetbrains.kotlin.descriptors.Modality +import org.jetbrains.kotlin.idea.caches.lightClasses.KtFakeLightClass +import org.jetbrains.kotlin.idea.caches.lightClasses.KtFakeLightMethod +import org.jetbrains.kotlin.idea.caches.project.implementedDescriptors +import org.jetbrains.kotlin.idea.caches.project.implementingDescriptors +import org.jetbrains.kotlin.idea.caches.resolve.findModuleDescriptor +import org.jetbrains.kotlin.idea.core.isInheritable +import org.jetbrains.kotlin.idea.core.isOverridable +import org.jetbrains.kotlin.idea.core.toDescriptor +import org.jetbrains.kotlin.idea.editor.fixers.startLine +import org.jetbrains.kotlin.idea.KotlinBundle +import org.jetbrains.kotlin.idea.presentation.DeclarationByModuleRenderer +import org.jetbrains.kotlin.idea.search.declarationsSearch.toPossiblyFakeLightMethods +import org.jetbrains.kotlin.idea.util.* +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.psiUtil.containingClassOrObject +import org.jetbrains.kotlin.psi.psiUtil.getPrevSiblingIgnoringWhitespaceAndComments +import java.awt.event.MouseEvent +import java.util.* +import javax.swing.ListCellRenderer + +class KotlinLineMarkerProvider : LineMarkerProviderDescriptor() { + override fun getName() = KotlinBundle.message("highlighter.name.kotlin.line.markers") + + override fun getOptions(): Array