FIR2IR: towards comprehensive visits in implicit cast inserter

This commit is contained in:
Jinseong Jeon
2020-10-25 23:58:20 -07:00
committed by teamcityserver
parent 5f6d2c5362
commit 2424f2438c
9 changed files with 240 additions and 64 deletions
@@ -6,25 +6,22 @@
package org.jetbrains.kotlin.fir.backend
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirExpressionWithSmartcast
import org.jetbrains.kotlin.fir.expressions.FirThisReceiverExpression
import org.jetbrains.kotlin.fir.references.FirReference
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.*
import org.jetbrains.kotlin.fir.references.*
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.scope
import org.jetbrains.kotlin.fir.scopes.FakeOverrideTypeCalculator
import org.jetbrains.kotlin.fir.symbols.AbstractFirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirAnonymousFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.fir.symbols.impl.*
import org.jetbrains.kotlin.fir.types.*
import org.jetbrains.kotlin.fir.visitors.FirDefaultVisitor
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.util.coerceToUnit
import org.jetbrains.kotlin.ir.util.coerceToUnitIfNeeded
import org.jetbrains.kotlin.name.Name
@@ -38,20 +35,184 @@ class Fir2IrImplicitCastInserter(
private fun ConeKotlinType.toIrType(): IrType = with(typeConverter) { toIrType() }
override fun visitElement(element: FirElement, data: IrElement): IrElement {
TODO("Should not be here: ${element.render()}")
TODO("Should not be here: ${element::class}: ${element.render()}")
}
// TODO: can be private once this visitor becomes more comprehensive, also generalized
internal fun IrExpression.insertImplicitNotNullCastIfNeeded(expression: FirExpression): IrExpression {
if (this is IrTypeOperatorCall && this.operator == IrTypeOperator.IMPLICIT_NOTNULL) {
override fun visitAnnotationCall(annotationCall: FirAnnotationCall, data: IrElement): IrElement = data
override fun visitAnonymousObject(anonymousObject: FirAnonymousObject, data: IrElement): IrElement = data
override fun visitAnonymousFunction(anonymousFunction: FirAnonymousFunction, data: IrElement): IrElement = data
override fun visitBinaryLogicExpression(binaryLogicExpression: FirBinaryLogicExpression, data: IrElement): IrElement = data
// TODO: maybe a place to do coerceIntToAnotherIntegerType?
override fun visitComparisonExpression(comparisonExpression: FirComparisonExpression, data: IrElement): IrElement = data
override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: IrElement): IrElement = data
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: IrElement): IrElement = data
override fun <T> visitConstExpression(constExpression: FirConstExpression<T>, data: IrElement): IrElement = data
override fun visitThisReceiverExpression(thisReceiverExpression: FirThisReceiverExpression, data: IrElement): IrElement = data
override fun visitQualifiedAccessExpression(qualifiedAccessExpression: FirQualifiedAccessExpression, data: IrElement): IrElement = data
override fun visitResolvedQualifier(resolvedQualifier: FirResolvedQualifier, data: IrElement): IrElement = data
override fun visitGetClassCall(getClassCall: FirGetClassCall, data: IrElement): IrElement = data
override fun visitFunctionCall(functionCall: FirFunctionCall, data: IrElement): IrElement = data
override fun visitCheckNotNullCall(checkNotNullCall: FirCheckNotNullCall, data: IrElement): IrElement = data
override fun visitCheckedSafeCallSubject(checkedSafeCallSubject: FirCheckedSafeCallSubject, data: IrElement): IrElement = data
override fun visitSafeCallExpression(safeCallExpression: FirSafeCallExpression, data: IrElement): IrElement = data
override fun visitStringConcatenationCall(stringConcatenationCall: FirStringConcatenationCall, data: IrElement): IrElement = data
// TODO: element-wise cast?
override fun visitArrayOfCall(arrayOfCall: FirArrayOfCall, data: IrElement): IrElement = data
// TODO: something to do w.r.t. SAM?
override fun visitLambdaArgumentExpression(lambdaArgumentExpression: FirLambdaArgumentExpression, data: IrElement): IrElement = data
// TODO: element-wise cast?
override fun visitNamedArgumentExpression(namedArgumentExpression: FirNamedArgumentExpression, data: IrElement): IrElement = data
// TODO: element-wise cast?
override fun visitVarargArgumentsExpression(varargArgumentsExpression: FirVarargArgumentsExpression, data: IrElement): IrElement = data
// TODO: element-wise cast?
override fun visitSpreadArgumentExpression(spreadArgumentExpression: FirSpreadArgumentExpression, data: IrElement): IrElement = data
// ==================================================================================
override fun visitExpression(expression: FirExpression, data: IrElement): IrElement {
return when(expression) {
is FirBlock -> (data as IrContainerExpression).insertImplicitCasts()
is FirUnitExpression -> (data as IrExpression).let { it.coerceToUnitIfNeeded(it.type, irBuiltIns) }
else -> data
}
}
override fun visitStatement(statement: FirStatement, data: IrElement): IrElement {
return when (statement) {
is FirTypeAlias -> data
FirStubStatement -> data
is FirUnitExpression -> (data as IrExpression).let { it.coerceToUnitIfNeeded(it.type, irBuiltIns) }
is FirBlock -> (data as IrContainerExpression).insertImplicitCasts()
else -> statement.accept(this, data)
}
}
// ==================================================================================
override fun visitWhenExpression(whenExpression: FirWhenExpression, data: IrElement): IrElement {
if (data is IrBlock) {
return data.insertImplicitCasts()
}
val irWhen = data as IrWhen
if (irWhen.branches.size != whenExpression.branches.size) {
return data
}
val firBranchMap = irWhen.branches.zip(whenExpression.branches).toMap()
irWhen.branches.replaceAll {
visitWhenBranch(firBranchMap[it]!!, it)
}
return data
}
override fun visitWhenSubjectExpression(whenSubjectExpression: FirWhenSubjectExpression, data: IrElement): IrElement = data
// TODO: cast `condition` expression to boolean?
override fun visitWhenBranch(whenBranch: FirWhenBranch, data: IrElement): IrBranch {
val irBranch = data as IrBranch
(irBranch.result as? IrContainerExpression)?.let {
irBranch.result = it.insertImplicitCasts()
}
return data
}
// TODO: Need to visit lhs/rhs branches?
override fun visitElvisExpression(elvisExpression: FirElvisExpression, data: IrElement): IrElement = data
// ==================================================================================
// TODO: cast `condition` expression to boolean?
override fun visitDoWhileLoop(doWhileLoop: FirDoWhileLoop, data: IrElement): IrElement {
val loop = data as IrDoWhileLoop
(loop.body as? IrContainerExpression)?.let {
loop.body = it.insertImplicitCasts()
}
return data
}
// TODO: cast `condition` expression to boolean?
override fun visitWhileLoop(whileLoop: FirWhileLoop, data: IrElement): IrElement {
val loop = data as IrWhileLoop
(loop.body as? IrContainerExpression)?.let {
loop.body = it.insertImplicitCasts()
}
return data
}
override fun visitBreakExpression(breakExpression: FirBreakExpression, data: IrElement): IrElement = data
override fun visitContinueExpression(continueExpression: FirContinueExpression, data: IrElement): IrElement = data
// ==================================================================================
override fun visitTryExpression(tryExpression: FirTryExpression, data: IrElement): IrElement {
val irTry = data as IrTry
(irTry.finallyExpression as? IrContainerExpression)?.let {
irTry.finallyExpression = it.insertImplicitCasts()
}
return data
}
override fun visitThrowExpression(throwExpression: FirThrowExpression, data: IrElement): IrElement =
(data as IrThrow).cast(throwExpression, throwExpression.exception.typeRef, throwExpression.typeRef)
override fun visitBlock(block: FirBlock, data: IrElement): IrElement =
(data as? IrContainerExpression)?.insertImplicitCasts() ?: data
override fun visitReturnExpression(returnExpression: FirReturnExpression, data: IrElement): IrElement {
val irReturn = data as IrReturn
val expectedType = returnExpression.target.labeledElement.returnTypeRef
irReturn.value = irReturn.value.cast(returnExpression.result, returnExpression.result.typeRef, expectedType)
return data
}
// ==================================================================================
internal fun IrExpression.cast(expression: FirExpression, valueType: FirTypeRef, expectedType: FirTypeRef): IrExpression {
if (this is IrTypeOperatorCall) {
return this
}
// TODO: Other conditions to check?
return when {
this is IrContainerExpression -> {
insertImplicitCasts()
}
expectedType.isUnit -> {
coerceToUnitIfNeeded(type, irBuiltIns)
}
// TODO: Not exactly matched with psi2ir yet...
valueType.hasEnhancedNullability() -> {
insertImplicitNotNullCastIfNeeded(expression)
}
// TODO: coerceIntToAnotherIntegerType
// TODO: even implicitCast call can be here?
else -> this
}
}
private fun IrExpression.insertImplicitNotNullCastIfNeeded(expression: FirExpression): IrExpression {
// [TypeOperatorLowering] will retrieve the source (from start offset to end offset) as an assertion message.
// Avoid type casting if we can't determine the source for some reasons, e.g., implicit `this` receiver.
if (expression.source == null ||
expression.typeRef.coneTypeSafe<ConeKotlinType>()?.hasEnhancedNullability != true
) {
if (expression.source == null) {
return this
}
return IrTypeOperatorCallImpl(
@@ -64,8 +225,7 @@ class Fir2IrImplicitCastInserter(
)
}
// TODO: can be private once this visitor becomes more comprehensive
internal fun IrContainerExpression.insertImplicitCasts(): IrContainerExpression {
private fun IrContainerExpression.insertImplicitCasts(): IrContainerExpression {
if (statements.isEmpty()) return this
val lastIndex = statements.lastIndex
@@ -53,7 +53,7 @@ class Fir2IrVisitor(
private val annotationGenerator = AnnotationGenerator(this)
private val implicitCastInserter = Fir2IrImplicitCastInserter(components, this)
internal val implicitCastInserter = Fir2IrImplicitCastInserter(components, this)
private val memberGenerator = ClassMemberGenerator(components, this, conversionScope)
@@ -64,7 +64,7 @@ class Fir2IrVisitor(
private fun <T : IrDeclaration> applyParentFromStackTo(declaration: T): T = conversionScope.applyParentFromStackTo(declaration)
override fun visitElement(element: FirElement, data: Any?): IrElement {
TODO("Should not be here: ${element.render()}")
TODO("Should not be here: ${element::class} ${element.render()}")
}
override fun visitField(field: FirField, data: Any?): IrField {
@@ -260,7 +260,10 @@ class Fir2IrVisitor(
variable, conversionScope.parentFromStack(), if (isNextVariable) IrDeclarationOrigin.FOR_LOOP_VARIABLE else null
)
if (initializer != null) {
irVariable.initializer = convertToIrExpression(initializer)
irVariable.initializer =
with(implicitCastInserter) {
convertToIrExpression(initializer).cast(initializer, initializer.typeRef, variable.returnTypeRef)
}
}
return irVariable
}
@@ -288,6 +291,8 @@ class Fir2IrVisitor(
},
convertToIrExpression(result)
)
}.let {
returnExpression.accept(implicitCastInserter, it)
}
}
@@ -470,10 +475,7 @@ class Fir2IrVisitor(
}
}
}.let {
// TODO: expression(implicitCastInserter, it) as IrExpression
with(implicitCastInserter) {
it.insertImplicitNotNullCastIfNeeded(expression)
}
expression.accept(implicitCastInserter, it) as IrExpression
}
}
@@ -563,11 +565,6 @@ class Fir2IrVisitor(
startOffset, endOffset, type, origin,
mapToIrStatements().filterNotNull()
)
}.also {
// TODO: can remove this once implicit cast inserter visits more expression kinds directly
with(implicitCastInserter) {
it.insertImplicitCasts()
}
}
}
}
@@ -655,7 +652,7 @@ class Fir2IrVisitor(
val irBranches = whenExpression.branches.mapNotNullTo(mutableListOf()) { branch ->
branch.takeIf {
it.condition !is FirElseIfTrueCondition || it.result.statements.isNotEmpty()
}?.toIrWhenBranch()
}?.toIrWhenBranch(whenExpression.typeRef)
}
if (whenExpression.isExhaustive && whenExpression.branches.none { it.condition is FirElseIfTrueCondition }) {
val irResult = IrCallImpl(
@@ -677,6 +674,8 @@ class Fir2IrVisitor(
) whenExpression.typeRef.toIrType() else irBuiltIns.unitType
)
}
}.also {
whenExpression.accept(implicitCastInserter, it)
}
}
@@ -708,10 +707,12 @@ class Fir2IrVisitor(
}
}
private fun FirWhenBranch.toIrWhenBranch(): IrBranch {
private fun FirWhenBranch.toIrWhenBranch(whenExpressionType: FirTypeRef): IrBranch {
return convertWithOffsets { startOffset, endOffset ->
val condition = condition
val irResult = convertToIrExpression(result)
val irResult = with(implicitCastInserter) {
convertToIrExpression(result).cast(result, result.typeRef, whenExpressionType)
}
if (condition is FirElseIfTrueCondition) {
IrElseBranchImpl(IrConstImpl.boolean(irResult.startOffset, irResult.endOffset, irBuiltIns.booleanType, true), irResult)
} else {
@@ -741,6 +742,8 @@ class Fir2IrVisitor(
condition = convertToIrExpression(doWhileLoop.condition)
loopMap.remove(doWhileLoop)
}
}.also {
doWhileLoop.accept(implicitCastInserter, it)
}
}
@@ -756,6 +759,8 @@ class Fir2IrVisitor(
body = whileLoop.block.convertToIrExpressionOrBlock(origin.takeIf { it != IrStatementOrigin.WHILE_LOOP })
loopMap.remove(whileLoop)
}
}.also {
whileLoop.accept(implicitCastInserter, it)
}
}
@@ -545,13 +545,20 @@ class CallAndReferenceGenerator(
argument: FirExpression,
parameter: FirValueParameter?,
annotationMode: Boolean = false
): IrExpression =
with(adapterGenerator) {
visitor.convertToIrExpression(argument, annotationMode)
.applySamConversionIfNeeded(argument, parameter)
.applySuspendConversionIfNeeded(argument, parameter)
.applyAssigningArrayElementsToVarargInNamedForm(argument, parameter)
): IrExpression {
var irArgument = visitor.convertToIrExpression(argument, annotationMode)
if (parameter != null) {
with(visitor.implicitCastInserter) {
irArgument = irArgument.cast(argument, argument.typeRef, parameter.returnTypeRef)
}
}
with(adapterGenerator) {
irArgument = irArgument.applySuspendConversionIfNeeded(argument, parameter)
}
return irArgument
.applySamConversionIfNeeded(argument, parameter)
.applyAssigningArrayElementsToVarargInNamedForm(argument, parameter)
}
private fun IrExpression.applySamConversionIfNeeded(
argument: FirExpression,
@@ -187,7 +187,18 @@ internal class ClassMemberGenerator(
declarationStorage.enterScope(this@initializeBackingField)
// NB: initializer can be already converted
if (initializer == null && initializerExpression != null) {
initializer = irFactory.createExpressionBody(visitor.convertToIrExpression(initializerExpression))
initializer = irFactory.createExpressionBody(
run {
val irExpression = visitor.convertToIrExpression(initializerExpression)
if (property.delegate == null) {
with(visitor.implicitCastInserter) {
irExpression.cast(initializerExpression, initializerExpression.typeRef, property.returnTypeRef)
}
} else {
irExpression
}
}
)
}
declarationStorage.leaveScope(this@initializeBackingField)
}
@@ -183,6 +183,9 @@ fun FirTypeRef.isUnsafeVarianceType(session: FirSession): Boolean {
return coneTypeSafe<ConeKotlinType>()?.isUnsafeVarianceType(session) == true
}
fun FirTypeRef.hasEnhancedNullability(): Boolean =
coneTypeSafe<ConeKotlinType>()?.hasEnhancedNullability == true
// Unlike other cases, return types may be implicit, i.e. unresolved
// But in that cases newType should also be `null`
fun FirTypeRef.withReplacedReturnType(newType: ConeKotlinType?): FirTypeRef {
+5 -4
View File
@@ -44,10 +44,11 @@ FILE fqName:<root> fileName:/bangbang.kt
BRANCH
if: TYPE_OP type=kotlin.Boolean origin=INSTANCEOF typeOperand=kotlin.String?
GET_VAR 'a: X of <root>.test4 declared in <root>.test4' type=X of <root>.test4 origin=null
then: CALL 'public final fun CHECK_NOT_NULL <T0> (arg0: T0 of kotlin.internal.ir.CHECK_NOT_NULL?): T0 of kotlin.internal.ir.CHECK_NOT_NULL declared in kotlin.internal.ir' type=kotlin.String origin=EXCLEXCL
<T0>: kotlin.String
arg0: TYPE_OP type=kotlin.String? origin=IMPLICIT_CAST typeOperand=kotlin.String?
GET_VAR 'a: X of <root>.test4 declared in <root>.test4' type=X of <root>.test4 origin=null
then: TYPE_OP type=kotlin.Unit origin=IMPLICIT_COERCION_TO_UNIT typeOperand=kotlin.Unit
CALL 'public final fun CHECK_NOT_NULL <T0> (arg0: T0 of kotlin.internal.ir.CHECK_NOT_NULL?): T0 of kotlin.internal.ir.CHECK_NOT_NULL declared in kotlin.internal.ir' type=kotlin.String origin=EXCLEXCL
<T0>: kotlin.String
arg0: TYPE_OP type=kotlin.String? origin=IMPLICIT_CAST typeOperand=kotlin.String?
GET_VAR 'a: X of <root>.test4 declared in <root>.test4' type=X of <root>.test4 origin=null
WHEN type=kotlin.Unit origin=IF
BRANCH
if: TYPE_OP type=kotlin.Boolean origin=INSTANCEOF typeOperand=kotlin.String?
@@ -1,13 +0,0 @@
FILE fqName:<root> fileName:/whenCoercedToUnit.kt
FUN name:foo visibility:public modality:FINAL <> (x:kotlin.Int) returnType:kotlin.Unit
VALUE_PARAMETER name:x index:0 type:kotlin.Int
BLOCK_BODY
BLOCK type=kotlin.Unit origin=WHEN
VAR IR_TEMPORARY_VARIABLE name:tmp_0 type:kotlin.Int [val]
GET_VAR 'x: kotlin.Int declared in <root>.foo' type=kotlin.Int origin=null
WHEN type=kotlin.Unit origin=WHEN
BRANCH
if: CALL 'public final fun EQEQ (arg0: kotlin.Any?, arg1: kotlin.Any?): kotlin.Boolean declared in kotlin.internal.ir' type=kotlin.Boolean origin=EQEQ
arg0: GET_VAR 'val tmp_0: kotlin.Int [val] declared in <root>.foo' type=kotlin.Int origin=null
arg1: CONST Int type=kotlin.Int value=0
then: CONST Int type=kotlin.Int value=0
@@ -1,3 +1,4 @@
// FIR_IDENTICAL
// WITH_RUNTIME
fun foo(x: Int) {
+7 -6
View File
@@ -27,9 +27,10 @@ FILE fqName:<root> fileName:/builtinMap.kt
$receiver: VALUE_PARAMETER name:<this> type:java.util.LinkedHashMap<K1 of <root>.plus?, V1 of <root>.plus?>
BLOCK_BODY
RETURN type=kotlin.Nothing from='local final fun <anonymous> (): kotlin.Unit declared in <root>.plus'
CALL 'public open fun put (p0: K of java.util.LinkedHashMap?, p1: V of java.util.LinkedHashMap?): V of java.util.LinkedHashMap? declared in java.util.LinkedHashMap' type=V1 of <root>.plus? origin=null
$this: GET_VAR '<this>: java.util.LinkedHashMap<K1 of <root>.plus?, V1 of <root>.plus?> declared in <root>.plus.<anonymous>' type=java.util.LinkedHashMap<K1 of <root>.plus?, V1 of <root>.plus?> origin=null
p0: CALL 'public final fun <get-first> (): A of kotlin.Pair declared in kotlin.Pair' type=K1 of <root>.plus origin=GET_PROPERTY
$this: GET_VAR 'pair: kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> declared in <root>.plus' type=kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> origin=null
p1: CALL 'public final fun <get-second> (): B of kotlin.Pair declared in kotlin.Pair' type=V1 of <root>.plus origin=GET_PROPERTY
$this: GET_VAR 'pair: kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> declared in <root>.plus' type=kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> origin=null
TYPE_OP type=kotlin.Unit origin=IMPLICIT_COERCION_TO_UNIT typeOperand=kotlin.Unit
CALL 'public open fun put (p0: K of java.util.LinkedHashMap?, p1: V of java.util.LinkedHashMap?): V of java.util.LinkedHashMap? declared in java.util.LinkedHashMap' type=V1 of <root>.plus? origin=null
$this: GET_VAR '<this>: java.util.LinkedHashMap<K1 of <root>.plus?, V1 of <root>.plus?> declared in <root>.plus.<anonymous>' type=java.util.LinkedHashMap<K1 of <root>.plus?, V1 of <root>.plus?> origin=null
p0: CALL 'public final fun <get-first> (): A of kotlin.Pair declared in kotlin.Pair' type=K1 of <root>.plus origin=GET_PROPERTY
$this: GET_VAR 'pair: kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> declared in <root>.plus' type=kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> origin=null
p1: CALL 'public final fun <get-second> (): B of kotlin.Pair declared in kotlin.Pair' type=V1 of <root>.plus origin=GET_PROPERTY
$this: GET_VAR 'pair: kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> declared in <root>.plus' type=kotlin.Pair<K1 of <root>.plus, V1 of <root>.plus> origin=null