diff --git a/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/DiagnosisCompilerFirTestdataTestGenerated.java b/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/DiagnosisCompilerFirTestdataTestGenerated.java index d060b97196f..dda07b90f43 100644 --- a/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/DiagnosisCompilerFirTestdataTestGenerated.java +++ b/analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/DiagnosisCompilerFirTestdataTestGenerated.java @@ -3551,6 +3551,18 @@ public class DiagnosisCompilerFirTestdataTestGenerated extends AbstractDiagnosis runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastToTypeParameter.kt"); } + @Test + @TestMetadata("smartcastsFromEquals_differentModule.kt") + public void testSmartcastsFromEquals_differentModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt"); + } + + @Test + @TestMetadata("smartcastsFromEquals_sameModule.kt") + public void testSmartcastsFromEquals_sameModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt"); + } + @Nested @TestMetadata("compiler/fir/analysis-tests/testData/resolve/smartcasts/booleans") @TestDataPath("$PROJECT_ROOT") diff --git a/compiler/fir/analysis-tests/legacy-fir-tests/tests-gen/org/jetbrains/kotlin/fir/LazyBodyIsNotTouchedTilContractsPhaseTestGenerated.java b/compiler/fir/analysis-tests/legacy-fir-tests/tests-gen/org/jetbrains/kotlin/fir/LazyBodyIsNotTouchedTilContractsPhaseTestGenerated.java index 4d07376ac64..a71c2e22c11 100644 --- a/compiler/fir/analysis-tests/legacy-fir-tests/tests-gen/org/jetbrains/kotlin/fir/LazyBodyIsNotTouchedTilContractsPhaseTestGenerated.java +++ b/compiler/fir/analysis-tests/legacy-fir-tests/tests-gen/org/jetbrains/kotlin/fir/LazyBodyIsNotTouchedTilContractsPhaseTestGenerated.java @@ -3139,6 +3139,16 @@ public class LazyBodyIsNotTouchedTilContractsPhaseTestGenerated extends Abstract runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastToTypeParameter.kt"); } + @TestMetadata("smartcastsFromEquals_differentModule.kt") + public void testSmartcastsFromEquals_differentModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt"); + } + + @TestMetadata("smartcastsFromEquals_sameModule.kt") + public void testSmartcastsFromEquals_sameModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt"); + } + @TestMetadata("compiler/fir/analysis-tests/testData/resolve/smartcasts/booleans") @TestDataPath("$PROJECT_ROOT") @RunWith(JUnit3RunnerWithInners.class) diff --git a/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.fir.txt b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.fir.txt new file mode 100644 index 00000000000..6dcf1dd39b3 --- /dev/null +++ b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.fir.txt @@ -0,0 +1,96 @@ +Module: lib +FILE: module_lib_smartcastsFromEquals_differentModule.kt + public final class Final : R|kotlin/Any| { + public constructor(): R|Final| { + super() + } + + } + public open class Base : R|kotlin/Any| { + public constructor(): R|Base| { + super() + } + + } + public final class Derived : R|Base| { + public constructor(): R|Derived| { + super|>() + } + + } + public final class FinalWithOverride : R|kotlin/Any| { + public constructor(): R|FinalWithOverride| { + super() + } + + public final override operator fun equals(other: R|kotlin/Any?|): R|kotlin/Boolean| { + ^equals ===(this@R|/FinalWithOverride|, R|/other|) + } + + } +Module: main +FILE: module_main_smartcastsFromEquals_differentModule.kt + public final fun testFinal(x: R|Final<*>|, y: R|Final|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntFinal|(R|/x|) + } + } + + } + public final fun testBase(x: R|Base<*>|, y: R|Base|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntBase|(R|/x|) + } + } + + } + public final fun testDerived(x: R|Derived<*>|, y: R|Derived|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntDerived|(R|/x|) + } + } + + } + public final fun testFinalWithOverride(x: R|FinalWithOverride<*>|, y: R|FinalWithOverride|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntFinalWithOverride|(R|/x|) + } + } + + } + public final fun takeIntFinal(x: R|Final|): R|kotlin/Unit| { + } + public final fun takeIntBase(x: R|Base|): R|kotlin/Unit| { + } + public final fun takeIntDerived(x: R|Derived|): R|kotlin/Unit| { + } + public final fun takeIntFinalWithOverride(x: R|FinalWithOverride|): R|kotlin/Unit| { + } diff --git a/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt new file mode 100644 index 00000000000..e6c3e3c79e5 --- /dev/null +++ b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt @@ -0,0 +1,60 @@ +// SKIP_JAVAC +// This directive is needed to skip this test in LazyBodyIsNotTouchedTilContractsPhaseTestGenerated, +// because it fails to parse module structure of multimodule test +// ISSUE: KT-49127 + +// MODULE: lib +class Final + +open class Base + +class Derived : Base() + +class FinalWithOverride { + override fun equals(other: Any?): Boolean { + // some custom implementation + return this === other + } +} + +// MODULE: main(lib) +fun testFinal(x: Final<*>, y: Final) { + if (x == y) { + takeIntFinal(x) // Error + } + if (x === y) { + takeIntFinal(x) // OK + } +} + +fun testBase(x: Base<*>, y: Base) { + if (x == y) { + takeIntBase(x) // Error + } + if (x === y) { + takeIntBase(x) // OK + } +} + +fun testDerived(x: Derived<*>, y: Derived) { + if (x == y) { + takeIntDerived(x) // Error + } + if (x === y) { + takeIntDerived(x) // OK + } +} + +fun testFinalWithOverride(x: FinalWithOverride<*>, y: FinalWithOverride) { + if (x == y) { + takeIntFinalWithOverride(x) // Error + } + if (x === y) { + takeIntFinalWithOverride(x) // OK + } +} + +fun takeIntFinal(x: Final) {} +fun takeIntBase(x: Base) {} +fun takeIntDerived(x: Derived) {} +fun takeIntFinalWithOverride(x: FinalWithOverride) {} diff --git a/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.fir.txt b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.fir.txt new file mode 100644 index 00000000000..c786a6ec95d --- /dev/null +++ b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.fir.txt @@ -0,0 +1,93 @@ +FILE: smartcastsFromEquals_sameModule.kt + public final class Final : R|kotlin/Any| { + public constructor(): R|Final| { + super() + } + + } + public open class Base : R|kotlin/Any| { + public constructor(): R|Base| { + super() + } + + } + public final class Derived : R|Base| { + public constructor(): R|Derived| { + super|>() + } + + } + public final class FinalWithOverride : R|kotlin/Any| { + public constructor(): R|FinalWithOverride| { + super() + } + + public final override operator fun equals(other: R|kotlin/Any?|): R|kotlin/Boolean| { + ^equals ===(this@R|/FinalWithOverride|, R|/other|) + } + + } + public final fun testFinal(x: R|Final<*>|, y: R|Final|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + R|/takeIntFinal|(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntFinal|(R|/x|) + } + } + + } + public final fun testBase(x: R|Base<*>|, y: R|Base|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntBase|(R|/x|) + } + } + + } + public final fun testDerived(x: R|Derived<*>|, y: R|Derived|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + R|/takeIntDerived|(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntDerived|(R|/x|) + } + } + + } + public final fun testFinalWithOverride(x: R|FinalWithOverride<*>|, y: R|FinalWithOverride|): R|kotlin/Unit| { + when () { + ==(R|/x|, R|/y|) -> { + #(R|/x|) + } + } + + when () { + ===(R|/x|, R|/y|) -> { + R|/takeIntFinalWithOverride|(R|/x|) + } + } + + } + public final fun takeIntFinal(x: R|Final|): R|kotlin/Unit| { + } + public final fun takeIntBase(x: R|Base|): R|kotlin/Unit| { + } + public final fun takeIntDerived(x: R|Derived|): R|kotlin/Unit| { + } + public final fun takeIntFinalWithOverride(x: R|FinalWithOverride|): R|kotlin/Unit| { + } diff --git a/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt new file mode 100644 index 00000000000..12bf5d7aeb5 --- /dev/null +++ b/compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt @@ -0,0 +1,55 @@ +// ISSUE: KT-49127 + +class Final + +open class Base + +class Derived : Base() + +class FinalWithOverride { + override fun equals(other: Any?): Boolean { + // some custom implementation + return this === other + } +} + +fun testFinal(x: Final<*>, y: Final) { + if (x == y) { + takeIntFinal(x) // OK + } + if (x === y) { + takeIntFinal(x) // OK + } +} + +fun testBase(x: Base<*>, y: Base) { + if (x == y) { + takeIntBase(x) // Error + } + if (x === y) { + takeIntBase(x) // OK + } +} + +fun testDerived(x: Derived<*>, y: Derived) { + if (x == y) { + takeIntDerived(x) // OK + } + if (x === y) { + takeIntDerived(x) // OK + } +} + +fun testFinalWithOverride(x: FinalWithOverride<*>, y: FinalWithOverride) { + if (x == y) { + takeIntFinalWithOverride(x) // Error + } + if (x === y) { + takeIntFinalWithOverride(x) // OK + } +} + +fun takeIntFinal(x: Final) {} +fun takeIntBase(x: Base) {} +fun takeIntDerived(x: Derived) {} +fun takeIntFinalWithOverride(x: FinalWithOverride) {} diff --git a/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticTestGenerated.java b/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticTestGenerated.java index 0fcb1434f6a..13e40e6c944 100644 --- a/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticTestGenerated.java +++ b/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticTestGenerated.java @@ -3551,6 +3551,18 @@ public class FirDiagnosticTestGenerated extends AbstractFirDiagnosticTest { runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastToTypeParameter.kt"); } + @Test + @TestMetadata("smartcastsFromEquals_differentModule.kt") + public void testSmartcastsFromEquals_differentModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt"); + } + + @Test + @TestMetadata("smartcastsFromEquals_sameModule.kt") + public void testSmartcastsFromEquals_sameModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt"); + } + @Nested @TestMetadata("compiler/fir/analysis-tests/testData/resolve/smartcasts/booleans") @TestDataPath("$PROJECT_ROOT") diff --git a/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticsWithLightTreeTestGenerated.java b/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticsWithLightTreeTestGenerated.java index d0a7bf1ce69..8f3be700607 100644 --- a/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticsWithLightTreeTestGenerated.java +++ b/compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirDiagnosticsWithLightTreeTestGenerated.java @@ -3551,6 +3551,18 @@ public class FirDiagnosticsWithLightTreeTestGenerated extends AbstractFirDiagnos runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastToTypeParameter.kt"); } + @Test + @TestMetadata("smartcastsFromEquals_differentModule.kt") + public void testSmartcastsFromEquals_differentModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_differentModule.kt"); + } + + @Test + @TestMetadata("smartcastsFromEquals_sameModule.kt") + public void testSmartcastsFromEquals_sameModule() throws Exception { + runTest("compiler/fir/analysis-tests/testData/resolve/smartcasts/smartcastsFromEquals_sameModule.kt"); + } + @Nested @TestMetadata("compiler/fir/analysis-tests/testData/resolve/smartcasts/booleans") @TestDataPath("$PROJECT_ROOT") diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/declaration/FirMethodOfAnyImplementedInInterfaceChecker.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/declaration/FirMethodOfAnyImplementedInInterfaceChecker.kt index f1d8349ec3c..c72501ec5b0 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/declaration/FirMethodOfAnyImplementedInInterfaceChecker.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/declaration/FirMethodOfAnyImplementedInInterfaceChecker.kt @@ -6,19 +6,19 @@ package org.jetbrains.kotlin.fir.analysis.checkers.declaration import org.jetbrains.kotlin.builtins.StandardNames.HASHCODE_NAME +import org.jetbrains.kotlin.diagnostics.DiagnosticReporter +import org.jetbrains.kotlin.diagnostics.reportOn import org.jetbrains.kotlin.fir.analysis.checkers.FirDeclarationPresenter import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext -import org.jetbrains.kotlin.diagnostics.DiagnosticReporter import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors -import org.jetbrains.kotlin.diagnostics.reportOn import org.jetbrains.kotlin.fir.analysis.diagnostics.withSuppressedDiagnostics -import org.jetbrains.kotlin.fir.declarations.* +import org.jetbrains.kotlin.fir.declarations.FirRegularClass +import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction import org.jetbrains.kotlin.fir.declarations.utils.hasBody import org.jetbrains.kotlin.fir.declarations.utils.isInterface import org.jetbrains.kotlin.fir.declarations.utils.isOverride -import org.jetbrains.kotlin.fir.types.isNullableAny +import org.jetbrains.kotlin.fir.resolve.isEquals import org.jetbrains.kotlin.name.CallableId -import org.jetbrains.kotlin.util.OperatorNameConventions.EQUALS import org.jetbrains.kotlin.util.OperatorNameConventions.TO_STRING object FirMethodOfAnyImplementedInInterfaceChecker : FirRegularClassChecker(), FirDeclarationPresenter { @@ -44,11 +44,8 @@ object FirMethodOfAnyImplementedInInterfaceChecker : FirRegularClassChecker(), F (function.name == HASHCODE_NAME || function.name == TO_STRING) ) { methodOfAny = true - } else { - val singleParameter = function.valueParameters.singleOrNull() ?: continue - if (singleParameter.returnTypeRef.isNullableAny && function.name == EQUALS) { - methodOfAny = true - } + } else if (function.isEquals()) { + methodOfAny = true } if (methodOfAny) { diff --git a/compiler/fir/providers/src/org/jetbrains/kotlin/fir/resolve/SupertypeUtils.kt b/compiler/fir/providers/src/org/jetbrains/kotlin/fir/resolve/SupertypeUtils.kt index 6cebf2cf6fe..3e2faeb81a0 100644 --- a/compiler/fir/providers/src/org/jetbrains/kotlin/fir/resolve/SupertypeUtils.kt +++ b/compiler/fir/providers/src/org/jetbrains/kotlin/fir/resolve/SupertypeUtils.kt @@ -17,12 +17,14 @@ import org.jetbrains.kotlin.fir.resolve.substitution.ConeSubstitutor import org.jetbrains.kotlin.fir.resolve.substitution.ConeSubstitutorByMap import org.jetbrains.kotlin.fir.scopes.FirScope import org.jetbrains.kotlin.fir.scopes.FirTypeScope +import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag import org.jetbrains.kotlin.fir.symbols.ensureResolved import org.jetbrains.kotlin.fir.symbols.impl.* import org.jetbrains.kotlin.fir.types.* import org.jetbrains.kotlin.types.model.CaptureStatus import org.jetbrains.kotlin.utils.SmartList import org.jetbrains.kotlin.utils.SmartSet +import org.jetbrains.kotlin.utils.addIfNotNull abstract class SupertypeSupplier { abstract fun forClass(firClass: FirClass, useSiteSession: FirSession): List @@ -44,6 +46,48 @@ abstract class SupertypeSupplier { } } +fun collectSymbolsForType(type: ConeKotlinType, useSiteSession: FirSession): List> { + val lookupTags = mutableListOf() + + fun ConeKotlinType.collectClassIds() { + when (val unwrappedType = lowerBoundIfFlexible().fullyExpandedType(useSiteSession)) { + is ConeClassLikeType -> lookupTags.addIfNotNull(unwrappedType.lookupTag) + is ConeIntersectionType -> unwrappedType.intersectedTypes.forEach { it.collectClassIds() } + else -> {} + } + } + + type.collectClassIds() + return lookupTags.mapNotNull { it.toSymbol(useSiteSession) as? FirClassSymbol<*> } +} + +fun lookupSuperTypes( + type: ConeKotlinType, + lookupInterfaces: Boolean, + deep: Boolean, + useSiteSession: FirSession, + substituteTypes: Boolean, + supertypeSupplier: SupertypeSupplier = SupertypeSupplier.Default, +): List { + return lookupSuperTypes(collectSymbolsForType(type, useSiteSession), lookupInterfaces, deep, useSiteSession, substituteTypes, supertypeSupplier) +} + +fun lookupSuperTypes( + symbols: List>, + lookupInterfaces: Boolean, + deep: Boolean, + useSiteSession: FirSession, + substituteTypes: Boolean, + supertypeSupplier: SupertypeSupplier = SupertypeSupplier.Default, +): List { + return SmartList().also { + val visitedSymbols = SmartSet.create>() + for (symbol in symbols) { + symbol.collectSuperTypes(it, visitedSymbols, deep, lookupInterfaces, substituteTypes, useSiteSession, supertypeSupplier) + } + } +} + fun lookupSuperTypes( klass: FirClass, lookupInterfaces: Boolean, diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt index 42d884fc6bd..4fd22705353 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/ResolveUtils.kt @@ -34,6 +34,7 @@ import org.jetbrains.kotlin.fir.resolve.dfa.FirDataFlowAnalyzer import org.jetbrains.kotlin.fir.resolve.dfa.PropertyStability import org.jetbrains.kotlin.fir.resolve.diagnostics.ConeUnresolvedNameError import org.jetbrains.kotlin.fir.resolve.providers.symbolProvider +import org.jetbrains.kotlin.fir.resolve.transformers.ReturnTypeCalculator import org.jetbrains.kotlin.fir.resolve.transformers.body.resolve.resultType import org.jetbrains.kotlin.fir.scopes.impl.delegatedWrapperData import org.jetbrains.kotlin.fir.scopes.impl.importedFromObjectData @@ -49,6 +50,7 @@ import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlin.resolve.ForbiddenNamedArgumentsTarget import org.jetbrains.kotlin.types.SmartcastStability +import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.utils.addToStdlib.safeAs fun List.toTypeProjections(): Array = diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt index 6e9a4d4c5a0..9a104625a92 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/dfa/FirDataFlowAnalyzer.kt @@ -6,6 +6,7 @@ package org.jetbrains.kotlin.fir.resolve.dfa import org.jetbrains.kotlin.config.LanguageFeature +import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.fir.* import org.jetbrains.kotlin.fir.contracts.FirResolvedContractDescription import org.jetbrains.kotlin.fir.contracts.description.ConeBooleanConstantReference @@ -18,24 +19,23 @@ import org.jetbrains.kotlin.fir.declarations.utils.isLocal import org.jetbrains.kotlin.fir.expressions.* import org.jetbrains.kotlin.fir.references.FirControlFlowGraphReference import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference -import org.jetbrains.kotlin.fir.resolve.PersistentImplicitReceiverStack -import org.jetbrains.kotlin.fir.resolve.ResolutionMode +import org.jetbrains.kotlin.fir.resolve.* import org.jetbrains.kotlin.fir.resolve.dfa.cfg.* import org.jetbrains.kotlin.fir.resolve.dfa.contracts.buildContractFir import org.jetbrains.kotlin.fir.resolve.dfa.contracts.createArgumentsMapping -import org.jetbrains.kotlin.fir.resolve.fullyExpandedType import org.jetbrains.kotlin.fir.resolve.substitution.ConeSubstitutor import org.jetbrains.kotlin.fir.resolve.substitution.ConeSubstitutorByMap -import org.jetbrains.kotlin.fir.resolve.toSymbol import org.jetbrains.kotlin.fir.resolve.transformers.body.resolve.FirAbstractBodyResolveTransformer import org.jetbrains.kotlin.fir.resolve.transformers.body.resolve.resultType +import org.jetbrains.kotlin.fir.scopes.getFunctions +import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirVariableSymbol +import org.jetbrains.kotlin.fir.symbols.impl.* import org.jetbrains.kotlin.fir.types.* import org.jetbrains.kotlin.fir.visitors.transformSingle import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlin.types.ConstantValueKind +import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.utils.addIfNotNull import org.jetbrains.kotlin.utils.addToStdlib.firstIsInstanceOrNull @@ -570,9 +570,7 @@ abstract class FirDataFlowAnalyzer( rightIsNullable -> processEqNull(node, rightOperand, operation.invert(), ::shouldAddImplicationForStatement) } - if (operation == FirOperation.IDENTITY || operation == FirOperation.NOT_IDENTITY) { - processIdentity(node, leftOperand, rightOperand, operation) - } + processPossibleIdentity(node, leftOperand, rightOperand, operation) } /* @@ -639,8 +637,11 @@ abstract class FirDataFlowAnalyzer( node.flow = flow } - private fun processIdentity( - node: EqualityOperatorCallNode, leftOperand: FirExpression, rightOperand: FirExpression, operation: FirOperation + private fun processPossibleIdentity( + node: EqualityOperatorCallNode, + leftOperand: FirExpression, + rightOperand: FirExpression, + operation: FirOperation, ) { val flow = node.flow val expressionVariable = variableStorage.getOrCreateVariable(node.previousFlow, node.fir) @@ -648,6 +649,13 @@ abstract class FirDataFlowAnalyzer( val rightOperandVariable = variableStorage.getOrCreateVariable(node.previousFlow, rightOperand) val leftOperandType = leftOperand.coneType val rightOperandType = rightOperand.coneType + + if (!leftOperandVariable.isReal() && !rightOperandVariable.isReal()) return + + if (operation == FirOperation.EQ || operation == FirOperation.NOT_EQ) { + if (hasOverriddenEquals(leftOperandType)) return + } + val isEq = operation.isEq() if (leftOperandVariable.isReal()) { @@ -663,6 +671,40 @@ abstract class FirDataFlowAnalyzer( node.flow = flow } + private fun hasOverriddenEquals(type: ConeKotlinType): Boolean { + val session = components.session + val symbolsForType = collectSymbolsForType(type, session) + if (symbolsForType.any { it.hasEqualsOverride(session, checkModality = true) }) return true + + val superTypes = lookupSuperTypes( + symbolsForType, + lookupInterfaces = false, + deep = true, + session, + substituteTypes = false + ) + val superClassSymbols = superTypes.mapNotNull { + it.fullyExpandedType(session).toSymbol(session) as? FirRegularClassSymbol + } + + return superClassSymbols.any { it.hasEqualsOverride(session, checkModality = false) } + } + + private fun FirClassSymbol<*>.hasEqualsOverride(session: FirSession, checkModality: Boolean): Boolean { + val status = resolvedStatus + if (checkModality && status.modality != Modality.FINAL) return true + if (status.isExpect) return true + when (classId) { + StandardClassIds.Any, StandardClassIds.String -> return false + } + if (moduleData != session.moduleData) { + return true + } + return session.declaredMemberScope(this) + .getFunctions(OperatorNameConventions.EQUALS) + .any { it.fir.isEquals() } + } + // ----------------------------------- Jump ----------------------------------- fun exitJump(jump: FirJump<*>) { diff --git a/compiler/fir/semantics/src/org/jetbrains/kotlin/fir/resolve/DeclarationUtils.kt b/compiler/fir/semantics/src/org/jetbrains/kotlin/fir/resolve/DeclarationUtils.kt index 8f6043f25db..248bdda85e1 100644 --- a/compiler/fir/semantics/src/org/jetbrains/kotlin/fir/resolve/DeclarationUtils.kt +++ b/compiler/fir/semantics/src/org/jetbrains/kotlin/fir/resolve/DeclarationUtils.kt @@ -15,6 +15,7 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol import org.jetbrains.kotlin.fir.symbols.impl.LookupTagInternals import org.jetbrains.kotlin.fir.types.* import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlin.util.OperatorNameConventions fun FirClassLikeDeclaration.getContainingDeclaration(session: FirSession): FirClassLikeDeclaration? { if (isLocal) { @@ -86,3 +87,10 @@ var FirConstructor.originalConstructorIfTypeAlias: FirConstructor? by FirDeclara val FirConstructorSymbol.isTypeAliasedConstructor: Boolean get() = fir.originalConstructorIfTypeAlias != null + +fun FirSimpleFunction.isEquals(): Boolean { + if (name != OperatorNameConventions.EQUALS) return false + if (valueParameters.size != 1) return false + val parameter = valueParameters.first() + return parameter.returnTypeRef.isNullableAny +} diff --git a/compiler/testData/diagnostics/tests/smartCasts/fakeSmartCastOnEquality.fir.kt b/compiler/testData/diagnostics/tests/smartCasts/fakeSmartCastOnEquality.fir.kt index 278a5576dbc..aa3f520c669 100644 --- a/compiler/testData/diagnostics/tests/smartCasts/fakeSmartCastOnEquality.fir.kt +++ b/compiler/testData/diagnostics/tests/smartCasts/fakeSmartCastOnEquality.fir.kt @@ -34,11 +34,11 @@ fun foo(x: FinalClass?, y: Any) { // OK x.hashCode() // OK - y.use() + y.use() } when (x) { // OK (equals from FinalClass) - y -> y.use() + y -> y.use() } when (y) { // ERROR (equals from Any) @@ -92,4 +92,4 @@ sealed class Sealed { gav() } } -} \ No newline at end of file +} diff --git a/compiler/tests-spec/testData/diagnostics/notLinked/dfa/neg/1.fir.kt b/compiler/tests-spec/testData/diagnostics/notLinked/dfa/neg/1.fir.kt index 1bb1e638156..4d3815eb1cd 100644 --- a/compiler/tests-spec/testData/diagnostics/notLinked/dfa/neg/1.fir.kt +++ b/compiler/tests-spec/testData/diagnostics/notLinked/dfa/neg/1.fir.kt @@ -181,11 +181,11 @@ fun case_12(x: TypealiasNullableStringIndirect, y: TypealiasNullableStringIndire // TESTCASE NUMBER: 13 fun case_13(x: otherpackage.Case13?) = - if ((x == null !is Boolean) !== true) { + if ((x == null !is Boolean) !== true) { throw Exception() } else { - x - x.equals(x) + x + x.equals(x) } // TESTCASE NUMBER: 14 diff --git a/compiler/tests-spec/testData/diagnostics/notLinked/dfa/pos/6.fir.kt b/compiler/tests-spec/testData/diagnostics/notLinked/dfa/pos/6.fir.kt index dbaa694dcef..af94732476e 100644 --- a/compiler/tests-spec/testData/diagnostics/notLinked/dfa/pos/6.fir.kt +++ b/compiler/tests-spec/testData/diagnostics/notLinked/dfa/pos/6.fir.kt @@ -203,7 +203,7 @@ fun case_11(x: TypealiasNullableString?, y: TypealiasNu } else { if (y != z) { if (nullableStringProperty == z) { - if (u != z || u != v) { + if (u != z || u != v) { x x.equals(null) x.propT @@ -790,7 +790,7 @@ fun case_42() { fun case_43(x: TypealiasNullableString) { val z = null - if (x == z && x == z) { + if (x == z && x == z) { x x.hashCode() }