Rewrite generator for IrBuiltInsMapGenerated

Similarly to 742fef9042 with OperationsMapGenerated, use optimized
`when` over strings instead of lambdas because lambdas lead to a lot of
bytecode.

The change in `IrConst.toPrimitive` is needed because for operations
like `Byte.plus(Int)` the IrConst instance would have the IR type for
kotlin.Byte, but the actual runtime value of type Int
(java.lang.Integer), which would lead to CCE from
`interpretBinaryFunction`. Previously it didn't fail because of
unchecked cast before calling the lambda, which allowed a value of
runtime type java.lang.Integer to sneak through to the lambda parameter
and be "unboxed" to the correct type via the `(... as Number).toByte()`
conversion which backend generates.

The main benefit of this change is that it reduces the size of the
proguarded compiler jar by ~0.69%.
This commit is contained in:
Alexander Udalov
2021-01-11 20:36:02 +03:00
committed by TeamCityServer
parent ac7a1c7762
commit e6254b51e1
6 changed files with 800 additions and 576 deletions
@@ -189,29 +189,20 @@ class IrInterpreter(val irBuiltIns: IrBuiltIns, private val bodyMap: Map<IdSigna
}
}
val signature = CompileTimeFunction(methodName, argsType.map { it.getOnlyName() })
// TODO replace unary, binary, ternary functions with vararg
val result = withExceptionHandler {
when (argsType.size) {
1 -> {
val function = unaryFunctions[signature]
?: throw InterpreterMethodNotFoundError("For given function $signature there is no entry in unary map")
function.invoke(argsValues.first())
}
2 -> {
val function = binaryFunctions[signature]
?: throw InterpreterMethodNotFoundError("For given function $signature there is no entry in binary map")
when (methodName) {
"rangeTo" -> return calculateRangeTo(irFunction.returnType)
else -> function.invoke(argsValues[0], argsValues[1])
}
}
3 -> {
val function = ternaryFunctions[signature]
?: throw InterpreterMethodNotFoundError("For given function $signature there is no entry in ternary map")
function.invoke(argsValues[0], argsValues[1], argsValues[2])
1 -> interpretUnaryFunction(methodName, argsType[0].getOnlyName(), argsValues[0])
2 -> when (methodName) {
"rangeTo" -> return calculateRangeTo(irFunction.returnType)
else -> interpretBinaryFunction(
methodName, argsType[0].getOnlyName(), argsType[1].getOnlyName(), argsValues[0], argsValues[1]
)
}
3 -> interpretTernaryFunction(
methodName, argsType[0].getOnlyName(), argsType[1].getOnlyName(), argsType[2].getOnlyName(),
argsValues[0], argsValues[1], argsValues[2]
)
else -> throw InterpreterError("Unsupported number of arguments for invocation as builtin functions")
}
}
@@ -222,14 +213,14 @@ class IrInterpreter(val irBuiltIns: IrBuiltIns, private val bodyMap: Map<IdSigna
private fun calculateRangeTo(type: IrType): ExecutionResult {
val constructor = type.classOrNull!!.owner.constructors.first()
val constructorCall = IrConstructorCallImpl.fromSymbolOwner(constructor.returnType, constructor.symbol)
val constructorValueParameters = constructor.valueParameters.map { it.symbol }
val primitiveValueParameters = stack.getAll().map { it.state as Primitive<*> }
primitiveValueParameters.forEachIndexed { index, primitive ->
constructorCall.putValueArgument(index, primitive.value.toIrConst(primitive.type))
constructorCall.putValueArgument(index, primitive.value.toIrConst(constructorValueParameters[index].owner.type))
}
val constructorValueParameters = constructor.valueParameters.map { it.symbol }.zip(primitiveValueParameters)
return stack.newFrame(initPool = constructorValueParameters.map { Variable(it.first, it.second) }) {
return stack.newFrame(initPool = constructorValueParameters.zip(primitiveValueParameters).map { Variable(it.first, it.second) }) {
constructorCall.interpret()
}
}
@@ -104,8 +104,11 @@ fun Any?.toIrConst(irType: IrType, startOffset: Int = UNDEFINED_OFFSET, endOffse
}
}
internal fun <T> IrConst<T>.toPrimitive(): Primitive<T> {
return Primitive(this.value, this.type)
@Suppress("UNCHECKED_CAST")
internal fun <T> IrConst<T>.toPrimitive(): Primitive<T> = when {
type.isByte() -> Primitive((value as Number).toByte() as T, type)
type.isShort() -> Primitive((value as Number).toShort() as T, type)
else -> Primitive(value, type)
}
fun IrAnnotationContainer?.hasAnnotation(annotation: FqName): Boolean {
@@ -10,28 +10,3 @@ import org.jetbrains.kotlin.name.FqName
val compileTimeAnnotation = FqName("kotlin.CompileTimeCalculation")
val evaluateIntrinsicAnnotation = FqName("kotlin.EvaluateIntrinsic")
val contractsDslAnnotation = FqName("kotlin.internal.ContractsDsl")
data class CompileTimeFunction(val methodName: String, val args: List<String>)
@Suppress("UNCHECKED_CAST")
fun <T> unaryOperation(
methodName: String, receiverType: String, function: (T) -> Any?
): Pair<CompileTimeFunction, Function1<Any?, Any?>> {
return CompileTimeFunction(methodName, listOf(receiverType)) to function as Function1<Any?, Any?>
}
@Suppress("UNCHECKED_CAST")
fun <T, E> binaryOperation(
methodName: String, receiverType: String, parameterType: String, function: (T, E) -> Any?
): Pair<CompileTimeFunction, Function2<Any?, Any?, Any?>> {
return CompileTimeFunction(methodName, listOfNotNull(receiverType, parameterType)) to function as Function2<Any?, Any?, Any?>
}
@Suppress("UNCHECKED_CAST")
fun <T, E, R> ternaryOperation(
methodName: String, receiverType: String, firstParameterType: String, secondParameterType: String, function: (T, E, R) -> Any?
): Pair<CompileTimeFunction, Function3<Any?, Any?, Any?, Any?>> {
return CompileTimeFunction(
methodName, listOfNotNull(receiverType, firstParameterType, secondParameterType)
) to function as Function3<Any?, Any?, Any?, Any?>
}
+153 -81
View File
@@ -14,13 +14,10 @@ import org.jetbrains.kotlin.config.LanguageVersionSettingsImpl
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.descriptors.impl.ModuleDescriptorImpl
import org.jetbrains.kotlin.generators.util.GeneratorsFileUtil
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl
import org.jetbrains.kotlin.ir.descriptors.IrBuiltIns
import org.jetbrains.kotlin.ir.types.impl.originalKotlinType
import org.jetbrains.kotlin.ir.util.IdSignature
import org.jetbrains.kotlin.ir.util.IdSignatureComposer
import org.jetbrains.kotlin.ir.util.SymbolTable
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi2ir.generators.TypeTranslatorImpl
import org.jetbrains.kotlin.storage.LockBasedStorageManager
@@ -37,81 +34,184 @@ fun generateMap(): String {
val sb = StringBuilder()
val p = Printer(sb)
p.println(File("license/COPYRIGHT.txt").readText())
p.println("@file:Suppress(\"DEPRECATION\", \"DEPRECATION_ERROR\", \"UNCHECKED_CAST\")")
p.println()
p.println("package org.jetbrains.kotlin.ir.interpreter.builtins")
p.println()
p.println("import org.jetbrains.kotlin.ir.interpreter.exceptions.InterpreterMethodNotFoundError")
p.println("import org.jetbrains.kotlin.ir.interpreter.proxy.Proxy")
p.println()
p.println("/** This file is generated by org.jetbrains.kotlin.backend.common.interpreter.builtins.GenerateBuiltInsMap.generateMap(). DO NOT MODIFY MANUALLY */")
p.println("/** This file is generated by `./gradlew generateInterpreterMap`. DO NOT MODIFY MANUALLY */")
p.println()
val irBuiltIns = getIrBuiltIns()
val unaryOperationsMap = getOperationMap(1)
val binaryOperationsMap = getOperationMap(2)
val ternaryOperationsMap = getOperationMap(3)
val binaryIrOperationsMap = getBinaryIrOperationMap(irBuiltIns)
generateInterpretUnaryFunction(p, getOperationMap(1).apply {
val irNullCheck = irBuiltIns.checkNotNullSymbol.owner
this += Operation(irNullCheck.name.asString(), listOf("T0?"), customExpression = "a!!")
})
p.println("@Suppress(\"DEPRECATION\", \"DEPRECATION_ERROR\")")
p.println("val unaryFunctions = mapOf<CompileTimeFunction, Function1<Any?, Any?>>(")
p.println(generateUnaryBody(unaryOperationsMap, irBuiltIns))
p.println(")")
p.println()
generateInterpretBinaryFunction(p, getOperationMap(2) + getBinaryIrOperationMap(irBuiltIns))
p.println("val binaryFunctions = mapOf<CompileTimeFunction, Function2<Any?, Any?, Any?>>(")
p.println(generateBinaryBody(binaryOperationsMap, binaryIrOperationsMap))
p.println(")")
p.println()
p.println("val ternaryFunctions = mapOf<CompileTimeFunction, Function3<Any?, Any?, Any?, Any?>>(")
p.println(generateTernaryBody(ternaryOperationsMap))
p.println(")")
p.println()
generateInterpretTernaryFunction(p, getOperationMap(3))
return sb.toString()
}
private fun getOperationMap(argumentsCount: Int): MutableMap<CallableDescriptor, Pair<String, String>> {
private fun generateInterpretUnaryFunction(p: Printer, unaryOperations: List<Operation>) {
p.println("internal fun interpretUnaryFunction(name: String, type: String, a: Any?): Any? {")
p.pushIndent()
p.println("when (name) {")
p.pushIndent()
for ((name, operations) in unaryOperations.groupBy(Operation::name)) {
p.println("\"$name\" -> when (type) {")
p.pushIndent()
for (operation in operations) {
p.println("\"${operation.typeA}\" -> return ${operation.expressionString}")
}
p.popIndent()
p.println("}")
}
p.popIndent()
p.println("}")
p.println("throw InterpreterMethodNotFoundError(\"Unknown function: \$name(\$type)\")")
p.popIndent()
p.println("}")
p.println()
}
private fun generateInterpretBinaryFunction(p: Printer, binaryOperations: List<Operation>) {
p.println("internal fun interpretBinaryFunction(name: String, typeA: String, typeB: String, a: Any?, b: Any?): Any? {")
p.pushIndent()
p.println("when (name) {")
p.pushIndent()
for ((name, operations) in binaryOperations.groupBy(Operation::name)) {
p.println("\"$name\" -> when (typeA) {")
p.pushIndent()
for ((typeA, operationsOnTypeA) in operations.groupBy(Operation::typeA)) {
val singleOperation = operationsOnTypeA.singleOrNull()
if (singleOperation != null) {
// Slightly improve readability if there's only one operation with such name and typeA.
p.println("\"$typeA\" -> if (typeB == \"${singleOperation.typeB}\") return ${singleOperation.expressionString}")
} else {
p.println("\"$typeA\" -> when (typeB) {")
p.pushIndent()
for ((typeB, operationsOnTypeB) in operationsOnTypeA.groupBy(Operation::typeB)) {
for (operation in operationsOnTypeB) {
p.println("\"$typeB\" -> return ${operation.expressionString}")
}
}
p.popIndent()
p.println("}")
}
}
p.popIndent()
p.println("}")
}
p.popIndent()
p.println("}")
p.println("throw InterpreterMethodNotFoundError(\"Unknown function: \$name(\$typeA, \$typeB)\")")
p.popIndent()
p.println("}")
p.println()
}
private fun generateInterpretTernaryFunction(p: Printer, ternaryOperations: List<Operation>) {
p.println("internal fun interpretTernaryFunction(name: String, typeA: String, typeB: String, typeC: String, a: Any?, b: Any?, c: Any?): Any {")
p.pushIndent()
p.println("when (name) {")
p.pushIndent()
for ((name, operations) in ternaryOperations.groupBy(Operation::name)) {
p.println("\"$name\" -> when (typeA) {")
p.pushIndent()
for (operation in operations) {
val (typeA, typeB, typeC) = operation.parameterTypes
p.println("\"$typeA\" -> if (typeB == \"$typeB\" && typeC == \"$typeC\") return ${operation.expressionString}")
}
p.popIndent()
p.println("}")
}
p.popIndent()
p.println("}")
p.println("throw InterpreterMethodNotFoundError(\"Unknown function: \$name(\$typeA, \$typeB, \$typeC)\")")
p.popIndent()
p.println("}")
p.println()
}
private fun castValue(name: String, type: String): String = when (type) {
"Any?", "T" -> name
"Array" -> "$name as Array<Any?>"
else -> "$name as $type"
}
private fun castValueParenthesized(name: String, type: String): String =
if (type == "Any?") name else "(${castValue(name, type)})"
private data class Operation(
val name: String,
val parameterTypes: List<String>,
val isFunction: Boolean = true,
val customExpression: String? = null,
) {
val typeA: String get() = parameterTypes[0]
val typeB: String get() = parameterTypes[1]
val expressionString: String
get() {
val receiver = castValueParenthesized("a", typeA)
println(name)
return when {
name == IrBuiltIns.OperatorNames.EQEQEQ && parameterTypes.all { it == "Any?" } ->
"if (a is Proxy && b is Proxy) a.state === b.state else a === b"
customExpression != null -> customExpression
else -> buildString {
append(receiver)
append(".")
append(name)
if (isFunction) append("(")
parameterTypes.withIndex().drop(1).joinTo(this) { (index, type) ->
castValue(('a' + index).toString(), type)
}
if (isFunction) append(")")
}
}
}
}
private fun getOperationMap(argumentsCount: Int): MutableList<Operation> {
val builtIns = DefaultBuiltIns.Instance
val operationMap = mutableMapOf<CallableDescriptor, Pair<String, String>>()
val operationMap = mutableListOf<Operation>()
val allPrimitiveTypes = PrimitiveType.values().map { builtIns.getBuiltInClassByFqName(it.typeFqName) }
val arrays = PrimitiveType.values().map { builtIns.getPrimitiveArrayClassDescriptor(it) } + builtIns.array
fun CallableDescriptor.isFakeOverride(classDescriptor: ClassDescriptor): Boolean {
val isPrimitive = KotlinBuiltIns.isPrimitiveClass(classDescriptor) || KotlinBuiltIns.isString(classDescriptor.defaultType)
val isFakeOverridden = (this as? FunctionDescriptor)?.kind == CallableMemberDescriptor.Kind.FAKE_OVERRIDE
return when {
isPrimitive -> false
else -> isFakeOverridden
}
return !isPrimitive && isFakeOverridden
}
for (classDescriptor in allPrimitiveTypes + builtIns.string + arrays + builtIns.any) {
val classTypeParameters = classDescriptor.typeConstructor.parameters.map { it.name.asString() }
val typeParametersReplacedToAny =
if (classTypeParameters.isNotEmpty()) classTypeParameters.joinToString(prefix = "<", postfix = ">") { "Any?" } else ""
val classType = classDescriptor.defaultType.constructor.toString()
val compileTimeFunctions = classDescriptor.unsubstitutedMemberScope.getContributedDescriptors()
.filterIsInstance<CallableDescriptor>()
.filter { !it.isFakeOverride(classDescriptor) }
.filter { !it.isFakeOverride(classDescriptor) && it.valueParameters.size + 1 == argumentsCount }
for (function in compileTimeFunctions) {
val operationArguments = (listOf(classType) + function.valueParameters.map { it.type }).joinToString { "\"" + it + "\"" }
val typeParametersOfFun = listOf(classType + typeParametersReplacedToAny) +
function.valueParameters.map { if (classTypeParameters.contains(it.type.toString())) "Any?" else it.type.toString() }
if (function.valueParameters.size + 1 == argumentsCount) { // +1 for receiver
operationMap[function] = typeParametersOfFun.joinToString(prefix = "<", postfix = ">") to operationArguments
}
operationMap.add(
Operation(
function.name.asString(),
listOf(classDescriptor.defaultType.constructor.toString()) + function.valueParameters.map { it.type.toString() },
function is FunctionDescriptor
)
)
}
}
return operationMap
}
private fun getBinaryIrOperationMap(irBuiltIns: IrBuiltIns): MutableMap<IrFunction, Pair<String, String>> {
val operationMap = mutableMapOf<IrFunction, Pair<String, String>>()
private fun getBinaryIrOperationMap(irBuiltIns: IrBuiltIns): List<Operation> {
val operationMap = mutableListOf<Operation>()
val irFunSymbols =
(irBuiltIns.lessFunByOperandType.values + irBuiltIns.lessOrEqualFunByOperandType.values +
irBuiltIns.greaterFunByOperandType.values + irBuiltIns.greaterOrEqualFunByOperandType.values +
@@ -121,49 +221,21 @@ private fun getBinaryIrOperationMap(irBuiltIns: IrBuiltIns): MutableMap<IrFuncti
for (function in irFunSymbols) {
val parametersTypes = function.valueParameters.map { it.type.originalKotlinType!!.toString() }
val operationArguments = parametersTypes.joinToString { "\"" + it + "\"" }
val functionTypeParameters = parametersTypes.joinToString(prefix = "<", postfix = ">")
check(parametersTypes.size == 2) { "Couldn't add following method from ir builtins to operations map: ${function.name}" }
operationMap[function] = functionTypeParameters to operationArguments
operationMap.add(
Operation(
function.name.asString(), parametersTypes,
customExpression = castValueParenthesized("a", parametersTypes[0]) + " " +
getIrMethodSymbolByName(function.name.asString()) + " " +
castValueParenthesized("b", parametersTypes[1])
)
)
}
return operationMap
}
private fun generateUnaryBody(unaryOperationsMap: Map<CallableDescriptor, Pair<String, String>>, irBuiltIns: IrBuiltIns): String {
val irNullCheck = irBuiltIns.checkNotNullSymbol.owner
return unaryOperationsMap.entries.joinToString(separator = ",\n", postfix = ",\n") { (function, parameters) ->
val methodName = "${function.name}"
val parentheses = if (function is FunctionDescriptor) "()" else ""
" unaryOperation${parameters.first}(\"$methodName\", ${parameters.second}) { a -> a.$methodName$parentheses }"
} + " unaryOperation<Any?>(\"${irNullCheck.name}\", \"${irNullCheck.valueParameters.first().type.originalKotlinType}\") { a -> a!! }"
}
private fun generateBinaryBody(
binaryOperationsMap: Map<CallableDescriptor, Pair<String, String>>, binaryIrOperationsMap: Map<IrFunction, Pair<String, String>>
): String {
return binaryOperationsMap.entries.joinToString(separator = ",\n", postfix = ",\n") { (function, parameters) ->
val methodName = "${function.name}"
" binaryOperation${parameters.first}(\"$methodName\", ${parameters.second}) { a, b -> a.$methodName(b) }"
} + binaryIrOperationsMap.entries.joinToString(separator = ",\n") { (function, parameters) ->
val methodName = "${function.name}"
val methodSymbol = getIrMethodSymbolByName(methodName)
val body = when (methodName) {
IrBuiltIns.OperatorNames.EQEQEQ -> "if (a is Proxy && b is Proxy) a.state === b.state else a === b"
else -> "a $methodSymbol b"
}
" binaryOperation${parameters.first}(\"$methodName\", ${parameters.second}) { a, b -> $body }"
}
}
private fun generateTernaryBody(ternaryOperationsMap: Map<CallableDescriptor, Pair<String, String>>): String {
return ternaryOperationsMap.entries.joinToString(separator = ",\n") { (function, parameters) ->
val methodName = "${function.name}"
" ternaryOperation${parameters.first}(\"$methodName\", ${parameters.second}) { a, b, c -> a.$methodName(b, c) }"
}
}
private fun getIrMethodSymbolByName(methodName: String): String {
return when (methodName) {
IrBuiltIns.OperatorNames.LESS -> "<"
@@ -10,7 +10,7 @@ import org.jetbrains.kotlin.generators.interpreter.DESTINATION
import org.jetbrains.kotlin.generators.interpreter.generateMap
import org.jetbrains.kotlin.test.KotlinTestUtils
class GenerateBuiltInsMapTest : TestCase() {
class GenerateInterpreterMapTest : TestCase() {
fun testGeneratedDataIsUpToDate() {
val text = generateMap()
KotlinTestUtils.assertEqualsToFile(DESTINATION, text)