Optimize check for missing fields in deserialization (#3862)

Fixes Kotlin/kotlinx.serialization#662 Kotlin/kotlinx.serialization#657
This commit is contained in:
Sergey Shanshin
2020-11-25 21:50:42 +03:00
committed by GitHub
parent f9503efb74
commit b5143ba2ab
10 changed files with 425 additions and 28 deletions
@@ -16,11 +16,14 @@
package org.jetbrains.kotlinx.serialization.compiler.backend.common
import org.jetbrains.kotlin.config.ApiVersion
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.descriptorUtil.secondaryConstructors
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.VersionReader
import org.jetbrains.kotlinx.serialization.compiler.resolve.*
abstract class SerializableCodegen(
@@ -28,6 +31,10 @@ abstract class SerializableCodegen(
bindingContext: BindingContext
) : AbstractSerialGenerator(bindingContext, serializableDescriptor) {
protected val properties = bindingContext.serializablePropertiesFor(serializableDescriptor)
protected val staticDescriptor = serializableDescriptor.declaredTypeParameters.isEmpty()
private val fieldMissingOptimizationVersion = ApiVersion.parse("1.1")!!
protected val useFieldMissingOptimization = canUseFieldMissingOptimization()
fun generate() {
generateSyntheticInternalConstructor()
@@ -50,6 +57,40 @@ abstract class SerializableCodegen(
}
}
protected fun getGoldenMask(): Int {
var goldenMask = 0
var requiredBit = 1
for (property in properties.serializableProperties) {
if (!property.optional) {
goldenMask = goldenMask or requiredBit
}
requiredBit = requiredBit shl 1
}
return goldenMask
}
protected fun getGoldenMaskList(): List<Int> {
val maskSlotCount = properties.serializableProperties.bitMaskSlotCount()
val goldenMaskList = MutableList(maskSlotCount) { 0 }
for (i in properties.serializableProperties.indices) {
if (!properties.serializableProperties[i].optional) {
val slotNumber = i / 32
val bitInSlot = i % 32
goldenMaskList[slotNumber] = goldenMaskList[slotNumber] or (1 shl bitInSlot)
}
}
return goldenMaskList
}
private fun canUseFieldMissingOptimization(): Boolean {
val implementationVersion = VersionReader.getVersionsForCurrentModuleFromContext(
serializableDescriptor.module,
bindingContext
)?.implementationVersion
return if (implementationVersion != null) implementationVersion >= fieldMissingOptimizationVersion else false
}
protected abstract fun generateInternalConstructor(constructorDescriptor: ClassConstructorDescriptor)
protected open fun generateWriteSelfMethod(methodDescriptor: FunctionDescriptor) {
@@ -11,6 +11,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.descriptors.annotations.Annotations
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
@@ -67,6 +68,30 @@ interface IrBuilderExtension {
) { bodyGen(c) }
}
// function will not be created in the real class
fun IrClass.createInlinedFunction(
name: Name,
visibility: DescriptorVisibility,
origin: IrDeclarationOrigin,
returnType: IrType,
bodyGen: IrBlockBodyBuilder.(IrFunction) -> Unit
): IrSimpleFunction {
val function = factory.buildFun {
this.name = name
this.visibility = visibility
this.origin = origin
this.isInline = true
this.returnType = returnType
}
val functionSymbol = function.symbol
function.parent = this
function.body = DeclarationIrBuilder(compilerContext, functionSymbol, startOffset, endOffset).irBlockBody(
startOffset,
endOffset
) { bodyGen(function) }
return function
}
fun IrBuilderWithScope.irInvoke(
dispatchReceiver: IrExpression? = null,
callee: IrFunctionSymbol,
@@ -109,6 +134,19 @@ interface IrBuilderExtension {
}
}
fun IrBuilderWithScope.createPrimitiveArrayOfExpression(
elementPrimitiveType: IrType,
arrayElements: List<IrExpression>
): IrExpression {
val arrayType = compilerContext.irBuiltIns.primitiveArrayForType.getValue(elementPrimitiveType).defaultType
val arg0 = IrVarargImpl(startOffset, endOffset, arrayType, elementPrimitiveType, arrayElements)
val typeArguments = listOf(elementPrimitiveType)
return irCall(compilerContext.symbols.arrayOf, arrayType, typeArguments = typeArguments).apply {
putValueArgument(0, arg0)
}
}
fun IrBuilderWithScope.irBinOp(name: Name, lhs: IrExpression, rhs: IrExpression): IrExpression {
val classFqName = (lhs.type as IrSimpleType).classOrNull!!.owner.fqNameWhenAvailable!!
val symbol = compilerContext.referenceFunctions(classFqName.child(name)).single()
@@ -766,4 +804,4 @@ interface IrBuilderExtension {
return superClasses.singleOrNull { it.kind == ClassKind.CLASS }
}
}
}
@@ -68,7 +68,7 @@ class SerialInfoImplJvmIrGenerator(
generateSimplePropertyWithBackingField(property.descriptor, irClass, Name.identifier("_" + property.name.asString()))
val getter = property.getter!!
getter.origin = SERIALIZABLE_SYNTHETIC_ORIGIN
getter.origin = SERIALIZABLE_PLUGIN_ORIGIN
// Add JvmName annotation to property getters to force the resulting JVM method name for 'x' be 'x', instead of 'getX',
// and to avoid having useless bridges for it generated in BridgeLowering.
// Unfortunately, this results in an extra `@JvmName` annotation in the bytecode, but it shouldn't matter very much.
@@ -76,7 +76,7 @@ class SerialInfoImplJvmIrGenerator(
val field = property.backingField!!
field.visibility = DescriptorVisibilities.PRIVATE
field.origin = SERIALIZABLE_SYNTHETIC_ORIGIN
field.origin = SERIALIZABLE_PLUGIN_ORIGIN
val parameter = ctor.addValueParameter(property.name.asString(), getter.returnType)
ctorBody.statements += IrSetFieldImpl(
@@ -8,24 +8,31 @@ package org.jetbrains.kotlinx.serialization.compiler.backend.ir
import org.jetbrains.kotlin.backend.common.deepCopyWithVariables
import org.jetbrains.kotlin.backend.common.lower.irThrow
import org.jetbrains.kotlin.codegen.CompilationException
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.addField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrDelegatingConstructorCallImpl
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.getAnnotation
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlinx.serialization.compiler.backend.common.SerializableCodegen
import org.jetbrains.kotlinx.serialization.compiler.backend.common.serialName
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.serializableAnnotationIsUseless
import org.jetbrains.kotlinx.serialization.compiler.extensions.SerializationPluginContext
import org.jetbrains.kotlinx.serialization.compiler.resolve.*
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ARRAY_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.MISSING_FIELD_EXC
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESC_FIELD
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SINGLE_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.initializedDescriptorFieldName
class SerializableIrGenerator(
val irClass: IrClass,
@@ -33,7 +40,20 @@ class SerializableIrGenerator(
bindingContext: BindingContext
) : SerializableCodegen(irClass.descriptor, bindingContext), IrBuilderExtension {
private val descriptorGenerationFunctionName = "createInitializedDescriptor"
private val serialDescClass: ClassDescriptor = serializableDescriptor.module
.getClassFromSerializationDescriptorsPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS)
private val serialDescImplClass: ClassDescriptor = serializableDescriptor
.getClassFromInternalSerializationPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS_IMPL)
private val addElementFun = serialDescImplClass.findFunctionSymbol(CallingConventions.addElement)
val throwMissedFieldExceptionFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(SINGLE_MASK_FIELD_MISSING_FUNC_FQ).single() else null
val throwMissedFieldExceptionArrayFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(ARRAY_MASK_FIELD_MISSING_FUNC_FQ).single() else null
private fun IrClass.hasSerializableAnnotationWithoutArgs(): Boolean {
val annot = getAnnotation(SerializationAnnotations.serializableAnnotationFqName) ?: return false
@@ -64,6 +84,10 @@ class SerializableIrGenerator(
val thiz = irClass.thisReceiver!!
val superClass = irClass.getSuperClassOrAny()
var startPropOffset: Int = 0
if (useFieldMissingOptimization) {
generateOptimizedGoldenMaskCheck(seenVars)
}
when {
superClass.symbol == compilerContext.irBuiltIns.anyClass -> generateAnySuperConstructorCall(toBuilder = this@contributeConstructor)
superClass.isInternalSerializable -> {
@@ -83,7 +107,14 @@ class SerializableIrGenerator(
requireNotNull(transformFieldInitializer(prop.irField)) { "Optional value without an initializer" } // todo: filter abstract here
setProperty(irGet(thiz), prop.irProp, initializerBody)
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
// property required
if (useFieldMissingOptimization) {
// field definitely not empty as it's checked before - no need another IF, only assign property from param
+assignParamExpr
continue
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
}
}
val propNotSeenTest =
@@ -116,6 +147,153 @@ class SerializableIrGenerator(
}
}
private fun IrBlockBodyBuilder.generateOptimizedGoldenMaskCheck(seenVars: List<IrValueParameter>) {
if (serializableDescriptor.isAbstractSerializableClass() || serializableDescriptor.isSealedSerializableClass()) {
// for abstract classes fields MUST BE checked in child classes
return
}
val fieldsMissedTest: IrExpression
val throwErrorExpr: IrExpression
val maskSlotCount = seenVars.size
if (maskSlotCount == 1) {
val goldenMask = getGoldenMask()
throwErrorExpr = irInvoke(
null,
throwMissedFieldExceptionFunc!!,
irGet(seenVars[0]),
irInt(goldenMask),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)
fieldsMissedTest = irNotEquals(
irInt(goldenMask),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMask),
irGet(seenVars[0])
)
)
} else {
val goldenMaskList = getGoldenMaskList()
var compositeExpression: IrExpression? = null
for (i in goldenMaskList.indices) {
val singleCheckExpr = irNotEquals(
irInt(goldenMaskList[i]),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMaskList[i]),
irGet(seenVars[i])
)
)
compositeExpression = if (compositeExpression == null) {
singleCheckExpr
} else {
irBinOp(
OperatorNameConventions.OR,
compositeExpression,
singleCheckExpr
)
}
}
fieldsMissedTest = compositeExpression!!
throwErrorExpr = irBlock {
+irInvoke(
null,
throwMissedFieldExceptionArrayFunc!!,
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.indices.map { irGet(seenVars[it]) }),
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.map { irInt(it) }),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)
}
}
+irIfThen(compilerContext.irBuiltIns.unitType, fieldsMissedTest, throwErrorExpr)
}
private fun IrBlockBodyBuilder.getSerialDescriptorExpr(): IrExpression {
return if (serializableDescriptor.shouldHaveGeneratedSerializer && staticDescriptor) {
val serializer = serializableDescriptor.classSerializer!!
val serialDescriptorGetter = compilerContext.referenceClass(serializer.fqNameSafe)!!.getPropertyGetter(SERIAL_DESC_FIELD)!!
irGet(
serialDescriptorGetter.owner.returnType,
irGetObject(serializer),
serialDescriptorGetter.owner.symbol
)
} else {
irGetField(null, generateStaticDescriptorField())
}
}
private fun IrBlockBodyBuilder.generateStaticDescriptorField(): IrField {
val serialDescItType = serialDescClass.defaultType.toIrType()
val function = irClass.createInlinedFunction(
Name.identifier(descriptorGenerationFunctionName),
DescriptorVisibilities.PRIVATE,
SERIALIZABLE_PLUGIN_ORIGIN,
serialDescItType
) {
val serialDescVar = irTemporary(
getInstantiateDescriptorExpr(),
nameHint = "serialDesc"
)
for (property in properties.serializableProperties) {
+getAddElementToDescriptorExpr(property, serialDescVar)
}
+irReturn(irGet(serialDescVar))
}
return irClass.addField {
name = Name.identifier(initializedDescriptorFieldName)
visibility = DescriptorVisibilities.PRIVATE
origin = SERIALIZABLE_PLUGIN_ORIGIN
isFinal = true
isStatic = true
type = serialDescItType
}.apply { initializer = irClass.factory.createExpressionBody(irCall(function)) }
}
private fun IrBlockBodyBuilder.getInstantiateDescriptorExpr(): IrExpression {
val classConstructors = compilerContext.referenceConstructors(serialDescImplClass.fqNameSafe)
val serialClassDescImplCtor = classConstructors.single { it.owner.isPrimary }
return irInvoke(
null, serialClassDescImplCtor,
irString(serializableDescriptor.serialName()), irNull(), irInt(properties.serializableProperties.size)
)
}
private fun IrBlockBodyBuilder.getAddElementToDescriptorExpr(
property: SerializableProperty,
serialDescVar: IrVariable
): IrExpression {
return irInvoke(
irGet(serialDescVar),
addElementFun,
irString(property.name),
irBoolean(property.optional),
typeHint = compilerContext.irBuiltIns.unitType
)
}
private inline fun ClassDescriptor.findFunctionSymbol(
functionName: String,
predicate: (IrSimpleFunction) -> Boolean = { true }
): IrFunctionSymbol {
val irClass = compilerContext.referenceClass(fqNameSafe)?.owner ?: error("Couldn't load class $this")
val simpleFunctions = irClass.declarations.filterIsInstance<IrSimpleFunction>()
return simpleFunctions.filter { it.name.asString() == functionName }.single { predicate(it) }.symbol
}
private fun IrBlockBodyBuilder.generateSuperNonSerializableCall(superClass: IrClass) {
val ctorRef = superClass.declarations.filterIsInstance<IrConstructor>().singleOrNull { it.valueParameters.isEmpty() }
?: error("Non-serializable parent of serializable $serializableDescriptor must have no arg constructor")
@@ -189,4 +367,4 @@ class SerializableIrGenerator(
}
}
}
}
}
@@ -37,10 +37,7 @@ import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ST
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.STRUCTURE_ENCODER_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.UNKNOWN_FIELD_EXC
object SERIALIZABLE_PLUGIN_ORIGIN : IrDeclarationOriginImpl("SERIALIZER")
// TODO: use in places where elements need to have ACC_SYNTHETIC on JVM
object SERIALIZABLE_SYNTHETIC_ORIGIN : IrDeclarationOriginImpl("SERIALIZER")
object SERIALIZABLE_PLUGIN_ORIGIN : IrDeclarationOriginImpl("SERIALIZER", true)
open class SerializerIrGenerator(
val irClass: IrClass,
@@ -9,6 +9,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.codegen.*
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi
import org.jetbrains.kotlin.load.java.JvmAbi
import org.jetbrains.kotlin.load.kotlin.TypeMappingMode
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
@@ -25,10 +26,12 @@ import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.DE
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ENCODER_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.KSERIALIZER_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.MISSING_FIELD_EXC
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.PLUGIN_EXCEPTIONS_FILE
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_CTOR_MARKER_NAME
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESCRIPTOR_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESCRIPTOR_CLASS_IMPL
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESCRIPTOR_FOR_ENUM
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESC_FIELD
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_EXC
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_LOADER_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_SAVER_CLASS
@@ -50,7 +53,7 @@ internal val kOutputType = Type.getObjectType("kotlinx/serialization/encoding/$S
internal val encoderType = Type.getObjectType("kotlinx/serialization/encoding/$ENCODER_CLASS")
internal val decoderType = Type.getObjectType("kotlinx/serialization/encoding/$DECODER_CLASS")
internal val kInputType = Type.getObjectType("kotlinx/serialization/encoding/$STRUCTURE_DECODER_CLASS")
internal val pluginUtilsType = Type.getObjectType("kotlinx/serialization/internal/${PLUGIN_EXCEPTIONS_FILE}Kt")
internal val kSerialSaverType = Type.getObjectType("kotlinx/serialization/$SERIAL_SAVER_CLASS")
internal val kSerialLoaderType = Type.getObjectType("kotlinx/serialization/$SERIAL_LOADER_CLASS")
@@ -61,6 +64,9 @@ internal val serializationExceptionName = "kotlinx/serialization/$SERIAL_EXC"
internal val serializationExceptionMissingFieldName = "kotlinx/serialization/$MISSING_FIELD_EXC"
internal val serializationExceptionUnknownIndexName = "kotlinx/serialization/$UNKNOWN_FIELD_EXC"
internal val descriptorGetterName = JvmAbi.getterName(SERIAL_DESC_FIELD)
val OPT_MASK_TYPE: Type = Type.INT_TYPE
val OPT_MASK_BITS = 32
@@ -28,10 +28,15 @@ import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.resolve.descriptorUtil.getSuperClassOrAny
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.jvm.diagnostics.OtherOrigin
import org.jetbrains.kotlinx.serialization.compiler.backend.common.*
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.serializableAnnotationIsUseless
import org.jetbrains.kotlinx.serialization.compiler.resolve.*
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ARRAY_MASK_FIELD_MISSING_FUNC_NAME
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SINGLE_MASK_FIELD_MISSING_FUNC_NAME
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.initializedDescriptorFieldName
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
@@ -179,20 +184,28 @@ class SerializableCodegenImpl(
}
private fun InstructionAdapter.doGenerateConstructorImpl(exprCodegen: ExpressionCodegen) {
val seenMask = 1
val bitMaskOff = fun(it: Int): Int { return seenMask + bitMaskSlotAt(it) }
val bitMaskEnd = seenMask + properties.serializableProperties.bitMaskSlotCount()
var (propIndex, propOffset) = generateSuperSerializableCall(bitMaskEnd)
val seenMaskVar = 1
val bitMaskOff = fun(it: Int): Int { return seenMaskVar + bitMaskSlotAt(it) }
val bitMaskEnd = seenMaskVar + properties.serializableProperties.bitMaskSlotCount()
if (useFieldMissingOptimization) {
generateOptimizedGoldenMaskCheck(seenMaskVar)
}
var (propIndex, propOffset) = generateSuperSerializableCall(seenMaskVar, bitMaskEnd)
for (i in propIndex until properties.serializableProperties.size) {
val prop = properties[i]
val propType = prop.asmType
if (!prop.optional) {
// primary were validated before constructor call
genValidateProperty(i, bitMaskOff(i))
val nonThrowLabel = Label()
ificmpne(nonThrowLabel)
genMissingFieldExceptionThrow(prop.name)
visitLabel(nonThrowLabel)
if (!useFieldMissingOptimization) {
// primary were validated before constructor call
genValidateProperty(i, bitMaskOff(i))
val nonThrowLabel = Label()
ificmpne(nonThrowLabel)
genMissingFieldExceptionThrow(prop.name)
visitLabel(nonThrowLabel)
}
// setting field
load(0, thisAsmType)
load(propOffset, propType)
@@ -237,7 +250,7 @@ class SerializableCodegenImpl(
areturn(Type.VOID_TYPE)
}
private fun InstructionAdapter.generateSuperSerializableCall(propStartVar: Int): Pair<Int, Int> {
private fun InstructionAdapter.generateSuperSerializableCall(maskVar: Int, propStartVar: Int): Pair<Int, Int> {
val superClass = serializableDescriptor.getSuperClassOrAny()
val superType = classCodegen.typeMapper.mapType(superClass).internalName
@@ -261,12 +274,112 @@ class SerializableCodegenImpl(
return 0 to propStartVar
} else {
val superProps = bindingContext.serializablePropertiesFor(superClass).serializableProperties
val creator = buildInternalConstructorDesc(propStartVar, 1, classCodegen, superProps)
val creator = buildInternalConstructorDesc(propStartVar, maskVar, classCodegen, superProps)
invokespecial(superType, "<init>", creator, false)
return superProps.size to propStartVar + superProps.sumBy { it.asmType.size }
}
}
private fun InstructionAdapter.generateOptimizedGoldenMaskCheck(maskVar: Int) {
if (serializableDescriptor.isAbstractSerializableClass() || serializableDescriptor.isSealedSerializableClass()) {
// for abstract classes fields MUST BE checked in child classes
return
}
val allPresentsLabel = Label()
val maskSlotCount = properties.serializableProperties.bitMaskSlotCount()
if (maskSlotCount == 1) {
val goldenMask = getGoldenMask()
iconst(goldenMask)
dup()
load(maskVar, OPT_MASK_TYPE)
and(OPT_MASK_TYPE)
ificmpeq(allPresentsLabel)
load(maskVar, OPT_MASK_TYPE)
iconst(goldenMask)
stackSerialDescriptor()
invokestatic(
pluginUtilsType.internalName,
SINGLE_MASK_FIELD_MISSING_FUNC_NAME.asString(),
"(II${descType.descriptor})V",
false
)
} else {
val fieldsMissingLabel = Label()
val goldenMaskList = getGoldenMaskList()
goldenMaskList.forEachIndexed { i, goldenMask ->
val maskIndex = maskVar + i
// if( (goldenMask & seen) != goldenMask )
iconst(goldenMask)
dup()
load(maskIndex, OPT_MASK_TYPE)
and(OPT_MASK_TYPE)
ificmpne(fieldsMissingLabel)
}
goTo(allPresentsLabel)
visitLabel(fieldsMissingLabel)
// prepare seen array
fillArray(OPT_MASK_TYPE, goldenMaskList) { i, _ ->
load(maskVar + i, OPT_MASK_TYPE)
}
// prepare golden mask array
fillArray(OPT_MASK_TYPE, goldenMaskList) { _, goldenMask ->
iconst(goldenMask)
}
stackSerialDescriptor()
invokestatic(
pluginUtilsType.internalName,
ARRAY_MASK_FIELD_MISSING_FUNC_NAME.asString(),
"([I[I${descType.descriptor})V",
false
)
}
visitLabel(allPresentsLabel)
}
private fun InstructionAdapter.stackSerialDescriptor() {
if (serializableDescriptor.shouldHaveGeneratedSerializer && staticDescriptor) {
val serializer = serializableDescriptor.classSerializer!!
StackValue.singleton(serializer, classCodegen.typeMapper).put(kSerializerType, this)
invokeinterface(kSerializerType.internalName, descriptorGetterName, "()${descType.descriptor}")
} else {
generateStaticDescriptorField()
getstatic(thisAsmType.internalName, initializedDescriptorFieldName, descType.descriptor)
}
}
private fun generateStaticDescriptorField() {
val flags = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL or Opcodes.ACC_SYNTHETIC or Opcodes.ACC_STATIC
classCodegen.v.newField(
OtherOrigin(classCodegen.myClass.psiOrParent), flags,
initializedDescriptorFieldName, descType.descriptor, null, null
)
val clInit = classCodegen.createOrGetClInitCodegen()
with(clInit.v) {
anew(descImplType)
dup()
aconst(serializableDescriptor.serialName())
aconst(null)
aconst(properties.serializableProperties.size)
invokespecial(descImplType.internalName, "<init>", "(Ljava/lang/String;${generatedSerializerType.descriptor}I)V", false)
for (property in properties.serializableProperties) {
dup()
aconst(property.name)
iconst(if (property.optional) 1 else 0)
invokevirtual(descImplType.internalName, CallingConventions.addElement, "(Ljava/lang/String;Z)V", false)
}
putstatic(thisAsmType.internalName, initializedDescriptorFieldName, descType.descriptor)
}
}
private fun ExpressionCodegen.genInitProperty(prop: SerializableProperty) = getProp(prop)?.let {
classCodegen.initializeProperty(this, it)
}
@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.config.ApiVersion
import org.jetbrains.kotlin.config.KotlinCompilerVersion
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.BindingTrace
import org.jetbrains.kotlin.util.slicedMap.Slices
import org.jetbrains.kotlin.util.slicedMap.WritableSlice
@@ -50,6 +51,11 @@ object VersionReader {
return versions
}
fun getVersionsForCurrentModuleFromContext(module: ModuleDescriptor, context: BindingContext): RuntimeVersions? {
context.get(VERSIONS_SLICE, module)?.let { return it }
return getVersionsForCurrentModule(module)
}
fun getVersionsForCurrentModule(module: ModuleDescriptor): RuntimeVersions? {
val markerClass = module.getClassFromSerializationPackage(SerialEntityNames.KSERIALIZER_CLASS)
val location = (markerClass.source as? KotlinJvmBinarySourceElement)?.binaryClass?.location ?: return null
@@ -59,4 +65,4 @@ object VersionReader {
if (!file.exists()) return null
return getVersionsFromManifest(file)
}
}
}
@@ -44,6 +44,8 @@ object SerialEntityNames {
const val LOAD = "deserialize"
const val SERIALIZER_CLASS = "\$serializer"
const val initializedDescriptorFieldName = "\$initializedDescriptor"
// classes
val KSERIALIZER_NAME = Name.identifier(KSERIALIZER_CLASS)
val SERIAL_CTOR_MARKER_NAME = Name.identifier("SerializationConstructorMarker")
@@ -68,6 +70,8 @@ object SerialEntityNames {
const val SERIAL_DESCRIPTOR_CLASS_IMPL = "PluginGeneratedSerialDescriptor"
const val SERIAL_DESCRIPTOR_FOR_ENUM = "EnumDescriptor"
const val PLUGIN_EXCEPTIONS_FILE = "PluginExceptions"
//exceptions
const val SERIAL_EXC = "SerializationException"
const val MISSING_FIELD_EXC = "MissingFieldException"
@@ -81,6 +85,10 @@ object SerialEntityNames {
val TYPE_PARAMS_SERIALIZERS_GETTER = Name.identifier("typeParametersSerializers")
val WRITE_SELF_NAME = Name.identifier("write\$Self")
val SERIALIZER_PROVIDER_NAME = Name.identifier("serializer")
val SINGLE_MASK_FIELD_MISSING_FUNC_NAME = Name.identifier("throwMissingFieldException")
val ARRAY_MASK_FIELD_MISSING_FUNC_NAME = Name.identifier("throwArrayMissingFieldException")
val SINGLE_MASK_FIELD_MISSING_FUNC_FQ = SerializationPackages.internalPackageFqName.child(SINGLE_MASK_FIELD_MISSING_FUNC_NAME)
val ARRAY_MASK_FIELD_MISSING_FUNC_FQ = SerializationPackages.internalPackageFqName.child(ARRAY_MASK_FIELD_MISSING_FUNC_NAME)
// parameters
val dummyParamName = Name.identifier("serializationConstructorMarker")
@@ -82,6 +82,16 @@ internal fun ModuleDescriptor.getClassFromInternalSerializationPackage(classSimp
)
) { "Can't locate class $classSimpleName from package ${SerializationPackages.internalPackageFqName}" }
internal fun ModuleDescriptor.getClassFromSerializationDescriptorsPackage(classSimpleName: String) =
requireNotNull(
findClassAcrossModuleDependencies(
ClassId(
SerializationPackages.descriptorsPackageFqName,
Name.identifier(classSimpleName)
)
)
) { "Can't locate class $classSimpleName from package ${SerializationPackages.descriptorsPackageFqName}" }
internal fun getSerializationPackageFqn(classSimpleName: String): FqName =
SerializationPackages.packageFqName.child(Name.identifier(classSimpleName))