FIR: encapsulate annotation loading in KotlinDeserializedJvmSymbolsProvider.knownClassNamesInPackage into class

This commit is contained in:
Ilya Kirillov
2021-01-17 00:32:22 +01:00
parent 169134655a
commit 3cee5e848a
3 changed files with 172 additions and 161 deletions
@@ -0,0 +1,165 @@
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.fir.java.deserialization
import org.jetbrains.kotlin.SpecialJvmAnnotations
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.diagnostics.ConeSimpleDiagnostic
import org.jetbrains.kotlin.fir.diagnostics.DiagnosticKind
import org.jetbrains.kotlin.fir.expressions.FirAnnotationCall
import org.jetbrains.kotlin.fir.expressions.FirClassReferenceExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.buildUnaryArgumentList
import org.jetbrains.kotlin.fir.expressions.builder.*
import org.jetbrains.kotlin.fir.java.createConstantOrError
import org.jetbrains.kotlin.fir.references.builder.buildErrorNamedReference
import org.jetbrains.kotlin.fir.references.builder.buildResolvedNamedReference
import org.jetbrains.kotlin.fir.references.impl.FirReferencePlaceholderForResolvedAnnotations
import org.jetbrains.kotlin.fir.resolve.firSymbolProvider
import org.jetbrains.kotlin.fir.resolve.providers.getClassDeclaredPropertySymbols
import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag
import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl
import org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import org.jetbrains.kotlin.fir.types.builder.buildResolvedTypeRef
import org.jetbrains.kotlin.fir.types.constructClassType
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinaryClass
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.constants.ClassLiteralValue
internal class AnnotationsLoader(private val session: FirSession) {
private fun loadAnnotation(
annotationClassId: ClassId, result: MutableList<FirAnnotationCall>,
): KotlinJvmBinaryClass.AnnotationArgumentVisitor {
val lookupTag = ConeClassLikeLookupTagImpl(annotationClassId)
return object : KotlinJvmBinaryClass.AnnotationArgumentVisitor {
private val argumentMap = mutableMapOf<Name, FirExpression>()
override fun visit(name: Name?, value: Any?) {
if (name != null) {
argumentMap[name] = createConstant(value)
}
}
private fun ClassLiteralValue.toFirClassReferenceExpression(): FirClassReferenceExpression {
val literalLookupTag = ConeClassLikeLookupTagImpl(classId)
return buildClassReferenceExpression {
classTypeRef = literalLookupTag.toDefaultResolvedTypeRef()
}
}
private fun ClassId.toEnumEntryReferenceExpression(name: Name): FirExpression {
return buildFunctionCall {
val entryPropertySymbol =
session.firSymbolProvider.getClassDeclaredPropertySymbols(
this@toEnumEntryReferenceExpression, name,
).firstOrNull()
calleeReference = when {
entryPropertySymbol != null -> {
buildResolvedNamedReference {
this.name = name
resolvedSymbol = entryPropertySymbol
}
}
else -> {
buildErrorNamedReference {
diagnostic = ConeSimpleDiagnostic(
"Strange deserialized enum value: ${this@toEnumEntryReferenceExpression}.$name",
DiagnosticKind.Java,
)
}
}
}
}
}
override fun visitClassLiteral(name: Name, value: ClassLiteralValue) {
argumentMap[name] = buildGetClassCall {
argumentList = buildUnaryArgumentList(value.toFirClassReferenceExpression())
}
}
override fun visitEnum(name: Name, enumClassId: ClassId, enumEntryName: Name) {
argumentMap[name] = enumClassId.toEnumEntryReferenceExpression(enumEntryName)
}
override fun visitArray(name: Name): KotlinJvmBinaryClass.AnnotationArrayArgumentVisitor {
return object : KotlinJvmBinaryClass.AnnotationArrayArgumentVisitor {
private val elements = mutableListOf<FirExpression>()
override fun visit(value: Any?) {
elements.add(createConstant(value))
}
override fun visitEnum(enumClassId: ClassId, enumEntryName: Name) {
elements.add(enumClassId.toEnumEntryReferenceExpression(enumEntryName))
}
override fun visitClassLiteral(value: ClassLiteralValue) {
elements.add(
buildGetClassCall {
argumentList = buildUnaryArgumentList(value.toFirClassReferenceExpression())
}
)
}
override fun visitEnd() {
argumentMap[name] = buildArrayOfCall {
argumentList = buildArgumentList {
arguments += elements
}
}
}
}
}
override fun visitAnnotation(name: Name, classId: ClassId): KotlinJvmBinaryClass.AnnotationArgumentVisitor {
val list = mutableListOf<FirAnnotationCall>()
val visitor = loadAnnotation(classId, list)
return object : KotlinJvmBinaryClass.AnnotationArgumentVisitor by visitor {
override fun visitEnd() {
visitor.visitEnd()
argumentMap[name] = list.single()
}
}
}
override fun visitEnd() {
result += buildAnnotationCall {
annotationTypeRef = lookupTag.toDefaultResolvedTypeRef()
argumentList = buildArgumentList {
for ((name, expression) in argumentMap) {
arguments += buildNamedArgumentExpression {
this.expression = expression
this.name = name
isSpread = false
}
}
}
calleeReference = FirReferencePlaceholderForResolvedAnnotations
}
}
private fun createConstant(value: Any?): FirExpression {
return value.createConstantOrError(session)
}
}
}
internal fun loadAnnotationIfNotSpecial(
annotationClassId: ClassId, result: MutableList<FirAnnotationCall>,
): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
if (annotationClassId in SpecialJvmAnnotations.SPECIAL_ANNOTATIONS) return null
return loadAnnotation(annotationClassId, result)
}
private fun ConeClassLikeLookupTag.toDefaultResolvedTypeRef(): FirResolvedTypeRef =
buildResolvedTypeRef {
type = constructClassType(emptyArray(), isNullable = false)
}
}
@@ -268,6 +268,7 @@ private data class MemberAnnotations(val memberAnnotations: MutableMap<MemberSig
// TODO: better to be in KotlinDeserializedJvmSymbolsProvider?
private fun FirSession.loadMemberAnnotations(kotlinBinaryClass: KotlinJvmBinaryClass, byteContent: ByteArray?): MemberAnnotations {
val memberAnnotations = hashMapOf<MemberSignature, MutableList<FirAnnotationCall>>()
val annotationsLoader = AnnotationsLoader(this)
kotlinBinaryClass.visitMembers(object : KotlinJvmBinaryClass.MemberVisitor {
override fun visitMethod(name: Name, desc: String): KotlinJvmBinaryClass.MethodAnnotationVisitor? {
@@ -296,7 +297,7 @@ private fun FirSession.loadMemberAnnotations(kotlinBinaryClass: KotlinJvmBinaryC
result = arrayListOf()
memberAnnotations[paramSignature] = result
}
return loadAnnotationIfNotSpecial(classId, result)
return annotationsLoader.loadAnnotationIfNotSpecial(classId, result)
}
}
@@ -304,7 +305,7 @@ private fun FirSession.loadMemberAnnotations(kotlinBinaryClass: KotlinJvmBinaryC
private val result = arrayListOf<FirAnnotationCall>()
override fun visitAnnotation(classId: ClassId, source: SourceElement): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
return loadAnnotationIfNotSpecial(classId, result)
return annotationsLoader.loadAnnotationIfNotSpecial(classId, result)
}
override fun visitEnd() {
@@ -316,15 +317,4 @@ private fun FirSession.loadMemberAnnotations(kotlinBinaryClass: KotlinJvmBinaryC
}, byteContent)
return MemberAnnotations(memberAnnotations)
}
// TODO: Or, better to migrate annotation deserialization in KotlinDeserializedJvmSymbolsProvider to here?
private fun FirSession.loadAnnotationIfNotSpecial(
annotationClassId: ClassId,
result: MutableList<FirAnnotationCall>
): KotlinJvmBinaryClass.AnnotationArgumentVisitor? =
(firSymbolProvider as? FirCompositeSymbolProvider)
?.providers
?.filterIsInstance<KotlinDeserializedJvmSymbolsProvider>()
?.singleOrNull()
?.loadAnnotationIfNotSpecial(annotationClassId, result)
}
@@ -7,7 +7,6 @@ package org.jetbrains.kotlin.fir.java.deserialization
import com.intellij.openapi.progress.ProcessCanceledException
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.SpecialJvmAnnotations
import org.jetbrains.kotlin.descriptors.SourceElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.ThreadSafeMutableState
@@ -16,27 +15,15 @@ import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.deserialization.FirConstDeserializer
import org.jetbrains.kotlin.fir.deserialization.FirDeserializationContext
import org.jetbrains.kotlin.fir.deserialization.deserializeClassToSymbol
import org.jetbrains.kotlin.fir.diagnostics.ConeSimpleDiagnostic
import org.jetbrains.kotlin.fir.diagnostics.DiagnosticKind
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.builder.*
import org.jetbrains.kotlin.fir.java.JavaSymbolProvider
import org.jetbrains.kotlin.fir.java.createConstantOrError
import org.jetbrains.kotlin.fir.java.topLevelName
import org.jetbrains.kotlin.fir.references.builder.buildErrorNamedReference
import org.jetbrains.kotlin.fir.references.builder.buildResolvedNamedReference
import org.jetbrains.kotlin.fir.references.impl.FirReferencePlaceholderForResolvedAnnotations
import org.jetbrains.kotlin.fir.resolve.firSymbolProvider
import org.jetbrains.kotlin.fir.resolve.providers.*
import org.jetbrains.kotlin.fir.scopes.KotlinScopeProvider
import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag
import org.jetbrains.kotlin.fir.symbols.impl.*
import org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import org.jetbrains.kotlin.fir.types.builder.buildResolvedTypeRef
import org.jetbrains.kotlin.fir.types.constructClassType
import org.jetbrains.kotlin.load.java.JavaClassFinder
import org.jetbrains.kotlin.load.kotlin.*
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinaryClass.AnnotationArrayArgumentVisitor
import org.jetbrains.kotlin.load.kotlin.header.KotlinClassHeader
import org.jetbrains.kotlin.metadata.ProtoBuf
import org.jetbrains.kotlin.metadata.deserialization.Flags
@@ -47,7 +34,6 @@ import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.isOneSegmentFQN
import org.jetbrains.kotlin.resolve.constants.ClassLiteralValue
import org.jetbrains.kotlin.resolve.jvm.JvmClassName
import org.jetbrains.kotlin.serialization.deserialization.IncompatibleVersionErrorData
import org.jetbrains.kotlin.serialization.deserialization.getName
@@ -63,6 +49,7 @@ class KotlinDeserializedJvmSymbolsProvider(
private val javaClassFinder: JavaClassFinder,
private val kotlinScopeProvider: KotlinScopeProvider,
) : FirSymbolProvider(session) {
private val annotationsLoader = AnnotationsLoader(session)
private val classCache = SymbolProviderCache<ClassId, FirRegularClassSymbol>()
private val typeAliasCache = SymbolProviderCache<ClassId, FirTypeAliasSymbol>()
private val packagePartsCache = SymbolProviderCache<FqName, Collection<PackagePartsCacheData>>()
@@ -159,137 +146,6 @@ class KotlinDeserializedJvmSymbolsProvider(
return JvmProtoBufUtil.readClassDataFrom(data, strings)
}
private fun ConeClassLikeLookupTag.toDefaultResolvedTypeRef(): FirResolvedTypeRef =
buildResolvedTypeRef {
type = constructClassType(emptyArray(), isNullable = false)
}
private fun loadAnnotation(
annotationClassId: ClassId, result: MutableList<FirAnnotationCall>,
): KotlinJvmBinaryClass.AnnotationArgumentVisitor {
val lookupTag = ConeClassLikeLookupTagImpl(annotationClassId)
return object : KotlinJvmBinaryClass.AnnotationArgumentVisitor {
private val argumentMap = mutableMapOf<Name, FirExpression>()
override fun visit(name: Name?, value: Any?) {
if (name != null) {
argumentMap[name] = createConstant(value)
}
}
private fun ClassLiteralValue.toFirClassReferenceExpression(): FirClassReferenceExpression {
val literalLookupTag = ConeClassLikeLookupTagImpl(classId)
return buildClassReferenceExpression {
classTypeRef = literalLookupTag.toDefaultResolvedTypeRef()
}
}
private fun ClassId.toEnumEntryReferenceExpression(name: Name): FirExpression {
return buildFunctionCall {
val entryPropertySymbol =
this@KotlinDeserializedJvmSymbolsProvider.session.firSymbolProvider.getClassDeclaredPropertySymbols(
this@toEnumEntryReferenceExpression, name,
).firstOrNull()
calleeReference = when {
entryPropertySymbol != null -> {
buildResolvedNamedReference {
this.name = name
resolvedSymbol = entryPropertySymbol
}
}
else -> {
buildErrorNamedReference {
diagnostic = ConeSimpleDiagnostic(
"Strange deserialized enum value: ${this@toEnumEntryReferenceExpression}.$name",
DiagnosticKind.Java,
)
}
}
}
}
}
override fun visitClassLiteral(name: Name, value: ClassLiteralValue) {
argumentMap[name] = buildGetClassCall {
argumentList = buildUnaryArgumentList(value.toFirClassReferenceExpression())
}
}
override fun visitEnum(name: Name, enumClassId: ClassId, enumEntryName: Name) {
argumentMap[name] = enumClassId.toEnumEntryReferenceExpression(enumEntryName)
}
override fun visitArray(name: Name): AnnotationArrayArgumentVisitor {
return object : AnnotationArrayArgumentVisitor {
private val elements = mutableListOf<FirExpression>()
override fun visit(value: Any?) {
elements.add(createConstant(value))
}
override fun visitEnum(enumClassId: ClassId, enumEntryName: Name) {
elements.add(enumClassId.toEnumEntryReferenceExpression(enumEntryName))
}
override fun visitClassLiteral(value: ClassLiteralValue) {
elements.add(
buildGetClassCall {
argumentList = buildUnaryArgumentList(value.toFirClassReferenceExpression())
}
)
}
override fun visitEnd() {
argumentMap[name] = buildArrayOfCall {
argumentList = buildArgumentList {
arguments += elements
}
}
}
}
}
override fun visitAnnotation(name: Name, classId: ClassId): KotlinJvmBinaryClass.AnnotationArgumentVisitor {
val list = mutableListOf<FirAnnotationCall>()
val visitor = loadAnnotation(classId, list)
return object : KotlinJvmBinaryClass.AnnotationArgumentVisitor by visitor {
override fun visitEnd() {
visitor.visitEnd()
argumentMap[name] = list.single()
}
}
}
override fun visitEnd() {
result += buildAnnotationCall {
annotationTypeRef = lookupTag.toDefaultResolvedTypeRef()
argumentList = buildArgumentList {
for ((name, expression) in argumentMap) {
arguments += buildNamedArgumentExpression {
this.expression = expression
this.name = name
isSpread = false
}
}
}
calleeReference = FirReferencePlaceholderForResolvedAnnotations
}
}
private fun createConstant(value: Any?): FirExpression {
return value.createConstantOrError(session)
}
}
}
internal fun loadAnnotationIfNotSpecial(
annotationClassId: ClassId, result: MutableList<FirAnnotationCall>,
): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
if (annotationClassId in SpecialJvmAnnotations.SPECIAL_ANNOTATIONS) return null
return loadAnnotation(annotationClassId, result)
}
private fun findAndDeserializeClassViaParent(classId: ClassId): FirRegularClassSymbol? {
val outerClassId = classId.outerClassId ?: return null
@@ -344,7 +200,7 @@ class KotlinDeserializedJvmSymbolsProvider(
kotlinJvmBinaryClass.loadClassAnnotations(
object : KotlinJvmBinaryClass.AnnotationVisitor {
override fun visitAnnotation(classId: ClassId, source: SourceElement): KotlinJvmBinaryClass.AnnotationArgumentVisitor? {
return loadAnnotationIfNotSpecial(classId, annotations)
return annotationsLoader.loadAnnotationIfNotSpecial(classId, annotations)
}
override fun visitEnd() {
@@ -415,4 +271,4 @@ private class KnownNameInPackageCache(session: FirSession, private val javaClass
val knownNames = knownClassNamesInPackage.getValue(classId.packageFqName) ?: return false
return classId.relativeClassName.topLevelName() !in knownNames
}
}
}