diff --git a/compiler/fir/lightTree/src/org/jetbrains/kotlin/fir/lightTree/converter/DeclarationsConverter.kt b/compiler/fir/lightTree/src/org/jetbrains/kotlin/fir/lightTree/converter/DeclarationsConverter.kt index 32186d07968..ae0d94c3a5c 100644 --- a/compiler/fir/lightTree/src/org/jetbrains/kotlin/fir/lightTree/converter/DeclarationsConverter.kt +++ b/compiler/fir/lightTree/src/org/jetbrains/kotlin/fir/lightTree/converter/DeclarationsConverter.kt @@ -39,6 +39,7 @@ import org.jetbrains.kotlin.fir.types.FirUserTypeRef import org.jetbrains.kotlin.fir.types.impl.* import org.jetbrains.kotlin.lexer.KtModifierKeywordToken import org.jetbrains.kotlin.lexer.KtTokens.* +import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.name.SpecialNames @@ -1393,7 +1394,12 @@ class DeclarationsConverter( isNullable, receiverTypeReference, returnTypeReference - ).apply { valueParameters += valueParametersList.map { it.firValueParameter } } + ).apply { + valueParameters += valueParametersList.map { it.firValueParameter } + if (receiverTypeReference != null) { + annotations += extensionFunctionAnnotation + } + } } /** @@ -1444,4 +1450,17 @@ class DeclarationsConverter( ).apply { annotations += modifiers.annotations } return ValueParameter(isVal, isVar, modifiers, firValueParameter, destructuringDeclaration) } + + private val extensionFunctionAnnotation = FirAnnotationCallImpl( + null, + null, + FirResolvedTypeRefImpl( + null, + ConeClassTypeImpl( + ConeClassLikeLookupTagImpl(ClassId.fromString("kotlin/ExtensionFunctionType")), + emptyArray(), + false + ) + ) + ) } \ No newline at end of file diff --git a/compiler/fir/psi2fir/src/org/jetbrains/kotlin/fir/builder/RawFirBuilder.kt b/compiler/fir/psi2fir/src/org/jetbrains/kotlin/fir/builder/RawFirBuilder.kt index 103a5854414..b69fe4c17a2 100644 --- a/compiler/fir/psi2fir/src/org/jetbrains/kotlin/fir/builder/RawFirBuilder.kt +++ b/compiler/fir/psi2fir/src/org/jetbrains/kotlin/fir/builder/RawFirBuilder.kt @@ -25,6 +25,7 @@ import org.jetbrains.kotlin.fir.types.FirTypeProjection import org.jetbrains.kotlin.fir.types.FirTypeRef import org.jetbrains.kotlin.fir.types.impl.* import org.jetbrains.kotlin.lexer.KtTokens.* +import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.getStrictParentOfType @@ -850,6 +851,9 @@ class RawFirBuilder(session: FirSession, val stubMode: Boolean) : BaseFirBuilder for (valueParameter in unwrappedElement.parameters) { functionType.valueParameters += valueParameter.convert() } + if (functionType.receiverTypeRef != null) { + functionType.annotations += extensionFunctionAnnotation + } functionType } null -> FirErrorTypeRefImpl(source, "Unwrapped type is null") @@ -1356,4 +1360,17 @@ class RawFirBuilder(session: FirSession, val stubMode: Boolean) : BaseFirBuilder return FirExpressionStub(expression.toFirSourceElement()) } } + + private val extensionFunctionAnnotation = FirAnnotationCallImpl( + null, + null, + FirResolvedTypeRefImpl( + null, + ConeClassTypeImpl( + ConeClassLikeLookupTagImpl(ClassId.fromString("kotlin/ExtensionFunctionType")), + emptyArray(), + false + ) + ) + ) } diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/inference/FirCallCompleter.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/inference/FirCallCompleter.kt index d7d91ebb22a..553399404c7 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/inference/FirCallCompleter.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/inference/FirCallCompleter.kt @@ -108,13 +108,16 @@ class FirCallCompleter( stubsForPostponedVariables: Map ): Pair, InferenceSession> { + val needItParam = lambdaArgument.valueParameters.isEmpty() && parameters.size == (if (receiverType != null) 2 else 1) + val itParam = when { - lambdaArgument.valueParameters.isEmpty() && parameters.size == 1 -> { + needItParam -> { val name = Name.identifier("it") + val itType = if (receiverType != null) parameters[1] else parameters.single() FirValueParameterImpl( null, session, - FirResolvedTypeRefImpl(null, parameters.single()), + FirResolvedTypeRefImpl(null, itType), name, FirVariableSymbol(name), defaultValue = null, @@ -129,7 +132,7 @@ class FirCallCompleter( val expectedReturnTypeRef = expectedReturnType?.let { lambdaArgument.returnTypeRef.resolvedTypeFromPrototype(it) } val newLambdaExpression = lambdaArgument.copy( - receiverTypeRef = receiverType?.let { lambdaArgument.receiverTypeRef!!.resolvedTypeFromPrototype(it) }, + receiverTypeRef = receiverType?.let { lambdaArgument.receiverTypeRef?.resolvedTypeFromPrototype(it) }, valueParameters = lambdaArgument.valueParameters.mapIndexed { index, parameter -> parameter.transformReturnTypeRef(StoreType, parameter.returnTypeRef.resolvedTypeFromPrototype(parameters[index])) parameter diff --git a/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/chooseCallableReferenceDependingOnInferredReceiver.txt b/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/chooseCallableReferenceDependingOnInferredReceiver.txt index 7dca31be72c..bac374f7a4b 100644 --- a/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/chooseCallableReferenceDependingOnInferredReceiver.txt +++ b/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/chooseCallableReferenceDependingOnInferredReceiver.txt @@ -30,19 +30,19 @@ FILE: chooseCallableReferenceDependingOnInferredReceiver.kt ^bar R|kotlin/TODO|() } public final fun test(): R|kotlin/Unit| { - R|/myWith|(R|/A.A|(), = myWith@fun (it: R|A|): R|kotlin/Unit| { - lval t1: = #(::foo#) - lval t2: = #(::baz#) - R|/myWith|(R|/B.B|(), = myWith@fun (it: R|B|): R|kotlin/Unit| { - lval a: R|A| = #(::foo#) - lval b: R|B| = #(::foo#) - lval t3: = #(::baz#) - #(::foo#) + R|/myWith|(R|/A.A|(), = myWith@fun R|A|.(): R|kotlin/Unit| { + lval t1: R|A| = R|/bar|(::R|/A.foo|) + lval t2: R|A| = R|/bar|(::R|/A.baz|) + R|/myWith|(R|/B.B|(), = myWith@fun R|B|.(): R|kotlin/Unit| { + lval a: R|A| = R|/bar|(::R|/A.foo|) + lval b: R|B| = R|/bar|(::R|/B.foo|) + lval t3: R|B| = R|/bar|(::R|/B.baz|) + R|/bar|(::#) } ) } ) } - public final inline fun myWith(receiver: R|T|, block: R|kotlin/Function1|): R|R| { + public final inline fun myWith(receiver: R|T|, block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|R| { ^myWith R|kotlin/TODO|() } diff --git a/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/moreSpecificAmbiguousExtensions.txt b/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/moreSpecificAmbiguousExtensions.txt index f169da5f2b2..11b5d6195f7 100644 --- a/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/moreSpecificAmbiguousExtensions.txt +++ b/compiler/fir/resolve/testData/diagnostics/callableReferences/fromBasicDiagnosticTests/moreSpecificAmbiguousExtensions.txt @@ -12,9 +12,9 @@ FILE: moreSpecificAmbiguousExtensions.kt lval extFun2: R|kotlin/reflect/KFunction2| = Q|IB|::R|/extFun| } public final fun testWithExpectedType(): R|kotlin/Unit| { - lval extFun_AB_A: R|kotlin/Function2| = Q|IA|::R|/extFun| - lval extFun_AA_B: R|kotlin/Function2| = Q|IB|::R|/extFun| - lval extFun_BB_A: R|kotlin/Function2| = Q|IA|::R|/extFun| - lval extFun_BA_B: R|kotlin/Function2| = Q|IB|::R|/extFun| - lval extFun_BB_B: R|kotlin/Function2| = Q|IB|::R|/extFun| + lval extFun_AB_A: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2| = Q|IA|::R|/extFun| + lval extFun_AA_B: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2| = Q|IB|::R|/extFun| + lval extFun_BB_A: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2| = Q|IA|::R|/extFun| + lval extFun_BA_B: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2| = Q|IB|::R|/extFun| + lval extFun_BB_B: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2| = Q|IB|::R|/extFun| } diff --git a/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.kt b/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.kt new file mode 100644 index 00000000000..d38dbf0d514 --- /dev/null +++ b/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.kt @@ -0,0 +1,41 @@ +interface A { + fun foo() +} + +fun myWith(receiver: T, block: T.() -> Unit) { + receiver.block() +} + +fun T.myApply(block: T.() -> Unit) { + this.block() +} + +fun withA(block: A.() -> Unit) {} + +fun test_1() { + withA { + foo() + } +} + +fun test_2(a: A) { + myWith(a) { + foo() + } +} + +fun test_3(a: A) { + a.myApply { + foo() + } +} + +fun complexLambda(block: Int.(String) -> Unit) {} + +fun test_4() { + complexLambda { + inc() + this.inc() + it.length + } +} \ No newline at end of file diff --git a/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.txt b/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.txt new file mode 100644 index 00000000000..9434ca87de5 --- /dev/null +++ b/compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.txt @@ -0,0 +1,41 @@ +FILE: lambdaWithReceiver.kt + public abstract interface A : R|kotlin/Any| { + public abstract fun foo(): R|kotlin/Unit| + + } + public final fun myWith(receiver: R|T|, block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|kotlin/Unit| { + R|/receiver|.#() + } + public final fun R|T|.myApply(block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|kotlin/Unit| { + this@R|/myApply|.#() + } + public final fun withA(block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|kotlin/Unit| { + } + public final fun test_1(): R|kotlin/Unit| { + R|/withA|( = withA@fun R|A|.(): R|kotlin/Unit| { + this@R|/A|.R|/A.foo|() + } + ) + } + public final fun test_2(a: R|A|): R|kotlin/Unit| { + R|/myWith|(R|/a|, = myWith@fun R|A|.(): R|kotlin/Unit| { + this@R|/A|.R|/A.foo|() + } + ) + } + public final fun test_3(a: R|A|): R|kotlin/Unit| { + R|/a|.R|/myApply|( = myApply@fun R|A|.(): R|kotlin/Unit| { + this@R|/A|.R|/A.foo|() + } + ) + } + public final fun complexLambda(block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function2|): R|kotlin/Unit| { + } + public final fun test_4(): R|kotlin/Unit| { + R|/complexLambda|( = complexLambda@fun R|kotlin/Int|.(it: R|kotlin/String|): R|kotlin/Unit| { + this@R|kotlin/Int|.R|kotlin/Int.inc|() + this@R|special/anonymous|.R|kotlin/Int.inc|() + R|/it|.R|kotlin/String.length| + } + ) + } diff --git a/compiler/fir/resolve/testData/resolve/functionTypes.txt b/compiler/fir/resolve/testData/resolve/functionTypes.txt index 3f2ecccc3a7..c3213480f73 100644 --- a/compiler/fir/resolve/testData/resolve/functionTypes.txt +++ b/compiler/fir/resolve/testData/resolve/functionTypes.txt @@ -4,7 +4,7 @@ FILE: functionTypes.kt } public final fun R|kotlin/collections/List|.simpleMap(f: R|kotlin/Function1|): R|R| { } - public final fun simpleWith(t: R|T|, f: R|kotlin/Function1|): R|kotlin/Unit| { + public final fun simpleWith(t: R|T|, f: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|kotlin/Unit| { ^simpleWith R|/t|.#() } public abstract interface KMutableProperty1 : R|KProperty1|, R|KMutableProperty| { diff --git a/compiler/fir/resolve/testData/resolve/smartcasts/implicitReceivers.txt b/compiler/fir/resolve/testData/resolve/smartcasts/implicitReceivers.txt index d321f2c8d89..8554874eae0 100644 --- a/compiler/fir/resolve/testData/resolve/smartcasts/implicitReceivers.txt +++ b/compiler/fir/resolve/testData/resolve/smartcasts/implicitReceivers.txt @@ -8,7 +8,7 @@ FILE: implicitReceivers.kt } } - public final fun R|T|.with(block: R|kotlin/Function1|): R|kotlin/Unit| { + public final fun R|T|.with(block: @R|kotlin/ExtensionFunctionType|() R|kotlin/Function1|): R|kotlin/Unit| { } public final fun R|kotlin/Any?|.test_1(): R|kotlin/Unit| { when () { @@ -41,9 +41,9 @@ FILE: implicitReceivers.kt #() } public final fun test_3(a: R|kotlin/Any|, b: R|kotlin/Any|, c: R|kotlin/Any|): R|kotlin/Unit| { - R|kotlin/with|(R|/a|, = wa@fun R|kotlin/Any|.(it: R|kotlin/Any|): R|kotlin/Unit| { - R|kotlin/with|(R|/b|, = wb@fun R|kotlin/Any|.(it: R|kotlin/Any|): R|kotlin/Unit| { - R|kotlin/with|(R|/c|, = wc@fun R|kotlin/Any|.(it: R|kotlin/Any|): R|kotlin/Unit| { + R|kotlin/with|(R|/a|, = wa@fun R|kotlin/Any|.(): R|kotlin/Unit| { + R|kotlin/with|(R|/b|, = wb@fun R|kotlin/Any|.(): R|kotlin/Unit| { + R|kotlin/with|(R|/c|, = wc@fun R|kotlin/Any|.(): R|kotlin/Unit| { (this@R|special/anonymous| as R|A|) this@R|special/anonymous|.R|/A.foo|() this@R|/A|.R|/A.foo|() diff --git a/compiler/fir/resolve/testData/resolve/stdlib/implicitReceiverOrder.txt b/compiler/fir/resolve/testData/resolve/stdlib/implicitReceiverOrder.txt index 7c84a2976de..76e121c35cd 100644 --- a/compiler/fir/resolve/testData/resolve/stdlib/implicitReceiverOrder.txt +++ b/compiler/fir/resolve/testData/resolve/stdlib/implicitReceiverOrder.txt @@ -28,16 +28,16 @@ FILE: implicitReceiverOrder.kt } public final fun test(a: R|A|, b: R|B|): R|kotlin/Unit| { - R|kotlin/with|(R|/b|, = with@fun R|B|.(it: R|B|): R|kotlin/Unit| { - R|kotlin/with|(R|/a|, = with@fun R|A|.(it: R|A|): R|kotlin/Unit| { + R|kotlin/with|(R|/b|, = with@fun R|B|.(): R|kotlin/Unit| { + R|kotlin/with|(R|/a|, = with@fun R|A|.(): R|kotlin/Unit| { this@R|/A|.R|/A.foo|() (this@R|/B|, this@R|special/anonymous|).R|/B.bar|() } ) } ) - R|kotlin/with|(R|/a|, = with@fun R|A|.(it: R|A|): R|kotlin/Unit| { - R|kotlin/with|(R|/b|, = with@fun R|B|.(it: R|B|): R|kotlin/Unit| { + R|kotlin/with|(R|/a|, = with@fun R|A|.(): R|kotlin/Unit| { + R|kotlin/with|(R|/b|, = with@fun R|B|.(): R|kotlin/Unit| { this@R|/B|.R|/B.foo|() (this@R|/A|, this@R|special/anonymous|).R|/A.bar|() } diff --git a/compiler/fir/resolve/testData/resolve/stdlib/mapList.txt b/compiler/fir/resolve/testData/resolve/stdlib/mapList.txt index 2a52cc6a452..72876a2221d 100644 --- a/compiler/fir/resolve/testData/resolve/stdlib/mapList.txt +++ b/compiler/fir/resolve/testData/resolve/stdlib/mapList.txt @@ -5,7 +5,7 @@ FILE: mapList.kt R|/it|.R|kotlin/Int.plus|(R|/it|) } ) - R|/u|.R|/applyX||>( = applyX@fun R|kotlin/collections/List|.(it: R|kotlin/collections/List|): R|kotlin/Unit| { + R|/u|.R|/applyX||>( = applyX@fun R|kotlin/collections/List|.(): R|kotlin/Unit| { this@R|special/anonymous|.R|FakeOverride|(Int(1)) this@R|kotlin/collections/List|.R|FakeOverride|(Int(1)) } diff --git a/compiler/fir/resolve/testData/resolve/stdlib/multipleImplicitReceivers.txt b/compiler/fir/resolve/testData/resolve/stdlib/multipleImplicitReceivers.txt index d0dbc2eb2cd..ef1167da8b3 100644 --- a/compiler/fir/resolve/testData/resolve/stdlib/multipleImplicitReceivers.txt +++ b/compiler/fir/resolve/testData/resolve/stdlib/multipleImplicitReceivers.txt @@ -25,10 +25,10 @@ FILE: multipleImplicitReceivers.kt } public final fun test(fooImpl: R|IFoo|, invokeImpl: R|IInvoke|): R|kotlin/Unit| { - R|kotlin/with|(Q|A|, = with@fun R|A|.(it: R|A|): R|kotlin/Unit| { - R|kotlin/with|(R|/fooImpl|, = with@fun R|IFoo|.(it: R|IFoo|): R|kotlin/Unit| { + R|kotlin/with|(Q|A|, = with@fun R|A|.(): R|kotlin/Unit| { + R|kotlin/with|(R|/fooImpl|, = with@fun R|IFoo|.(): R|kotlin/Unit| { (this@R|/IFoo|, this@R|special/anonymous|).R|/IFoo.foo| - R|kotlin/with|(R|/invokeImpl|, = with@fun R|IInvoke|.(it: R|IInvoke|): R|kotlin/Unit| { + R|kotlin/with|(R|/invokeImpl|, = with@fun R|IInvoke|.(): R|kotlin/Unit| { (this@R|/IInvoke|, (this@R|/IFoo|, this@R|special/anonymous|).R|/IFoo.foo|).R|/IInvoke.invoke|() } ) diff --git a/compiler/fir/resolve/testData/resolve/stdlib/recursiveBug.txt b/compiler/fir/resolve/testData/resolve/stdlib/recursiveBug.txt index 8d2452a314d..f616d969d6d 100644 --- a/compiler/fir/resolve/testData/resolve/stdlib/recursiveBug.txt +++ b/compiler/fir/resolve/testData/resolve/stdlib/recursiveBug.txt @@ -4,7 +4,7 @@ FILE: recursiveBug.kt super() } - public final val result: R|kotlin/String| = this@R|/Foo|.R|kotlin/run|( = run@fun R|Foo|.(it: R|Foo|): R|kotlin/String| { + public final val result: R|kotlin/String| = this@R|/Foo|.R|kotlin/run|( = run@fun R|Foo|.(): R|kotlin/String| { R|/name|.R|FakeOverride|() } ) diff --git a/compiler/fir/resolve/testData/resolve/stdlib/topLevelResolve.txt b/compiler/fir/resolve/testData/resolve/stdlib/topLevelResolve.txt index 8c42d62b4a2..73ccc6557ab 100644 --- a/compiler/fir/resolve/testData/resolve/stdlib/topLevelResolve.txt +++ b/compiler/fir/resolve/testData/resolve/stdlib/topLevelResolve.txt @@ -35,7 +35,7 @@ FILE: topLevelResolve.kt R|/it|.R|kotlin/String.length| } ) - lval viaWith: R|kotlin/collections/List| = R|kotlin/with||, R|kotlin/collections/List|>(R|kotlin/collections/listOf|(Int(42)), = with@fun R|kotlin/collections/List|.(it: R|kotlin/collections/List|): R|kotlin/collections/List| { + lval viaWith: R|kotlin/collections/List| = R|kotlin/with||, R|kotlin/collections/List|>(R|kotlin/collections/listOf|(Int(42)), = with@fun R|kotlin/collections/List|.(): R|kotlin/collections/List| { this@R|special/anonymous|.R|kotlin/collections/map|( = map@fun (it: R|kotlin/Int|): R|kotlin/Int| { R|/it|.R|kotlin/Int.times|(R|/it|) } @@ -44,11 +44,11 @@ FILE: topLevelResolve.kt ) } public final fun testWith(): R|kotlin/Unit| { - lval length: R|kotlin/Int| = R|kotlin/with|(String(), = with@fun R|kotlin/String|.(it: R|kotlin/String|): R|kotlin/Int| { + lval length: R|kotlin/Int| = R|kotlin/with|(String(), = with@fun R|kotlin/String|.(): R|kotlin/Int| { this@R|kotlin/String|.R|kotlin/String.length| } ) - lval indices: R|kotlin/ranges/IntRange| = R|kotlin/with|(String(), = with@fun R|kotlin/String|.(it: R|kotlin/String|): R|kotlin/ranges/IntRange| { + lval indices: R|kotlin/ranges/IntRange| = R|kotlin/with|(String(), = with@fun R|kotlin/String|.(): R|kotlin/ranges/IntRange| { this@R|special/anonymous|.R|kotlin/text/indices| } ) diff --git a/compiler/fir/resolve/tests/org/jetbrains/kotlin/fir/FirResolveTestCaseGenerated.java b/compiler/fir/resolve/tests/org/jetbrains/kotlin/fir/FirResolveTestCaseGenerated.java index d2990c3b600..c6eb468a2ee 100644 --- a/compiler/fir/resolve/tests/org/jetbrains/kotlin/fir/FirResolveTestCaseGenerated.java +++ b/compiler/fir/resolve/tests/org/jetbrains/kotlin/fir/FirResolveTestCaseGenerated.java @@ -326,6 +326,11 @@ public class FirResolveTestCaseGenerated extends AbstractFirResolveTestCase { runTest("compiler/fir/resolve/testData/resolve/expresssions/lambda.kt"); } + @TestMetadata("lambdaWithReceiver.kt") + public void testLambdaWithReceiver() throws Exception { + runTest("compiler/fir/resolve/testData/resolve/expresssions/lambdaWithReceiver.kt"); + } + @TestMetadata("localConstructor.kt") public void testLocalConstructor() throws Exception { runTest("compiler/fir/resolve/testData/resolve/expresssions/localConstructor.kt");