KT-45777: Parse classes only once with ASM

to get both constants and inline functions.

Also add BasicClassInfo to ClassFileWithContents to simplify the code.
This commit is contained in:
Hung Nguyen
2021-12-20 09:37:23 +00:00
committed by nataliya.valtman
parent 9a995af0df
commit 9b96dbe2d2
5 changed files with 88 additions and 95 deletions
@@ -37,6 +37,7 @@ import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.kotlin.resolve.jvm.JvmClassName
import org.jetbrains.org.objectweb.asm.*
import org.jetbrains.org.objectweb.asm.ClassReader.*
import java.io.File
import java.security.MessageDigest
@@ -646,83 +647,86 @@ class KotlinClassInfo constructor(
}
fun createFrom(classId: ClassId, classHeader: KotlinClassHeader, classContents: ByteArray): KotlinClassInfo {
val constantsAndInlineFunctions = getConstantsAndInlineFunctions(classHeader, classContents)
return KotlinClassInfo(
classId,
classHeader.kind,
classHeader.data ?: emptyArray(),
classHeader.strings ?: emptyArray(),
classHeader.multifileClassName,
getConstantsMap(classContents),
getInlineFunctionsMap(classHeader, classContents)
constantsMap = constantsAndInlineFunctions.first,
inlineFunctionsMap = constantsAndInlineFunctions.second
)
}
}
}
private fun getConstantsMap(bytes: ByteArray): LinkedHashMap<String, Any> {
val result = LinkedHashMap<String, Any>()
/** Parses the class file only once to get both constants and inline functions. */
private fun getConstantsAndInlineFunctions(
classHeader: KotlinClassHeader,
classContents: ByteArray
): Pair<LinkedHashMap<String, Any>, LinkedHashMap<String, Long>> {
val constantsClassVisitor = ConstantsClassVisitor()
val inlineFunctionNames = inlineFunctionsJvmNames(classHeader)
ClassReader(bytes).accept(object : ClassVisitor(Opcodes.API_VERSION) {
override fun visitField(access: Int, name: String, desc: String, signature: String?, value: Any?): FieldVisitor? {
if (access and Opcodes.ACC_PRIVATE == Opcodes.ACC_PRIVATE) return null
val staticFinal = Opcodes.ACC_STATIC or Opcodes.ACC_FINAL
if (value != null && access and staticFinal == staticFinal) {
result[name] = value
}
return null
}
}, ClassReader.SKIP_CODE or ClassReader.SKIP_DEBUG or ClassReader.SKIP_FRAMES)
return result
return if (inlineFunctionNames.isEmpty()) {
ClassReader(classContents).accept(constantsClassVisitor, SKIP_CODE or SKIP_DEBUG or SKIP_FRAMES)
Pair(constantsClassVisitor.getResult(), LinkedHashMap())
} else {
val inlineFunctionsClassVisitor = InlineFunctionsClassVisitor(inlineFunctionNames, constantsClassVisitor)
ClassReader(classContents).accept(inlineFunctionsClassVisitor, 0)
Pair(constantsClassVisitor.getResult(), inlineFunctionsClassVisitor.getResult())
}
}
private fun getInlineFunctionsMap(header: KotlinClassHeader, bytes: ByteArray): LinkedHashMap<String, Long> {
val inlineFunctions = inlineFunctionsJvmNames(header)
if (inlineFunctions.isEmpty()) return LinkedHashMap()
private class ConstantsClassVisitor : ClassVisitor(Opcodes.API_VERSION) {
private val result = LinkedHashMap<String, Any>()
val result = LinkedHashMap<String, Long>()
var dummyVersion: Int = -1
ClassReader(bytes).accept(object : ClassVisitor(Opcodes.API_VERSION) {
override fun visitField(access: Int, name: String, desc: String, signature: String?, value: Any?): FieldVisitor? {
if (access and Opcodes.ACC_PRIVATE == Opcodes.ACC_PRIVATE) return null
override fun visit(
version: Int,
access: Int,
name: String?,
signature: String?,
superName: String?,
interfaces: Array<out String>?
) {
super.visit(version, access, name, signature, superName, interfaces)
dummyVersion = version
val staticFinal = Opcodes.ACC_STATIC or Opcodes.ACC_FINAL
if (value != null && access and staticFinal == staticFinal) {
result[name] = value
}
return null
}
override fun visitMethod(
access: Int,
name: String,
desc: String,
signature: String?,
exceptions: Array<out String>?
): MethodVisitor? {
if (access and Opcodes.ACC_PRIVATE == Opcodes.ACC_PRIVATE) return null
fun getResult() = result
}
val dummyClassWriter = ClassWriter(0)
dummyClassWriter.visit(dummyVersion, 0, "dummy", null, AsmTypes.OBJECT_TYPE.internalName, null)
private class InlineFunctionsClassVisitor(
private val inlineFunctionNames: Set<String>,
cv: ClassVisitor // Note: cv must not override the visitMethod (it will not be called with the current implementation below)
) : ClassVisitor(Opcodes.API_VERSION, cv) {
return object : MethodVisitor(Opcodes.API_VERSION, dummyClassWriter.visitMethod(0, name, desc, null, exceptions)) {
override fun visitEnd() {
val jvmName = name + desc
if (jvmName !in inlineFunctions) return
private val result = LinkedHashMap<String, Long>()
private var dummyVersion: Int = -1
val dummyBytes = dummyClassWriter.toByteArray()!!
override fun visit(version: Int, access: Int, name: String?, signature: String?, superName: String?, interfaces: Array<out String>?) {
super.visit(version, access, name, signature, superName, interfaces)
dummyVersion = version
}
val hash = dummyBytes.md5()
result[jvmName] = hash
}
override fun visitMethod(access: Int, name: String, desc: String, signature: String?, exceptions: Array<out String>?): MethodVisitor? {
if (access and Opcodes.ACC_PRIVATE == Opcodes.ACC_PRIVATE) return null
val dummyClassWriter = ClassWriter(0)
dummyClassWriter.visit(dummyVersion, 0, "dummy", null, AsmTypes.OBJECT_TYPE.internalName, null)
return object : MethodVisitor(Opcodes.API_VERSION, dummyClassWriter.visitMethod(0, name, desc, null, exceptions)) {
override fun visitEnd() {
val jvmName = name + desc
if (jvmName !in inlineFunctionNames) return
val dummyBytes = dummyClassWriter.toByteArray()!!
val hash = dummyBytes.md5()
result[jvmName] = hash
}
}
}
}, 0)
return result
fun getResult() = result
}
@@ -12,6 +12,7 @@ import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.resolve.jvm.JvmClassName
import org.jetbrains.org.objectweb.asm.AnnotationVisitor
import org.jetbrains.org.objectweb.asm.ClassReader
import org.jetbrains.org.objectweb.asm.ClassReader.*
import org.jetbrains.org.objectweb.asm.ClassVisitor
import org.jetbrains.org.objectweb.asm.Opcodes
@@ -46,10 +47,7 @@ class BasicClassInfo(
val innerClassesClassVisitor = InnerClassesClassVisitor(kotlinClassHeaderClassVisitor)
val basicClassInfoVisitor = BasicClassInfoClassVisitor(innerClassesClassVisitor)
ClassReader(classContents).accept(
basicClassInfoVisitor,
ClassReader.SKIP_CODE or ClassReader.SKIP_DEBUG or ClassReader.SKIP_FRAMES
)
ClassReader(classContents).accept(basicClassInfoVisitor, SKIP_CODE or SKIP_DEBUG or SKIP_FRAMES)
val className = basicClassInfoVisitor.getClassName()
val innerClassesInfo = innerClassesClassVisitor.getInnerClassesInfo()
@@ -29,8 +29,12 @@ class ClassFile(
}
}
/** Information to locate a .class file, plus their contents. */
/** Contains the contents of a [ClassFile] and information extracted from the contents. */
class ClassFileWithContents(
val classFile: ClassFile,
val contents: ByteArray
)
) {
val classInfo: BasicClassInfo by lazy {
BasicClassInfo.compute(contents)
}
}
@@ -76,18 +76,13 @@ object ClassSnapshotter {
protoBased: Boolean? = null,
includeDebugInfoInSnapshot: Boolean? = null
): List<ClassSnapshot> {
val classesInfo: List<BasicClassInfo> = classes.map { BasicClassInfo.compute(it.contents) }
// Find inaccessible classes first, their snapshots will be `InaccessibleClassSnapshot`s.
val inaccessibleClasses: Set<BasicClassInfo> = getInaccessibleClasses(classesInfo).toSet()
val classesInfo: List<BasicClassInfo> = classes.map { it.classInfo }
val inaccessibleClassesInfo: Set<BasicClassInfo> = getInaccessibleClasses(classesInfo).toSet()
// Snapshot the remaining accessible classes
val accessibleClasses: List<ClassFileWithContents> = classes.mapIndexedNotNull { index, clazz ->
if (classesInfo[index] in inaccessibleClasses) null else clazz
}
val accessibleClassesInfo: List<BasicClassInfo> = classesInfo.filterNot { it in inaccessibleClasses }
val accessibleSnapshots: List<ClassSnapshot> =
doSnapshot(accessibleClasses, accessibleClassesInfo, protoBased, includeDebugInfoInSnapshot)
val accessibleClasses: List<ClassFileWithContents> = classes.filter { it.classInfo !in inaccessibleClassesInfo }
val accessibleSnapshots: List<ClassSnapshot> = doSnapshot(accessibleClasses, protoBased, includeDebugInfoInSnapshot)
val accessibleClassSnapshots: Map<ClassFileWithContents, ClassSnapshot> = accessibleClasses.zipToMap(accessibleSnapshots)
return classes.map { accessibleClassSnapshots[it] ?: InaccessibleClassSnapshot }
@@ -95,58 +90,47 @@ object ClassSnapshotter {
private fun doSnapshot(
classes: List<ClassFileWithContents>,
classesInfo: List<BasicClassInfo>,
protoBased: Boolean? = null,
includeDebugInfoInSnapshot: Boolean? = null
): List<ClassSnapshot> {
// Snapshot Kotlin classes first
val kotlinSnapshots: List<KotlinClassSnapshot?> = classes.mapIndexed { index, clazz ->
trySnapshotKotlinClass(clazz, classesInfo[index])
val kotlinSnapshots: List<KotlinClassSnapshot?> = classes.map { clazz ->
trySnapshotKotlinClass(clazz)
}
val kotlinClassSnapshots: Map<ClassFileWithContents, KotlinClassSnapshot?> = classes.zipToMap(kotlinSnapshots)
// Snapshot the remaining Java classes
val javaClasses: List<ClassFileWithContents> = classes.filter { kotlinClassSnapshots[it] == null }
val javaClassesInfo: List<BasicClassInfo> = classesInfo.mapIndexedNotNull { index, classInfo ->
val javaClass = classes[index]
if (kotlinClassSnapshots[javaClass] == null) classInfo else null
}
val javaSnapshots: List<JavaClassSnapshot> =
snapshotJavaClasses(javaClasses, javaClassesInfo, protoBased, includeDebugInfoInSnapshot)
val javaSnapshots: List<JavaClassSnapshot> = snapshotJavaClasses(javaClasses, protoBased, includeDebugInfoInSnapshot)
val javaClassSnapshots: Map<ClassFileWithContents, JavaClassSnapshot> = javaClasses.zipToMap(javaSnapshots)
return classes.map { kotlinClassSnapshots[it] ?: javaClassSnapshots[it]!! }
}
/** Creates [KotlinClassSnapshot] of the given class, or returns `null` if the class is not a Kotlin class. */
private fun trySnapshotKotlinClass(classFile: ClassFileWithContents, classInfo: BasicClassInfo): KotlinClassSnapshot? {
return if (classInfo.isKotlinClass) {
val kotlinClassInfo = KotlinClassInfo.createFrom(classInfo.classId, classInfo.kotlinClassHeader!!, classFile.contents)
KotlinClassSnapshot(kotlinClassInfo, classInfo.supertypes)
private fun trySnapshotKotlinClass(classFile: ClassFileWithContents): KotlinClassSnapshot? {
return if (classFile.classInfo.isKotlinClass) {
val kotlinClassInfo =
KotlinClassInfo.createFrom(classFile.classInfo.classId, classFile.classInfo.kotlinClassHeader!!, classFile.contents)
KotlinClassSnapshot(kotlinClassInfo, classFile.classInfo.supertypes)
} else null
}
/** Creates [JavaClassSnapshot]s of the given Java classes. */
private fun snapshotJavaClasses(
classes: List<ClassFileWithContents>,
classesInfo: List<BasicClassInfo>,
protoBased: Boolean? = null,
includeDebugInfoInSnapshot: Boolean? = null
): List<JavaClassSnapshot> {
return if (protoBased ?: protoBasedDefaultValue) {
snapshotJavaClassesProtoBased(classes, classesInfo)
snapshotJavaClassesProtoBased(classes)
} else {
classes.mapIndexed { index, clazz ->
JavaClassSnapshotter.snapshot(clazz.contents, classesInfo[index], includeDebugInfoInSnapshot)
}
classes.map { JavaClassSnapshotter.snapshot(it, includeDebugInfoInSnapshot) }
}
}
private fun snapshotJavaClassesProtoBased(
classFilesWithContents: List<ClassFileWithContents>,
classesInfo: List<BasicClassInfo>
): List<JavaClassSnapshot> {
val classIds = classesInfo.map { it.classId }
private fun snapshotJavaClassesProtoBased(classFilesWithContents: List<ClassFileWithContents>): List<JavaClassSnapshot> {
val classIds = classFilesWithContents.map { it.classInfo.classId }
val classesContents = classFilesWithContents.map { it.contents }
val descriptors: List<JavaClassDescriptor?> = JavaClassDescriptorCreator.create(classIds, classesContents)
val snapshots: List<JavaClassSnapshot> = descriptors.mapIndexed { index, descriptor ->
@@ -14,7 +14,7 @@ import org.jetbrains.org.objectweb.asm.tree.ClassNode
/** Computes a [JavaClassSnapshot] of a Java class. */
object JavaClassSnapshotter {
fun snapshot(classContents: ByteArray, classInfo: BasicClassInfo, includeDebugInfoInSnapshot: Boolean? = null): JavaClassSnapshot {
fun snapshot(classFile: ClassFileWithContents, includeDebugInfoInSnapshot: Boolean? = null): JavaClassSnapshot {
// We will extract ABI information from the given class and store it into the `abiClass` variable.
// It is acceptable to collect more info than required, but it is incorrect to collect less info than required.
// There are 2 approaches:
@@ -29,7 +29,7 @@ object JavaClassSnapshotter {
// - SKIP_CODE and SKIP_FRAMES are set as method bodies will not be part of the ABI of the class.
// - SKIP_DEBUG is not set as it would skip method parameters, which may be used by annotation processors like Room.
// - EXPAND_FRAMES is not needed (and not relevant when SKIP_CODE is set).
val classReader = ClassReader(classContents)
val classReader = ClassReader(classFile.contents)
classReader.accept(abiClass, ClassReader.SKIP_CODE or ClassReader.SKIP_FRAMES)
// Then, remove non-ABI info, which includes:
@@ -52,7 +52,10 @@ object JavaClassSnapshotter {
abiClass.methods.clear()
val classAbiExcludingMembers = abiClass.let { snapshotJavaElement(it, it.name, includeDebugInfoInSnapshot) }
return RegularJavaClassSnapshot(classInfo.classId, classInfo.supertypes, classAbiExcludingMembers, fieldsAbi, methodsAbi)
return RegularJavaClassSnapshot(
classFile.classInfo.classId, classFile.classInfo.supertypes,
classAbiExcludingMembers, fieldsAbi, methodsAbi
)
}
private val gson by lazy {