diff --git a/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/DiagnosticData.kt b/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/DiagnosticData.kt index 68ad1b5be95..baaf4432535 100644 --- a/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/DiagnosticData.kt +++ b/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/DiagnosticData.kt @@ -58,7 +58,7 @@ enum class PositioningStrategy(private val strategy: String? = null) { VALUE_ARGUMENTS, SUPERTYPES_LIST, RETURN_WITH_LABEL, - ASSIGNMENT_VALUE, + PROPERTY_INITIALIZER, WHOLE_ELEMENT, INT_LITERAL_OUT_OF_RANGE, FLOAT_LITERAL_OUT_OF_RANGE, @@ -71,6 +71,7 @@ enum class PositioningStrategy(private val strategy: String? = null) { RESERVED_UNDERSCORE, QUESTION_MARK_BY_TYPE, ANNOTATION_USE_SITE, + ASSIGNMENT_LHS, ; diff --git a/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/FirDiagnosticsList.kt b/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/FirDiagnosticsList.kt index f5a3cb87b99..8a11ad2ad4f 100644 --- a/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/FirDiagnosticsList.kt +++ b/compiler/fir/checkers/checkers-component-generator/src/org/jetbrains/kotlin/fir/checkers/generator/diagnostics/FirDiagnosticsList.kt @@ -55,7 +55,7 @@ object DIAGNOSTICS_LIST : DiagnosticList() { val ASSIGNMENT_IN_EXPRESSION_CONTEXT by error() val BREAK_OR_CONTINUE_OUTSIDE_A_LOOP by error() val NOT_A_LOOP_LABEL by error() - val VARIABLE_EXPECTED by error() + val VARIABLE_EXPECTED by error(PositioningStrategy.ASSIGNMENT_LHS) val DELEGATION_IN_INTERFACE by error() val NESTED_CLASS_NOT_ALLOWED by error(PositioningStrategy.DECLARATION_NAME) { parameter("declaration") @@ -629,7 +629,7 @@ object DIAGNOSTICS_LIST : DiagnosticList() { parameter("expectedType") parameter("actualType") } - val INITIALIZER_TYPE_MISMATCH by error(PositioningStrategy.ASSIGNMENT_VALUE) { + val INITIALIZER_TYPE_MISMATCH by error(PositioningStrategy.PROPERTY_INITIALIZER) { parameter("expectedType") parameter("actualType") } diff --git a/compiler/fir/checkers/gen/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt b/compiler/fir/checkers/gen/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt index d75ff3eef86..1d0a3cf7959 100644 --- a/compiler/fir/checkers/gen/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt +++ b/compiler/fir/checkers/gen/org/jetbrains/kotlin/fir/analysis/diagnostics/FirErrors.kt @@ -87,7 +87,7 @@ object FirErrors { val ASSIGNMENT_IN_EXPRESSION_CONTEXT by error0() val BREAK_OR_CONTINUE_OUTSIDE_A_LOOP by error0() val NOT_A_LOOP_LABEL by error0() - val VARIABLE_EXPECTED by error0() + val VARIABLE_EXPECTED by error0(SourceElementPositioningStrategies.ASSIGNMENT_LHS) val DELEGATION_IN_INTERFACE by error0() val NESTED_CLASS_NOT_ALLOWED by error1(SourceElementPositioningStrategies.DECLARATION_NAME) val INCORRECT_CHARACTER_LITERAL by error0() @@ -384,7 +384,7 @@ object FirErrors { val CONST_VAL_WITHOUT_INITIALIZER by error0(SourceElementPositioningStrategies.CONST_MODIFIER) val CONST_VAL_WITH_NON_CONST_INITIALIZER by error0() val WRONG_SETTER_PARAMETER_TYPE by error2() - val INITIALIZER_TYPE_MISMATCH by error2(SourceElementPositioningStrategies.ASSIGNMENT_VALUE) + val INITIALIZER_TYPE_MISMATCH by error2(SourceElementPositioningStrategies.PROPERTY_INITIALIZER) val GETTER_VISIBILITY_DIFFERS_FROM_PROPERTY_VISIBILITY by error0(SourceElementPositioningStrategies.VISIBILITY_MODIFIER) val SETTER_VISIBILITY_INCONSISTENT_WITH_PROPERTY_VISIBILITY by error0(SourceElementPositioningStrategies.VISIBILITY_MODIFIER) val WRONG_SETTER_RETURN_TYPE by error0() diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/collectors/components/ErrorNodeDiagnosticCollectorComponent.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/collectors/components/ErrorNodeDiagnosticCollectorComponent.kt index d938d20cbda..d65df3a264e 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/collectors/components/ErrorNodeDiagnosticCollectorComponent.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/collectors/components/ErrorNodeDiagnosticCollectorComponent.kt @@ -47,9 +47,9 @@ class ErrorNodeDiagnosticCollectorComponent( override fun visitErrorNamedReference(errorNamedReference: FirErrorNamedReference, data: CheckerContext) { val source = errorNamedReference.source ?: return val qualifiedAccessOrAnnotationCall = data.qualifiedAccessOrAnnotationCalls.lastOrNull()?.takeIf { - // Use the source of the enclosing FirQualifiedAccessExpression if it is exactly the call to the erroneous callee. + // Use the source of the enclosing FirQualifiedAccess if it is exactly the call to the erroneous callee. when (it) { - is FirQualifiedAccessExpression -> it.calleeReference == errorNamedReference + is FirQualifiedAccess -> it.calleeReference == errorNamedReference is FirAnnotationCall -> it.calleeReference == errorNamedReference else -> false } diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/LightTreePositioningStrategies.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/LightTreePositioningStrategies.kt index fe6cd5706db..1b7e1c42c5f 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/LightTreePositioningStrategies.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/LightTreePositioningStrategies.kt @@ -425,6 +425,13 @@ object LightTreePositioningStrategies { endOffset: Int, tree: FlyweightCapableTreeStructure ): List { + if (node.tokenType == KtNodeTypes.BINARY_EXPRESSION && + tree.findDescendantByTypes(node, KtTokens.ALL_ASSIGNMENTS) != null + ) { + tree.findDescendantByType(node, KtNodeTypes.DOT_QUALIFIED_EXPRESSION)?.let { + return markElement(tree.dotOperator(it) ?: it, startOffset, endOffset, tree, node) + } + } if (node.tokenType == KtNodeTypes.DOT_QUALIFIED_EXPRESSION) { return markElement(tree.dotOperator(node) ?: node, startOffset, endOffset, tree, node) } @@ -707,6 +714,33 @@ object LightTreePositioningStrategies { } } + val ASSIGNMENT_LHS: LightTreePositioningStrategy = object : LightTreePositioningStrategy() { + override fun mark( + node: LighterASTNode, + startOffset: Int, + endOffset: Int, + tree: FlyweightCapableTreeStructure + ): List { + if ((node.tokenType == KtNodeTypes.BINARY_EXPRESSION && + tree.findDescendantByTypes(node, KtTokens.ALL_ASSIGNMENTS) != null) || + ((node.tokenType == KtNodeTypes.PREFIX_EXPRESSION || node.tokenType == KtNodeTypes.POSTFIX_EXPRESSION) && + tree.findDescendantByTypes(node, KtTokens.INCREMENT_AND_DECREMENT) != null) + ) { + val lhs = if (node.tokenType == KtNodeTypes.PREFIX_EXPRESSION) { + tree.lastChildExpression(node) + } else { + tree.firstChildExpression(node) + } + lhs?.let { + tree.unwrapParenthesesLabelsAndAnnotations(it)?.let { unwrapped -> + return markElement(unwrapped, startOffset, endOffset, tree, node) + } + } + } + return super.mark(node, startOffset, endOffset, tree) + } + } + val ANNOTATION_USE_SITE: LightTreePositioningStrategy = object : LightTreePositioningStrategy() { override fun mark( node: LighterASTNode, @@ -810,6 +844,18 @@ private fun FlyweightCapableTreeStructure.referenceExpression( return result } +private fun FlyweightCapableTreeStructure.unwrapParenthesesLabelsAndAnnotations(node: LighterASTNode): LighterASTNode? { + var unwrapped = node + while (true) { + unwrapped = when (unwrapped.tokenType) { + KtNodeTypes.PARENTHESIZED -> firstChildExpression(unwrapped) ?: return unwrapped + KtNodeTypes.LABELED_EXPRESSION -> lastChildExpression(unwrapped) ?: return unwrapped + KtNodeTypes.ANNOTATED_EXPRESSION -> firstChildExpression(unwrapped) ?: return unwrapped + else -> return unwrapped + } + } +} + private fun FlyweightCapableTreeStructure.findExpressionDeep(node: LighterASTNode): LighterASTNode? = findFirstDescendant(node) { it.isExpression() } @@ -898,6 +944,18 @@ fun FlyweightCapableTreeStructure.selector(node: LighterASTNode) } +fun FlyweightCapableTreeStructure.firstChildExpression(node: LighterASTNode): LighterASTNode? { + val childrenRef = Ref>() + getChildren(node, childrenRef) + return childrenRef.get()?.firstOrNull { it?.isExpression() == true } +} + +fun FlyweightCapableTreeStructure.lastChildExpression(node: LighterASTNode): LighterASTNode? { + val childrenRef = Ref>() + getChildren(node, childrenRef) + return childrenRef.get()?.lastOrNull { it?.isExpression() == true } +} + fun FlyweightCapableTreeStructure.findChildByType(node: LighterASTNode, type: IElementType): LighterASTNode? { val childrenRef = Ref>() getChildren(node, childrenRef) diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/SourceElementPositioningStrategies.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/SourceElementPositioningStrategies.kt index caf28ff409e..64e98d5b8f8 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/SourceElementPositioningStrategies.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/SourceElementPositioningStrategies.kt @@ -208,9 +208,9 @@ object SourceElementPositioningStrategies { PositioningStrategies.RETURN_WITH_LABEL ) - val ASSIGNMENT_VALUE = SourceElementPositioningStrategy( + val PROPERTY_INITIALIZER = SourceElementPositioningStrategy( LightTreePositioningStrategies.LAST_CHILD, - PositioningStrategies.ASSIGNMENT_VALUE + PositioningStrategies.PROPERTY_INITIALIZER ) val WHOLE_ELEMENT = SourceElementPositioningStrategy( @@ -247,4 +247,9 @@ object SourceElementPositioningStrategies { LightTreePositioningStrategies.ANNOTATION_USE_SITE, PositioningStrategies.ANNOTATION_USE_SITE ) + + val ASSIGNMENT_LHS = SourceElementPositioningStrategy( + LightTreePositioningStrategies.ASSIGNMENT_LHS, + PositioningStrategies.ASSIGNMENT_LHS + ) } diff --git a/compiler/fir/raw-fir/psi2fir/tests/org/jetbrains/kotlin/fir/builder/AbstractRawFirBuilderTestCase.kt b/compiler/fir/raw-fir/psi2fir/tests/org/jetbrains/kotlin/fir/builder/AbstractRawFirBuilderTestCase.kt index c93d34223cd..3d7665828c3 100644 --- a/compiler/fir/raw-fir/psi2fir/tests/org/jetbrains/kotlin/fir/builder/AbstractRawFirBuilderTestCase.kt +++ b/compiler/fir/raw-fir/psi2fir/tests/org/jetbrains/kotlin/fir/builder/AbstractRawFirBuilderTestCase.kt @@ -83,20 +83,31 @@ abstract class AbstractRawFirBuilderTestCase : KtParsingTestCase( if (!result.add(this)) { return result } - propertyLoop@ for (property in this::class.memberProperties) { - val childElement = property.getter.apply { isAccessible = true }.call(this) + for (property in this::class.memberProperties) { + if (hasNoAcceptAndTransform(this::class.simpleName, property.name)) continue - when (childElement) { - is FirNoReceiverExpression -> continue@propertyLoop + when (val childElement = property.getter.apply { isAccessible = true }.call(this)) { + is FirNoReceiverExpression -> continue is FirElement -> childElement.traverseChildren(result) is List<*> -> childElement.filterIsInstance().forEach { it.traverseChildren(result) } - else -> continue@propertyLoop + else -> continue } } return result } + private val firImplClassPropertiesWithNoAcceptAndTransform = mapOf( + "FirResolvedImportImpl" to "delegate", + "FirErrorTypeRefImpl" to "delegatedTypeRef", + "FirResolvedTypeRefImpl" to "delegatedTypeRef" + ) + + private fun hasNoAcceptAndTransform(className: String?, propertyName: String): Boolean { + if (className == null) return false + return firImplClassPropertiesWithNoAcceptAndTransform[className] == propertyName + } + private fun FirFile.visitChildren(): Set = ConsistencyVisitor().let { this@visitChildren.accept(it) diff --git a/compiler/fir/raw-fir/raw-fir.common/src/org/jetbrains/kotlin/fir/builder/BaseFirBuilder.kt b/compiler/fir/raw-fir/raw-fir.common/src/org/jetbrains/kotlin/fir/builder/BaseFirBuilder.kt index 8484912bbce..1921944987e 100644 --- a/compiler/fir/raw-fir/raw-fir.common/src/org/jetbrains/kotlin/fir/builder/BaseFirBuilder.kt +++ b/compiler/fir/raw-fir/raw-fir.common/src/org/jetbrains/kotlin/fir/builder/BaseFirBuilder.kt @@ -445,9 +445,8 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte */ // TODO: - // 1. Support receiver capturing for `array.b++` (elementType == ARRAY_ACCESS_EXPRESSION). - // 2. Support receiver capturing for `a?.b++` (elementType == SAFE_ACCESS_EXPRESSION). - // 3. Add box test cases for #1 and #2 where receiver expression has side effects. + // 1. Support receiver capturing for `a?.b++` (elementType == SAFE_ACCESS_EXPRESSION). + // 2. Add box test cases for #1 where receiver expression has side effects. fun generateIncrementOrDecrementBlock( baseExpression: T, operationReference: T?, @@ -456,21 +455,8 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte prefix: Boolean, convert: T.() -> FirExpression ): FirExpression { - // NOTE: By removing surrounding parentheses and labels, FirLabels will NOT be created for those labels. - // This should be fine since the label is meaningless and unusable for a ++/-- argument. - var unwrappedArgument = argument - while (true) { - unwrappedArgument = when (unwrappedArgument?.elementType) { - PARENTHESIZED -> unwrappedArgument?.getExpressionInParentheses() - LABELED_EXPRESSION -> unwrappedArgument?.getLabeledExpression() - else -> break - } - } - - if (unwrappedArgument == null) { - return buildErrorExpression { - diagnostic = ConeSimpleDiagnostic("Inc/dec without operand", DiagnosticKind.Syntax) - } + val unwrappedArgument = argument.unwrap() ?: return buildErrorExpression { + diagnostic = ConeSimpleDiagnostic("Inc/dec without operand", DiagnosticKind.Syntax) } if (unwrappedArgument.elementType == DOT_QUALIFIED_EXPRESSION) { @@ -566,6 +552,20 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte } } + private fun T?.unwrap(): T? { + // NOTE: By removing surrounding parentheses and labels, FirLabels will NOT be created for those labels. + // This should be fine since the label is meaningless and unusable for a ++/-- argument or assignment LHS. + var unwrapped = this + while (true) { + unwrapped = when (unwrapped?.elementType) { + PARENTHESIZED -> unwrapped?.getExpressionInParentheses() + LABELED_EXPRESSION -> unwrapped?.getLabeledExpression() + ANNOTATED_EXPRESSION -> unwrapped?.getAnnotatedExpression() + else -> return unwrapped + } + } + } + /** * given: * a.b++ @@ -860,12 +860,6 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte } } } - PARENTHESIZED -> { - return initializeLValue(left.getExpressionInParentheses(), convertQualified) - } - ANNOTATED_EXPRESSION -> { - return initializeLValue(left.getAnnotatedExpression(), convertQualified) - } } } return buildErrorNamedReference { @@ -881,19 +875,19 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte operation: FirOperation, convert: T.() -> FirExpression ): FirStatement { - val tokenType = this?.elementType - if (tokenType == PARENTHESIZED) { - return this!!.getExpressionInParentheses().generateAssignment(baseSource, rhs, value, operation, convert) + val unwrappedLhs = this.unwrap() ?: return buildErrorExpression { + diagnostic = ConeSimpleDiagnostic("Inc/dec without operand", DiagnosticKind.Syntax) } + + val tokenType = unwrappedLhs.elementType if (tokenType == ARRAY_ACCESS_EXPRESSION) { - require(this != null) if (operation == FirOperation.ASSIGN) { - context.arraySetArgument[this] = value + context.arraySetArgument[unwrappedLhs] = value } return if (operation == FirOperation.ASSIGN) { - this.convert() + unwrappedLhs.convert() } else { - generateAugmentedArraySetCall(baseSource, operation, rhs, convert) + generateAugmentedArraySetCall(unwrappedLhs, baseSource, operation, rhs, convert) } } @@ -924,7 +918,7 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte return buildVariableAssignment { source = baseSource rValue = value - calleeReference = initializeLValue(this@generateAssignment) { convert() as? FirQualifiedAccess } + calleeReference = initializeLValue(unwrappedLhs) { convert() as? FirQualifiedAccess } } } @@ -950,7 +944,8 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte return safeCallNonAssignment } - private fun T.generateAugmentedArraySetCall( + private fun generateAugmentedArraySetCall( + unwrappedReceiver: T, baseSource: FirSourceElement?, operation: FirOperation, rhs: T?, @@ -959,12 +954,13 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte return buildAugmentedArraySetCall { source = baseSource this.operation = operation - assignCall = generateAugmentedCallForAugmentedArraySetCall(operation, rhs, convert) - setGetBlock = generateSetGetBlockForAugmentedArraySetCall(baseSource, operation, rhs, convert) + assignCall = generateAugmentedCallForAugmentedArraySetCall(unwrappedReceiver, operation, rhs, convert) + setGetBlock = generateSetGetBlockForAugmentedArraySetCall(unwrappedReceiver, baseSource, operation, rhs, convert) } } - private fun T.generateAugmentedCallForAugmentedArraySetCall( + private fun generateAugmentedCallForAugmentedArraySetCall( + unwrappedReceiver: T, operation: FirOperation, rhs: T?, convert: T.() -> FirExpression @@ -977,7 +973,7 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte calleeReference = buildSimpleNamedReference { name = FirOperationNameConventions.ASSIGNMENTS.getValue(operation) } - explicitReceiver = convert() + explicitReceiver = unwrappedReceiver.convert() argumentList = buildArgumentList { arguments += rhs?.convert() ?: buildErrorExpression( null, @@ -989,7 +985,8 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte } - private fun T.generateSetGetBlockForAugmentedArraySetCall( + private fun generateSetGetBlockForAugmentedArraySetCall( + unwrappedReceiver: T, baseSource: FirSourceElement?, operation: FirOperation, rhs: T?, @@ -1005,7 +1002,7 @@ abstract class BaseFirBuilder(val baseSession: FirSession, val context: Conte * } */ return buildBlock { - val baseCall = convert() as FirFunctionCall + val baseCall = unwrappedReceiver.convert() as FirFunctionCall val arrayVariable = generateTemporaryVariable( baseModuleData, diff --git a/compiler/frontend/src/org/jetbrains/kotlin/diagnostics/PositioningStrategies.kt b/compiler/frontend/src/org/jetbrains/kotlin/diagnostics/PositioningStrategies.kt index 4d61d97b49d..bf6985a5101 100644 --- a/compiler/frontend/src/org/jetbrains/kotlin/diagnostics/PositioningStrategies.kt +++ b/compiler/frontend/src/org/jetbrains/kotlin/diagnostics/PositioningStrategies.kt @@ -775,11 +775,14 @@ object PositioningStrategies { val DOT_BY_QUALIFIED: PositioningStrategy = object : PositioningStrategy() { override fun mark(element: PsiElement): List { - when (element) { - is KtDotQualifiedExpression -> { - return mark(element.operationTokenNode.psi) + if (element is KtBinaryExpression && element.operationToken in KtTokens.ALL_ASSIGNMENTS) { + element.left?.let { left -> + left.findDescendantOfType()?.let { return mark(it) } } } + if (element is KtDotQualifiedExpression) { + return mark(element.operationTokenNode.psi) + } // Fallback to mark the callee reference. return REFERENCE_BY_QUALIFIED.mark(element) } @@ -848,9 +851,9 @@ object PositioningStrategies { val REIFIED_MODIFIER: PositioningStrategy = modifierSetPosition(KtTokens.REIFIED_KEYWORD) - val ASSIGNMENT_VALUE: PositioningStrategy = object : PositioningStrategy() { - override fun mark(element: PsiElement): List { - return markElement(if (element is KtProperty) element.initializer ?: element else element) + val PROPERTY_INITIALIZER: PositioningStrategy = object : PositioningStrategy() { + override fun mark(element: KtProperty): List { + return markElement(element.initializer ?: element) } } @@ -871,6 +874,17 @@ object PositioningStrategies { } } + val ASSIGNMENT_LHS: PositioningStrategy = object : PositioningStrategy() { + override fun mark(element: PsiElement): List { + if (element is KtBinaryExpression && element.operationToken in KtTokens.ALL_ASSIGNMENTS) { + element.left.let { left -> left.unwrapParenthesesLabelsAndAnnotations()?.let { return markElement(it) } } + } + if (element is KtUnaryExpression && element.operationToken in KtTokens.INCREMENT_AND_DECREMENT) { + element.baseExpression.let { arg -> arg.unwrapParenthesesLabelsAndAnnotations()?.let { return markElement(it) } } + } + return super.mark(element) + } + } /** * @param locateReferencedName whether to remove any nested parentheses while locating the reference element. This is useful for diff --git a/compiler/psi/src/org/jetbrains/kotlin/lexer/KtTokens.java b/compiler/psi/src/org/jetbrains/kotlin/lexer/KtTokens.java index 8d991122be8..c1119d53da0 100644 --- a/compiler/psi/src/org/jetbrains/kotlin/lexer/KtTokens.java +++ b/compiler/psi/src/org/jetbrains/kotlin/lexer/KtTokens.java @@ -265,4 +265,5 @@ public interface KtTokens { TokenSet AUGMENTED_ASSIGNMENTS = TokenSet.create(PLUSEQ, MINUSEQ, MULTEQ, PERCEQ, DIVEQ); TokenSet ALL_ASSIGNMENTS = TokenSet.create(EQ, PLUSEQ, MINUSEQ, MULTEQ, PERCEQ, DIVEQ); + TokenSet INCREMENT_AND_DECREMENT = TokenSet.create(PLUSPLUS, MINUSMINUS); } diff --git a/compiler/psi/src/org/jetbrains/kotlin/psi/psiUtil/psiUtils.kt b/compiler/psi/src/org/jetbrains/kotlin/psi/psiUtil/psiUtils.kt index 69efb019386..1c48a593ea7 100644 --- a/compiler/psi/src/org/jetbrains/kotlin/psi/psiUtil/psiUtils.kt +++ b/compiler/psi/src/org/jetbrains/kotlin/psi/psiUtil/psiUtils.kt @@ -498,4 +498,16 @@ fun KtExpression.isNull(): Boolean { returns(true) implies (this@isNull is KtConstantExpression) } return this is KtConstantExpression && this.node.elementType == KtNodeTypes.NULL -} \ No newline at end of file +} + +fun PsiElement?.unwrapParenthesesLabelsAndAnnotations(): PsiElement? { + var unwrapped = this + while (true) { + unwrapped = when (unwrapped) { + is KtParenthesizedExpression -> unwrapped.expression + is KtLabeledExpression -> unwrapped.baseExpression + is KtAnnotatedExpression -> unwrapped.baseExpression + else -> return unwrapped + } + } +} diff --git a/compiler/testData/diagnostics/tests/LValueAssignment.fir.kt b/compiler/testData/diagnostics/tests/LValueAssignment.fir.kt index f541a7db8f6..0126256aa7c 100644 --- a/compiler/testData/diagnostics/tests/LValueAssignment.fir.kt +++ b/compiler/testData/diagnostics/tests/LValueAssignment.fir.kt @@ -51,18 +51,25 @@ fun cannotBe() { 5 = 34 } +@Retention(AnnotationRetention.SOURCE) +@Target(AnnotationTarget.EXPRESSION) +annotation class Ann + fun canBe(i0: Int, j: Int) { var i = i0 - (label@ i) = 34 + (label@ i) = 34 - (label@ j) = 34 //repeat for j + (label@ j) = 34 //repeat for j val a = A() - (l@ a.a) = 3894 + (l@ a.a) = 3894 + + @Ann + l@ (i) = 123 } fun canBe2(j: Int) { - (label@ j) = 34 + (label@ j) = 34 } class A() { @@ -78,10 +85,13 @@ class Test() { (f@ getInt()) += 343 1++ - (r@ 1)++ + (r@ 1)-- getInt()++ - (m@ getInt())++ + (m@ getInt())-- + + ++2 + --(r@ 2) this++ @@ -89,6 +99,9 @@ class Test() { s += "ss" s += this s += (a@ 2) + + @Ann + l@ (1) = 123 } fun testIncompleteSyntax() { @@ -106,8 +119,11 @@ class Test() { b += 34 a++ - (l@ a)++ + (@Ann l@ a)-- (a)++ + --a + ++(@Ann l@ a) + --(a) } fun testVariables1() { @@ -122,6 +138,9 @@ class Test() { a[3] = 4 a[4]++ a[6] += 43 + @Ann + a[7] = 7 + (@Ann l@ (a))[8] = 8 ab.getArray()[54] = 23 ab.getArray()[54]++ diff --git a/compiler/testData/diagnostics/tests/LValueAssignment.kt b/compiler/testData/diagnostics/tests/LValueAssignment.kt index 9d3260345bb..6345c6cfe4f 100644 --- a/compiler/testData/diagnostics/tests/LValueAssignment.kt +++ b/compiler/testData/diagnostics/tests/LValueAssignment.kt @@ -51,6 +51,10 @@ fun cannotBe() { 5 = 34 } +@Retention(AnnotationRetention.SOURCE) +@Target(AnnotationTarget.EXPRESSION) +annotation class Ann + fun canBe(i0: Int, j: Int) { var i = i0 (label@ i) = 34 @@ -59,6 +63,9 @@ fun canBe(i0: Int, j: Int) { val a = A() (l@ a.a) = 3894 + + @Ann + l@ (i) = 123 } fun canBe2(j: Int) { @@ -78,10 +85,13 @@ class Test() { (f@ getInt()) += 343 1++ - (r@ 1)++ + (r@ 1)-- getInt()++ - (m@ getInt())++ + (m@ getInt())-- + + ++2 + --(r@ 2) this++ @@ -89,6 +99,9 @@ class Test() { s += "ss" s += this s += (a@ 2) + + @Ann + l@ (1) = 123 } fun testIncompleteSyntax() { @@ -106,8 +119,11 @@ class Test() { b += 34 a++ - (l@ a)++ + (@Ann l@ a)-- (a)++ + --a + ++(@Ann l@ a) + --(a) } fun testVariables1() { @@ -122,6 +138,9 @@ class Test() { a[3] = 4 a[4]++ a[6] += 43 + @Ann + a[7] = 7 + (@Ann l@ (a))[8] = 8 ab.getArray()[54] = 23 ab.getArray()[54]++ diff --git a/compiler/testData/diagnostics/tests/LValueAssignment.txt b/compiler/testData/diagnostics/tests/LValueAssignment.txt index 1f74c392ee9..935977a3e96 100644 --- a/compiler/testData/diagnostics/tests/LValueAssignment.txt +++ b/compiler/testData/diagnostics/tests/LValueAssignment.txt @@ -24,6 +24,13 @@ package lvalue_assignment { public open override /*1*/ /*fake_override*/ fun toString(): kotlin.String } + @kotlin.annotation.Retention(value = AnnotationRetention.SOURCE) @kotlin.annotation.Target(allowedTargets = {AnnotationTarget.EXPRESSION}) public final annotation class Ann : kotlin.Annotation { + public constructor Ann() + public open override /*1*/ /*fake_override*/ fun equals(/*0*/ other: kotlin.Any?): kotlin.Boolean + public open override /*1*/ /*fake_override*/ fun hashCode(): kotlin.Int + public open override /*1*/ /*fake_override*/ fun toString(): kotlin.String + } + public open class B { public constructor B() public final var b: kotlin.Int diff --git a/compiler/testData/diagnostics/tests/controlFlowAnalysis/kt2330.fir.kt b/compiler/testData/diagnostics/tests/controlFlowAnalysis/kt2330.fir.kt index 8f5c98ebdf4..2964014cd39 100644 --- a/compiler/testData/diagnostics/tests/controlFlowAnalysis/kt2330.fir.kt +++ b/compiler/testData/diagnostics/tests/controlFlowAnalysis/kt2330.fir.kt @@ -47,7 +47,7 @@ class R { fun test() { val o = object { fun run() { - p.x = 43 + p.x = 43 } } } \ No newline at end of file diff --git a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.fir.kt b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.fir.kt index 0c7e1257a61..6916a5d33c3 100644 --- a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.fir.kt +++ b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.fir.kt @@ -1,6 +1,7 @@ // !DIAGNOSTICS: -DEBUG_INFO_SMARTCAST class Foo { fun foo(a: Foo): Foo = a + var f: Foo? = null } fun main() { @@ -27,4 +28,14 @@ fun main() { val z: Foo? = null z!!.foo(z!!) + + val w: Foo? = null + w.f = z + (w.f) = z + (label@ w.f) = z + w!!.f = z + w.f = z + w!!.f = z + w.f.f = z + w.f!!.f = z } diff --git a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.kt b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.kt index baff72dd922..775e11d0e1b 100644 --- a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.kt +++ b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.kt @@ -1,6 +1,7 @@ // !DIAGNOSTICS: -DEBUG_INFO_SMARTCAST class Foo { fun foo(a: Foo): Foo = a + var f: Foo? = null } fun main() { @@ -27,4 +28,14 @@ fun main() { val z: Foo? = null z!!.foo(z!!) + + val w: Foo? = null + w.f = z + (w.f) = z + (label@ w.f) = z + w!!.f = z + w.f = z + w!!.f = z + w.f.f = z + w.f!!.f = z } diff --git a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.txt b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.txt index 585dcd4ad75..fe62cbc0536 100644 --- a/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.txt +++ b/compiler/testData/diagnostics/tests/nullabilityAndSmartCasts/QualifiedExpressionNullability.txt @@ -4,6 +4,7 @@ public fun main(): kotlin.Unit public final class Foo { public constructor Foo() + public final var f: Foo? public open override /*1*/ /*fake_override*/ fun equals(/*0*/ other: kotlin.Any?): kotlin.Boolean public final fun foo(/*0*/ a: Foo): Foo public open override /*1*/ /*fake_override*/ fun hashCode(): kotlin.Int diff --git a/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/AddExclExclCallFixFactories.kt b/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/AddExclExclCallFixFactories.kt index 1777a5fe69f..e792e358601 100644 --- a/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/AddExclExclCallFixFactories.kt +++ b/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/AddExclExclCallFixFactories.kt @@ -13,6 +13,7 @@ import org.jetbrains.kotlin.idea.frontend.api.fir.diagnostics.KtFirDiagnostic import org.jetbrains.kotlin.idea.frontend.api.symbols.KtFunctionSymbol import org.jetbrains.kotlin.idea.quickfix.AddExclExclCallFix import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.psiUtil.unwrapParenthesesLabelsAndAnnotations import org.jetbrains.kotlin.util.OperatorNameConventions object AddExclExclCallFixFactories { @@ -29,46 +30,51 @@ object AddExclExclCallFixFactories { } private fun KtAnalysisSession.getFixForUnsafeCall(psi: PsiElement): List { - val (target, hasImplicitReceiver) = when (psi) { + val (target, hasImplicitReceiver) = when (val unwrapped = psi.unwrapParenthesesLabelsAndAnnotations()) { // `foo.bar` -> `foo!!.bar` - is KtDotQualifiedExpression -> psi.receiverExpression to false + is KtDotQualifiedExpression -> unwrapped.receiverExpression to false // `foo[bar]` -> `foo!![bar]` - is KtArrayAccessExpression -> psi.arrayExpression to false + is KtArrayAccessExpression -> unwrapped.arrayExpression to false - is KtCallableReferenceExpression -> psi.lhs.let { lhs -> + is KtCallableReferenceExpression -> unwrapped.lhs.let { lhs -> if (lhs != null) { // `foo::bar` -> `foo!!::bar` lhs to false } else { // `::bar -> this!!::bar` - psi to true + unwrapped to true } } // `bar` -> `this!!.bar` - is KtNameReferenceExpression -> psi to true + is KtNameReferenceExpression -> unwrapped to true // `bar()` -> `this!!.bar()` - is KtCallExpression -> psi to true + is KtCallExpression -> unwrapped to true // `-foo` -> `-foo!!` // NOTE: Unsafe unary operator call is reported as UNSAFE_CALL, _not_ UNSAFE_OPERATOR_CALL - is KtUnaryExpression -> psi.baseExpression to false + is KtUnaryExpression -> unwrapped.baseExpression to false is KtBinaryExpression -> { - val receiver = if (KtPsiUtil.isInOrNotInOperation(psi)) { - // `bar in foo` -> `bar in foo!!` - psi.right - } else { - // `foo + bar` -> `foo!! + bar` OR `foo infixFun bar` -> `foo!! infixFun bar` - psi.left + val receiver = when { + KtPsiUtil.isInOrNotInOperation(unwrapped) -> + // `bar in foo` -> `bar in foo!!` + unwrapped.right + KtPsiUtil.isAssignment(unwrapped) -> + // UNSAFE_CALL for assignments (e.g., `foo.bar = value`) is reported on the entire statement (KtBinaryExpression). + // The unsafe call is on the LHS of the assignment. + return getFixForUnsafeCall(unwrapped.left ?: return emptyList()) + else -> + // `foo + bar` -> `foo!! + bar` OR `foo infixFun bar` -> `foo!! infixFun bar` + unwrapped.left } receiver to false } // UNSAFE_INFIX_CALL/UNSAFE_OPERATOR_CALL on KtBinaryExpression is reported on the child KtOperationReferenceExpression - is KtOperationReferenceExpression -> return getFixForUnsafeCall(psi.parent) + is KtOperationReferenceExpression -> return getFixForUnsafeCall(unwrapped.parent) else -> return emptyList() } diff --git a/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/ReplaceCallFixFactories.kt b/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/ReplaceCallFixFactories.kt index 725e6dbd0e7..96892fb81b0 100644 --- a/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/ReplaceCallFixFactories.kt +++ b/idea/idea-fir/src/org/jetbrains/kotlin/idea/quickfix/fixes/ReplaceCallFixFactories.kt @@ -11,9 +11,9 @@ import org.jetbrains.kotlin.idea.frontend.api.types.KtTypeNullability import org.jetbrains.kotlin.idea.frontend.api.types.KtTypeWithNullability import org.jetbrains.kotlin.idea.quickfix.ReplaceImplicitReceiverCallFix import org.jetbrains.kotlin.idea.quickfix.ReplaceWithSafeCallFix -import org.jetbrains.kotlin.psi.KtDotQualifiedExpression -import org.jetbrains.kotlin.psi.KtExpression -import org.jetbrains.kotlin.psi.KtNameReferenceExpression +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.psiUtil.unwrapParenthesesLabelsAndAnnotations object ReplaceCallFixFactories { val unsafeCallFactory = @@ -25,14 +25,23 @@ object ReplaceCallFixFactories { return expectedType?.nullability == KtTypeNullability.NON_NULLABLE } - when (val psi = diagnostic.psi) { - is KtDotQualifiedExpression -> listOf(ReplaceWithSafeCallFix(psi, psi.shouldHaveNotNullType())) + val psi = diagnostic.psi + val target = if (psi is KtBinaryExpression && psi.operationToken in KtTokens.ALL_ASSIGNMENTS) { + // UNSAFE_CALL for assignments (e.g., `foo.bar = value`) is reported on the entire statement (KtBinaryExpression). + // The unsafe call is on the LHS of the assignment. + psi.left + } else { + psi + }.unwrapParenthesesLabelsAndAnnotations() + + when (target) { + is KtDotQualifiedExpression -> listOf(ReplaceWithSafeCallFix(target, target.shouldHaveNotNullType())) is KtNameReferenceExpression -> { // TODO: As a safety precaution, resolve the expression to determine if it is a call with an implicit receiver. // This is a defensive check to ensure that the diagnostic was reported on such a call and not some other name reference. // This isn't strictly needed because FIR checkers aren't reporting on wrong elements, but ReplaceWithSafeCallFixFactory // in FE1.0 does so. - listOf(ReplaceImplicitReceiverCallFix(psi, psi.shouldHaveNotNullType())) + listOf(ReplaceImplicitReceiverCallFix(target, target.shouldHaveNotNullType())) } else -> emptyList() } diff --git a/idea/idea-fir/tests/org/jetbrains/kotlin/idea/quickfix/HighLevelQuickFixTestGenerated.java b/idea/idea-fir/tests/org/jetbrains/kotlin/idea/quickfix/HighLevelQuickFixTestGenerated.java index 8df0121db48..2220ddbe7ba 100644 --- a/idea/idea-fir/tests/org/jetbrains/kotlin/idea/quickfix/HighLevelQuickFixTestGenerated.java +++ b/idea/idea-fir/tests/org/jetbrains/kotlin/idea/quickfix/HighLevelQuickFixTestGenerated.java @@ -229,6 +229,16 @@ public class HighLevelQuickFixTestGenerated extends AbstractHighLevelQuickFixTes runTest("idea/testData/quickfix/addExclExclCall/array4.kt"); } + @TestMetadata("assignment.kt") + public void testAssignment() throws Exception { + runTest("idea/testData/quickfix/addExclExclCall/assignment.kt"); + } + + @TestMetadata("assignmentToUnsafeCallExpression.kt") + public void testAssignmentToUnsafeCallExpression() throws Exception { + runTest("idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt"); + } + @TestMetadata("functionReference.kt") public void testFunctionReference() throws Exception { runTest("idea/testData/quickfix/addExclExclCall/functionReference.kt"); @@ -1271,6 +1281,11 @@ public class HighLevelQuickFixTestGenerated extends AbstractHighLevelQuickFixTes runTest("idea/testData/quickfix/replaceWithSafeCall/assignmentToPropertyWithNoExplicitType.kt"); } + @TestMetadata("assignmentToUnsafeCallExpression.kt") + public void testAssignmentToUnsafeCallExpression() throws Exception { + runTest("idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt"); + } + @TestMetadata("comment.kt") public void testComment() throws Exception { runTest("idea/testData/quickfix/replaceWithSafeCall/comment.kt"); diff --git a/idea/testData/quickfix/addExclExclCall/assignment.kt b/idea/testData/quickfix/addExclExclCall/assignment.kt new file mode 100644 index 00000000000..503672ec8ef --- /dev/null +++ b/idea/testData/quickfix/addExclExclCall/assignment.kt @@ -0,0 +1,7 @@ +// "Add non-null asserted (!!) call" "true" +// WITH_RUNTIME +var i = 0 + +fun foo(s: String?) { + i = s.length +} diff --git a/idea/testData/quickfix/addExclExclCall/assignment.kt.after b/idea/testData/quickfix/addExclExclCall/assignment.kt.after new file mode 100644 index 00000000000..65101337f82 --- /dev/null +++ b/idea/testData/quickfix/addExclExclCall/assignment.kt.after @@ -0,0 +1,7 @@ +// "Add non-null asserted (!!) call" "true" +// WITH_RUNTIME +var i = 0 + +fun foo(s: String?) { + i = s!!.length +} diff --git a/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt b/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt new file mode 100644 index 00000000000..05d657e6c3c --- /dev/null +++ b/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt @@ -0,0 +1,6 @@ +// "Add non-null asserted (!!) call" "true" +class A(var s: String) + +fun foo(a: A?) { + a.s = "" +} diff --git a/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt.after b/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt.after new file mode 100644 index 00000000000..dc16487f584 --- /dev/null +++ b/idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt.after @@ -0,0 +1,6 @@ +// "Add non-null asserted (!!) call" "true" +class A(var s: String) + +fun foo(a: A?) { + a!!.s = "" +} diff --git a/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt b/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt new file mode 100644 index 00000000000..9c68de212ba --- /dev/null +++ b/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt @@ -0,0 +1,6 @@ +// "Replace with safe (?.) call" "true" +class A(var s: String? = null) + +fun foo(a: A?) { + a.s = "" +} diff --git a/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt.after b/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt.after new file mode 100644 index 00000000000..29cdd61c952 --- /dev/null +++ b/idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt.after @@ -0,0 +1,6 @@ +// "Replace with safe (?.) call" "true" +class A(var s: String? = null) + +fun foo(a: A?) { + a?.s = "" +} diff --git a/idea/tests/org/jetbrains/kotlin/idea/quickfix/QuickFixTestGenerated.java b/idea/tests/org/jetbrains/kotlin/idea/quickfix/QuickFixTestGenerated.java index 6f62b4ffcb6..05c481cf0b7 100644 --- a/idea/tests/org/jetbrains/kotlin/idea/quickfix/QuickFixTestGenerated.java +++ b/idea/tests/org/jetbrains/kotlin/idea/quickfix/QuickFixTestGenerated.java @@ -642,6 +642,16 @@ public class QuickFixTestGenerated extends AbstractQuickFixTest { runTest("idea/testData/quickfix/addExclExclCall/array4.kt"); } + @TestMetadata("assignment.kt") + public void testAssignment() throws Exception { + runTest("idea/testData/quickfix/addExclExclCall/assignment.kt"); + } + + @TestMetadata("assignmentToUnsafeCallExpression.kt") + public void testAssignmentToUnsafeCallExpression() throws Exception { + runTest("idea/testData/quickfix/addExclExclCall/assignmentToUnsafeCallExpression.kt"); + } + @TestMetadata("functionReference.kt") public void testFunctionReference() throws Exception { runTest("idea/testData/quickfix/addExclExclCall/functionReference.kt"); @@ -11825,6 +11835,11 @@ public class QuickFixTestGenerated extends AbstractQuickFixTest { runTest("idea/testData/quickfix/replaceWithSafeCall/assignmentToPropertyWithNoExplicitType.kt"); } + @TestMetadata("assignmentToUnsafeCallExpression.kt") + public void testAssignmentToUnsafeCallExpression() throws Exception { + runTest("idea/testData/quickfix/replaceWithSafeCall/assignmentToUnsafeCallExpression.kt"); + } + @TestMetadata("comment.kt") public void testComment() throws Exception { runTest("idea/testData/quickfix/replaceWithSafeCall/comment.kt");