From a049fda75bb30bed6ad5c87d9f1cdc9028579f1b Mon Sep 17 00:00:00 2001 From: Vladimir Dolzhenko Date: Thu, 2 Feb 2023 17:15:12 +0100 Subject: [PATCH] Create SimpleFunctionDescriptorImpl under nonCancelableSection SimpleFunctionDescriptorImpl initialization consists of two phases: ctor + initialize. When SimpleFunctionDescriptorImpl is created wrapped descriptor (e.g. ValueParameterDescriptorImpl) is leaked through bindingTrace with not fully initialized `containingDeclaration` (that is SimpleFunctionDescriptorImpl). If PCE happens after this unsafe publication prior to `initialize` then it will be case with NPE on fully initialized instance reading. #KT-56364 Fixed --- .../resolve/FunctionDescriptorResolver.kt | 347 ++++++++++-------- .../expressions/FunctionsTypingVisitor.kt | 38 +- 2 files changed, 210 insertions(+), 175 deletions(-) diff --git a/compiler/frontend/src/org/jetbrains/kotlin/resolve/FunctionDescriptorResolver.kt b/compiler/frontend/src/org/jetbrains/kotlin/resolve/FunctionDescriptorResolver.kt index 1a43ce515f0..3b100898f12 100644 --- a/compiler/frontend/src/org/jetbrains/kotlin/resolve/FunctionDescriptorResolver.kt +++ b/compiler/frontend/src/org/jetbrains/kotlin/resolve/FunctionDescriptorResolver.kt @@ -17,6 +17,8 @@ package org.jetbrains.kotlin.resolve import com.google.common.collect.HashMultimap +import com.intellij.openapi.diagnostic.ControlFlowException +import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.util.ThrowableComputable import com.intellij.psi.PsiElement import com.intellij.util.AstLoadingFilter @@ -71,6 +73,7 @@ import org.jetbrains.kotlin.types.isError import org.jetbrains.kotlin.types.typeUtil.replaceAnnotations import java.util.* + class FunctionDescriptorResolver( private val typeResolver: TypeResolver, private val descriptorResolver: DescriptorResolver, @@ -129,19 +132,21 @@ class FunctionDescriptorResolver( CallableMemberDescriptor.Kind.DECLARATION, function.toSourceElement() ) - initializeFunctionDescriptorAndExplicitReturnType( - containingDescriptor, - scope, - function, - functionDescriptor, - trace, - expectedFunctionType, - dataFlowInfo, - inferenceSession - ) - initializeFunctionReturnTypeBasedOnFunctionBody(scope, function, functionDescriptor, trace, dataFlowInfo, inferenceSession) - BindingContextUtils.recordFunctionDeclarationToDescriptor(trace, function, functionDescriptor) - return functionDescriptor + return computeInNonCancelableSection { + initializeFunctionDescriptorAndExplicitReturnType( + containingDescriptor, + scope, + function, + functionDescriptor, + trace, + expectedFunctionType, + dataFlowInfo, + inferenceSession + ) + initializeFunctionReturnTypeBasedOnFunctionBody(scope, function, functionDescriptor, trace, dataFlowInfo, inferenceSession) + BindingContextUtils.recordFunctionDeclarationToDescriptor(trace, function, functionDescriptor) + functionDescriptor + } } private fun initializeFunctionReturnTypeBasedOnFunctionBody( @@ -182,117 +187,124 @@ class FunctionDescriptorResolver( dataFlowInfo: DataFlowInfo, inferenceSession: InferenceSession? ) { - val headerScope = LexicalWritableScope( - scope, functionDescriptor, true, - TraceBasedLocalRedeclarationChecker(trace, overloadChecker), LexicalScopeKind.FUNCTION_HEADER - ) + try { + val headerScope = LexicalWritableScope( + scope, functionDescriptor, true, + TraceBasedLocalRedeclarationChecker(trace, overloadChecker), LexicalScopeKind.FUNCTION_HEADER + ) - val typeParameterDescriptors = - descriptorResolver.resolveTypeParametersForDescriptor(functionDescriptor, headerScope, scope, function.typeParameters, trace) - descriptorResolver.resolveGenericBounds(function, functionDescriptor, headerScope, typeParameterDescriptors, trace) + val typeParameterDescriptors = + descriptorResolver.resolveTypeParametersForDescriptor(functionDescriptor, headerScope, scope, function.typeParameters, trace) + descriptorResolver.resolveGenericBounds(function, functionDescriptor, headerScope, typeParameterDescriptors, trace) - val receiverTypeRef = function.receiverTypeReference - val receiverType = - if (receiverTypeRef != null) { - typeResolver.resolveType(headerScope, receiverTypeRef, trace, true) - } else { - if (function is KtFunctionLiteral) expectedFunctionType.getReceiverType() else null - } - - val contextReceivers = function.contextReceivers - val contextReceiverTypes = - if (function is KtFunctionLiteral) expectedFunctionType.getContextReceiversTypes() - else contextReceivers - .mapNotNull { - val typeReference = it.typeReference() ?: return@mapNotNull null - val type = typeResolver.resolveType(headerScope, typeReference, trace, true) - ContextReceiverTypeWithLabel(type, it.labelNameAsName()) + val receiverTypeRef = function.receiverTypeReference + val receiverType = + if (receiverTypeRef != null) { + typeResolver.resolveType(headerScope, receiverTypeRef, trace, true) + } else { + if (function is KtFunctionLiteral) expectedFunctionType.getReceiverType() else null } + val contextReceivers = function.contextReceivers + val contextReceiverTypes = + if (function is KtFunctionLiteral) expectedFunctionType.getContextReceiversTypes() + else contextReceivers + .mapNotNull { + val typeReference = it.typeReference() ?: return@mapNotNull null + val type = typeResolver.resolveType(headerScope, typeReference, trace, true) + ContextReceiverTypeWithLabel(type, it.labelNameAsName()) + } - val valueParameterDescriptors = - createValueParameterDescriptors(function, functionDescriptor, headerScope, trace, expectedFunctionType, inferenceSession) - headerScope.freeze() + val valueParameterDescriptors = + createValueParameterDescriptors(function, functionDescriptor, headerScope, trace, expectedFunctionType, inferenceSession) - val returnType = function.typeReference?.let { typeResolver.resolveType(headerScope, it, trace, true) } + headerScope.freeze() - val visibility = resolveVisibilityFromModifiers(function, getDefaultVisibility(function, container)) - val modality = resolveMemberModalityFromModifiers( - function, getDefaultModality(container, visibility, function.hasBody()), - trace.bindingContext, container - ) + val returnType = function.typeReference?.let { typeResolver.resolveType(headerScope, it, trace, true) } - val contractProvider = getContractProvider(functionDescriptor, trace, scope, dataFlowInfo, function, inferenceSession) - val userData = mutableMapOf, Any>().apply { - if (contractProvider != null) { - put(ContractProviderKey, contractProvider) - } - - if (receiverType != null && expectedFunctionType.functionTypeExpected() && !expectedFunctionType.annotations.isEmpty()) { - put(DslMarkerUtils.FunctionTypeAnnotationsKey, expectedFunctionType.annotations) - } - } - - val extensionReceiver = receiverType?.let { - val splitter = AnnotationSplitter(storageManager, it.annotations, EnumSet.of(AnnotationUseSiteTarget.RECEIVER)) - DescriptorFactory.createExtensionReceiverParameterForCallable( - functionDescriptor, it, splitter.getAnnotationsForTarget(AnnotationUseSiteTarget.RECEIVER) + val visibility = resolveVisibilityFromModifiers(function, getDefaultVisibility(function, container)) + val modality = resolveMemberModalityFromModifiers( + function, getDefaultModality(container, visibility, function.hasBody()), + trace.bindingContext, container ) - } - val contextReceiverDescriptors = contextReceiverTypes.mapIndexedNotNull { index, contextReceiver -> - val splitter = AnnotationSplitter(storageManager, contextReceiver.type.annotations, EnumSet.of(AnnotationUseSiteTarget.RECEIVER)) - DescriptorFactory.createContextReceiverParameterForCallable( - functionDescriptor, - contextReceiver.type, - contextReceiver.label, - splitter.getAnnotationsForTarget(AnnotationUseSiteTarget.RECEIVER), - index - ) - } - if (languageVersionSettings.supportsFeature(LanguageFeature.ContextReceivers)) { - val labelNameToReceiverMap = HashMultimap.create() - if (receiverTypeRef != null && extensionReceiver != null) { - receiverTypeRef.nameForReceiverLabel()?.let { - labelNameToReceiverMap.put(it, extensionReceiver) + val contractProvider = getContractProvider(functionDescriptor, trace, scope, dataFlowInfo, function, inferenceSession) + val userData = mutableMapOf, Any>().apply { + if (contractProvider != null) { + put(ContractProviderKey, contractProvider) + } + + if (receiverType != null && expectedFunctionType.functionTypeExpected() && !expectedFunctionType.annotations.isEmpty()) { + put(DslMarkerUtils.FunctionTypeAnnotationsKey, expectedFunctionType.annotations) } } - contextReceiverDescriptors.zip(0 until contextReceivers.size).reversed() - .forEach { (contextReceiverDescriptor, i) -> - contextReceivers[i].name()?.let { - labelNameToReceiverMap.put(it, contextReceiverDescriptor) + + val extensionReceiver = receiverType?.let { + val splitter = AnnotationSplitter(storageManager, it.annotations, EnumSet.of(AnnotationUseSiteTarget.RECEIVER)) + DescriptorFactory.createExtensionReceiverParameterForCallable( + functionDescriptor, it, splitter.getAnnotationsForTarget(AnnotationUseSiteTarget.RECEIVER) + ) + } + val contextReceiverDescriptors = contextReceiverTypes.mapIndexedNotNull { index, contextReceiver -> + val splitter = AnnotationSplitter(storageManager, contextReceiver.type.annotations, EnumSet.of(AnnotationUseSiteTarget.RECEIVER)) + DescriptorFactory.createContextReceiverParameterForCallable( + functionDescriptor, + contextReceiver.type, + contextReceiver.label, + splitter.getAnnotationsForTarget(AnnotationUseSiteTarget.RECEIVER), + index + ) + } + + if (languageVersionSettings.supportsFeature(LanguageFeature.ContextReceivers)) { + val labelNameToReceiverMap = HashMultimap.create() + if (receiverTypeRef != null && extensionReceiver != null) { + receiverTypeRef.nameForReceiverLabel()?.let { + labelNameToReceiverMap.put(it, extensionReceiver) } } + contextReceiverDescriptors.zip(0 until contextReceivers.size).reversed() + .forEach { (contextReceiverDescriptor, i) -> + contextReceivers[i].name()?.let { + labelNameToReceiverMap.put(it, contextReceiverDescriptor) + } + } - trace.record(BindingContext.DESCRIPTOR_TO_CONTEXT_RECEIVER_MAP, functionDescriptor, labelNameToReceiverMap) - } + trace.record(BindingContext.DESCRIPTOR_TO_CONTEXT_RECEIVER_MAP, functionDescriptor, labelNameToReceiverMap) + } - functionDescriptor.initialize( - extensionReceiver, - getDispatchReceiverParameterIfNeeded(container), - contextReceiverDescriptors, - typeParameterDescriptors, - valueParameterDescriptors, - returnType, - modality, - visibility, - userData.takeIf { it.isNotEmpty() } - ) + functionDescriptor.initialize( + extensionReceiver, + getDispatchReceiverParameterIfNeeded(container), + contextReceiverDescriptors, + typeParameterDescriptors, + valueParameterDescriptors, + returnType, + modality, + visibility, + userData.takeIf { it.isNotEmpty() } + ) - functionDescriptor.isOperator = function.hasModifier(KtTokens.OPERATOR_KEYWORD) - functionDescriptor.isInfix = function.hasModifier(KtTokens.INFIX_KEYWORD) - functionDescriptor.isExternal = function.hasModifier(KtTokens.EXTERNAL_KEYWORD) - functionDescriptor.isInline = function.hasModifier(KtTokens.INLINE_KEYWORD) - functionDescriptor.isTailrec = function.hasModifier(KtTokens.TAILREC_KEYWORD) - functionDescriptor.isSuspend = function.hasModifier(KtTokens.SUSPEND_KEYWORD) - functionDescriptor.isExpect = container is PackageFragmentDescriptor && function.hasExpectModifier() || - container is ClassDescriptor && container.isExpect - functionDescriptor.isActual = function.hasActualModifier() + functionDescriptor.isOperator = function.hasModifier(KtTokens.OPERATOR_KEYWORD) + functionDescriptor.isInfix = function.hasModifier(KtTokens.INFIX_KEYWORD) + functionDescriptor.isExternal = function.hasModifier(KtTokens.EXTERNAL_KEYWORD) + functionDescriptor.isInline = function.hasModifier(KtTokens.INLINE_KEYWORD) + functionDescriptor.isTailrec = function.hasModifier(KtTokens.TAILREC_KEYWORD) + functionDescriptor.isSuspend = function.hasModifier(KtTokens.SUSPEND_KEYWORD) + functionDescriptor.isExpect = container is PackageFragmentDescriptor && function.hasExpectModifier() || + container is ClassDescriptor && container.isExpect + functionDescriptor.isActual = function.hasActualModifier() - receiverType?.let { ForceResolveUtil.forceResolveAllContents(it.annotations) } - for (valueParameterDescriptor in valueParameterDescriptors) { - ForceResolveUtil.forceResolveAllContents(valueParameterDescriptor.type.annotations) + receiverType?.let { ForceResolveUtil.forceResolveAllContents(it.annotations) } + for (valueParameterDescriptor in valueParameterDescriptors) { + ForceResolveUtil.forceResolveAllContents(valueParameterDescriptor.type.annotations) + } + } catch (e: Exception) { + if (e is ControlFlowException) { + throw IllegalStateException("Method should be run under nonCancelableSection", e) + } + throw e } } @@ -444,8 +456,6 @@ class FunctionDescriptorResolver( constructorDescriptor.isActual = modifierList?.hasActualModifier() == true || // We don't require 'actual' for constructors of actual annotations classDescriptor.kind == ClassKind.ANNOTATION_CLASS && classDescriptor.isActual - if (declarationToTrace is PsiElement) - trace.record(BindingContext.CONSTRUCTOR, declarationToTrace, constructorDescriptor) val parameterScope = LexicalWritableScope( scope, constructorDescriptor, @@ -453,20 +463,27 @@ class FunctionDescriptorResolver( TraceBasedLocalRedeclarationChecker(trace, overloadChecker), LexicalScopeKind.CONSTRUCTOR_HEADER ) - val constructor = constructorDescriptor.initialize( - resolveValueParameters( - constructorDescriptor, parameterScope, valueParameters, trace, null, inferenceSession - ), - resolveVisibilityFromModifiers( - modifierList, - DescriptorUtils.getDefaultConstructorVisibility(classDescriptor, languageVersionSettings.supportsFeature(LanguageFeature.AllowSealedInheritorsInDifferentFilesOfSamePackage)) + return computeInNonCancelableSection { + if (declarationToTrace is PsiElement) + trace.record(BindingContext.CONSTRUCTOR, declarationToTrace, constructorDescriptor) + val constructor = constructorDescriptor.initialize( + resolveValueParameters( + constructorDescriptor, parameterScope, valueParameters, trace, null, inferenceSession + ), + resolveVisibilityFromModifiers( + modifierList, + DescriptorUtils.getDefaultConstructorVisibility( + classDescriptor, + languageVersionSettings.supportsFeature(LanguageFeature.AllowSealedInheritorsInDifferentFilesOfSamePackage) + ) + ) ) - ) - constructor.returnType = classDescriptor.defaultType - if (DescriptorUtils.isAnnotationClass(classDescriptor)) { - CompileTimeConstantUtils.checkConstructorParametersType(valueParameters, trace) + constructor.returnType = classDescriptor.defaultType + if (DescriptorUtils.isAnnotationClass(classDescriptor)) { + CompileTimeConstantUtils.checkConstructorParametersType(valueParameters, trace) + } + constructor } - return constructor } private fun resolveValueParameters( @@ -477,57 +494,67 @@ class FunctionDescriptorResolver( expectedParameterTypes: List?, inferenceSession: InferenceSession? ): List { - val result = ArrayList() + try { + val result = ArrayList() - for (i in valueParameters.indices) { - val valueParameter = valueParameters[i] - val typeReference = valueParameter.typeReference - val expectedType = expectedParameterTypes?.let { if (i < it.size) it[i] else null }?.takeUnless { TypeUtils.noExpectedType(it) } + for (i in valueParameters.indices) { + val valueParameter = valueParameters[i] + val typeReference = valueParameter.typeReference + val expectedType = expectedParameterTypes?.let { if (i < it.size) it[i] else null }?.takeUnless { TypeUtils.noExpectedType(it) } - val type: KotlinType - if (typeReference != null) { - type = typeResolver.resolveType(parameterScope, typeReference, trace, true) - if (expectedType != null) { - if (!KotlinTypeChecker.DEFAULT.isSubtypeOf(expectedType, type)) { - trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(valueParameter, expectedType)) + val type: KotlinType + if (typeReference != null) { + type = typeResolver.resolveType(parameterScope, typeReference, trace, true) + if (expectedType != null) { + if (!KotlinTypeChecker.DEFAULT.isSubtypeOf(expectedType, type)) { + trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(valueParameter, expectedType)) + } } - } - } else { - type = if (isFunctionLiteral(functionDescriptor) || isFunctionExpression(functionDescriptor)) { - val containsErrorType = TypeUtils.contains(expectedType) { it.isError } - if (expectedType == null || containsErrorType) { - trace.report(CANNOT_INFER_PARAMETER_TYPE.on(valueParameter)) - } - - expectedType ?: TypeUtils.CANNOT_INFER_FUNCTION_PARAM_TYPE } else { - trace.report(VALUE_PARAMETER_WITH_NO_TYPE_ANNOTATION.on(valueParameter)) - ErrorUtils.createErrorType(ErrorTypeKind.MISSED_TYPE_FOR_PARAMETER, valueParameter.nameAsSafeName.toString()) + type = if (isFunctionLiteral(functionDescriptor) || isFunctionExpression(functionDescriptor)) { + val containsErrorType = TypeUtils.contains(expectedType) { it.isError } + if (expectedType == null || containsErrorType) { + trace.report(CANNOT_INFER_PARAMETER_TYPE.on(valueParameter)) + } + + expectedType ?: TypeUtils.CANNOT_INFER_FUNCTION_PARAM_TYPE + } else { + trace.report(VALUE_PARAMETER_WITH_NO_TYPE_ANNOTATION.on(valueParameter)) + ErrorUtils.createErrorType(ErrorTypeKind.MISSED_TYPE_FOR_PARAMETER, valueParameter.nameAsSafeName.toString()) + } } - } - if (functionDescriptor !is ConstructorDescriptor || !functionDescriptor.isPrimary) { - val isConstructor = functionDescriptor is ConstructorDescriptor - with(modifiersChecker.withTrace(trace)) { - checkParameterHasNoValOrVar( - valueParameter, - if (isConstructor) VAL_OR_VAR_ON_SECONDARY_CONSTRUCTOR_PARAMETER else VAL_OR_VAR_ON_FUN_PARAMETER - ) + if (functionDescriptor !is ConstructorDescriptor || !functionDescriptor.isPrimary) { + val isConstructor = functionDescriptor is ConstructorDescriptor + with(modifiersChecker.withTrace(trace)) { + checkParameterHasNoValOrVar( + valueParameter, + if (isConstructor) VAL_OR_VAR_ON_SECONDARY_CONSTRUCTOR_PARAMETER else VAL_OR_VAR_ON_FUN_PARAMETER + ) + } } + + val valueParameterDescriptor = descriptorResolver.resolveValueParameterDescriptor( + parameterScope, functionDescriptor, valueParameter, i, type, trace, Annotations.EMPTY, inferenceSession + ) + + // Do not report NAME_SHADOWING for lambda destructured parameters as they may be not fully resolved at this time + ExpressionTypingUtils.checkVariableShadowing(parameterScope, trace, valueParameterDescriptor) + + parameterScope.addVariableDescriptor(valueParameterDescriptor) + result.add(valueParameterDescriptor) } - - val valueParameterDescriptor = descriptorResolver.resolveValueParameterDescriptor( - parameterScope, functionDescriptor, valueParameter, i, type, trace, Annotations.EMPTY, inferenceSession - ) - - // Do not report NAME_SHADOWING for lambda destructured parameters as they may be not fully resolved at this time - ExpressionTypingUtils.checkVariableShadowing(parameterScope, trace, valueParameterDescriptor) - - parameterScope.addVariableDescriptor(valueParameterDescriptor) - result.add(valueParameterDescriptor) + return result + } catch (e: Exception) { + if (e is ControlFlowException) { + throw IllegalStateException("Method should be run under nonCancelableSection", e) + } + throw e } - return result } private data class ContextReceiverTypeWithLabel(val type: KotlinType, val label: Name?) } + +private fun computeInNonCancelableSection(action: () -> T): T = + ProgressManager.getInstance().computeInNonCancelableSection(action) diff --git a/compiler/frontend/src/org/jetbrains/kotlin/types/expressions/FunctionsTypingVisitor.kt b/compiler/frontend/src/org/jetbrains/kotlin/types/expressions/FunctionsTypingVisitor.kt index 3a123dcac69..6b65f22d541 100644 --- a/compiler/frontend/src/org/jetbrains/kotlin/types/expressions/FunctionsTypingVisitor.kt +++ b/compiler/frontend/src/org/jetbrains/kotlin/types/expressions/FunctionsTypingVisitor.kt @@ -6,6 +6,7 @@ package org.jetbrains.kotlin.types.expressions import com.google.common.collect.Lists +import com.intellij.openapi.progress.ProgressManager import com.intellij.psi.PsiElement import org.jetbrains.kotlin.builtins.* import org.jetbrains.kotlin.config.LanguageFeature @@ -209,23 +210,30 @@ internal class FunctionsTypingVisitor(facade: ExpressionTypingInternals) : Expre context: ExpressionTypingContext ): AnonymousFunctionDescriptor { val functionLiteral = expression.functionLiteral - val functionDescriptor = AnonymousFunctionDescriptor( - context.scope.ownerDescriptor, - components.annotationResolver.resolveAnnotationsWithArguments(context.scope, expression.getAnnotationEntries(), context.trace), - CallableMemberDescriptor.Kind.DECLARATION, functionLiteral.toSourceElement(), - context.expectedType.isSuspendFunctionType() - ).let { - facade.components.typeResolutionInterceptor.interceptFunctionLiteralDescriptor(expression, context, it) - } - components.functionDescriptorResolver.initializeFunctionDescriptorAndExplicitReturnType( - context.scope.ownerDescriptor, context.scope, functionLiteral, - functionDescriptor, context.trace, context.expectedType, context.dataFlowInfo, context.inferenceSession + val annotations = components.annotationResolver.resolveAnnotationsWithArguments( + context.scope, + expression.getAnnotationEntries(), + context.trace ) - for (parameterDescriptor in functionDescriptor.valueParameters) { - ForceResolveUtil.forceResolveAllContents(parameterDescriptor.annotations) + return ProgressManager.getInstance().computeInNonCancelableSection { + val functionDescriptor = AnonymousFunctionDescriptor( + context.scope.ownerDescriptor, + annotations, + CallableMemberDescriptor.Kind.DECLARATION, functionLiteral.toSourceElement(), + context.expectedType.isSuspendFunctionType() + ).let { + facade.components.typeResolutionInterceptor.interceptFunctionLiteralDescriptor(expression, context, it) + } + components.functionDescriptorResolver.initializeFunctionDescriptorAndExplicitReturnType( + context.scope.ownerDescriptor, context.scope, functionLiteral, + functionDescriptor, context.trace, context.expectedType, context.dataFlowInfo, context.inferenceSession + ) + for (parameterDescriptor in functionDescriptor.valueParameters) { + ForceResolveUtil.forceResolveAllContents(parameterDescriptor.annotations) + } + BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor) + functionDescriptor } - BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor) - return functionDescriptor } private fun KotlinType.isBuiltinFunctionalType() =