Refactor Main-Class computation in CLI for JVM with .jar outputs

Compute FQ name of the main class right after running the analysis and
before invoking codegen. This is needed because MainFunctionDetector
depends on BindingContext, and JVM IR needs to clear BindingContext as
soon as it's not necessary to reduce peak memory usage, thus breaking
any usages of data from it after the codegen.

Also refactor and use the extracted, but not properly reused previously,
copy of findMainClass in findMainClass.kt.

Note that this replaces NPE in KT-42868 with an UOE.
This commit is contained in:
Alexander Udalov
2020-10-21 22:52:37 +02:00
parent 0a18be62e5
commit dd813777b9
3 changed files with 27 additions and 32 deletions
@@ -53,7 +53,6 @@ import org.jetbrains.kotlin.config.*
import org.jetbrains.kotlin.container.get
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.diagnostics.*
import org.jetbrains.kotlin.fileClasses.JvmFileClassUtil
import org.jetbrains.kotlin.fir.FirPsiSourceElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.FirAnalyzerFacade
@@ -63,7 +62,6 @@ import org.jetbrains.kotlin.fir.backend.jvm.FirMetadataSerializer
import org.jetbrains.kotlin.fir.checkers.registerExtendedCommonCheckers
import org.jetbrains.kotlin.fir.java.FirProjectSessionProvider
import org.jetbrains.kotlin.fir.session.FirSessionFactory
import org.jetbrains.kotlin.idea.MainFunctionDetector
import org.jetbrains.kotlin.ir.backend.jvm.jvmResolveLibraries
import org.jetbrains.kotlin.javac.JavacWrapper
import org.jetbrains.kotlin.load.kotlin.ModuleVisibilityManager
@@ -88,14 +86,14 @@ object KotlinToJVMBytecodeCompiler {
private fun writeOutput(
configuration: CompilerConfiguration,
outputFiles: OutputFileCollection,
mainClassProvider: MainClassProvider?
mainClassFqName: FqName?
) {
val reportOutputFiles = configuration.getBoolean(CommonConfigurationKeys.REPORT_OUTPUT_FILES)
val jarPath = configuration.get(JVMConfigurationKeys.OUTPUT_JAR)
val messageCollector = configuration.get(CLIConfigurationKeys.MESSAGE_COLLECTOR_KEY, MessageCollector.NONE)
if (jarPath != null) {
val includeRuntime = configuration.get(JVMConfigurationKeys.INCLUDE_RUNTIME, false)
CompileEnvironmentUtil.writeToJar(jarPath, includeRuntime, mainClassProvider?.mainClassFqName, outputFiles)
CompileEnvironmentUtil.writeToJar(jarPath, includeRuntime, mainClassFqName, outputFiles)
if (reportOutputFiles) {
val message = OutputMessageUtil.formatOutputMessage(outputFiles.asList().flatMap { it.sourceFiles }.distinct(), jarPath)
messageCollector.report(OUTPUT, message)
@@ -200,6 +198,11 @@ object KotlinToJVMBytecodeCompiler {
result.throwIfError()
val mainClassFqName =
if (chunk.size == 1 && projectConfiguration.get(JVMConfigurationKeys.OUTPUT_JAR) != null)
findMainClass(result.bindingContext, projectConfiguration.languageVersionSettings, environment.getSourceFiles())
else null
val outputs = newLinkedHashMapWithExpectedSize<Module, GenerationState>(chunk.size)
val localFileSystem = VirtualFileManager.getInstance().getFileSystem(StandardFileSystems.FILE_PROTOCOL)
@@ -214,20 +217,20 @@ object KotlinToJVMBytecodeCompiler {
outputs[module] = generate(environment, moduleConfiguration, result, ktFiles, module)
}
return writeOutputs(environment, projectConfiguration, chunk, outputs)
return writeOutputs(environment, projectConfiguration, chunk, outputs, mainClassFqName)
}
private fun writeOutputs(
environment: KotlinCoreEnvironment,
projectConfiguration: CompilerConfiguration,
chunk: List<Module>,
outputs: Map<Module, GenerationState>
outputs: Map<Module, GenerationState>,
mainClassFqName: FqName?
): Boolean {
try {
for ((_, state) in outputs) {
ProgressIndicatorAndCompilationCanceledStatus.checkCanceled()
val mainClassProvider = if (outputs.size == 1) MainClassProvider(state, environment) else null
writeOutput(state.configuration, state.factory, mainClassProvider)
writeOutput(state.configuration, state.factory, mainClassFqName)
}
} finally {
outputs.values.forEach(GenerationState::destroy)
@@ -436,7 +439,13 @@ object KotlinToJVMBytecodeCompiler {
ProgressIndicatorAndCompilationCanceledStatus.checkCanceled()
outputs[module] = generationState
}
return writeOutputs(environment, projectConfiguration, chunk, outputs)
val mainClassFqName: FqName? =
if (chunk.size == 1 && projectConfiguration.get(JVMConfigurationKeys.OUTPUT_JAR) != null)
TODO(".jar output is not yet supported for -Xuse-fir: KT-42868")
else null
return writeOutputs(environment, projectConfiguration, chunk, outputs, mainClassFqName)
}
private fun FirDiagnostic<*>.toRegularDiagnostic(): Diagnostic {
@@ -469,22 +478,6 @@ object KotlinToJVMBytecodeCompiler {
(File(path).takeIf(File::isAbsolute) ?: buildFile.resolveSibling(path)).absolutePath
}
class MainClassProvider(generationState: GenerationState, environment: KotlinCoreEnvironment) {
val mainClassFqName: FqName? by lazy { findMainClass(generationState, environment.getSourceFiles()) }
private fun findMainClass(generationState: GenerationState, files: List<KtFile>): FqName? {
val mainFunctionDetector = MainFunctionDetector(generationState.bindingContext, generationState.languageVersionSettings)
return files.asSequence()
.map { file ->
if (mainFunctionDetector.hasMain(file.declarations))
JvmFileClassUtil.getFileClassInfoNoResolve(file).facadeClassFqName
else
null
}
.singleOrNull { it != null }
}
}
fun compileBunchOfSources(environment: KotlinCoreEnvironment): Boolean {
val moduleVisibilityManager = ModuleVisibilityManager.SERVICE.getInstance(environment.project)
@@ -498,7 +491,7 @@ object KotlinToJVMBytecodeCompiler {
val generationState = analyzeAndGenerate(environment) ?: return false
try {
writeOutput(environment.configuration, generationState.factory, MainClassProvider(generationState, environment))
writeOutput(environment.configuration, generationState.factory, null)
return true
} finally {
generationState.destroy()
@@ -5,14 +5,15 @@
package org.jetbrains.kotlin.cli.jvm.compiler
import org.jetbrains.kotlin.codegen.state.GenerationState
import org.jetbrains.kotlin.config.LanguageVersionSettings
import org.jetbrains.kotlin.fileClasses.JvmFileClassUtil
import org.jetbrains.kotlin.idea.MainFunctionDetector
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.resolve.BindingContext
fun findMainClass(generationState: GenerationState, files: List<KtFile>): FqName? {
val mainFunctionDetector = MainFunctionDetector(generationState.bindingContext, generationState.languageVersionSettings)
fun findMainClass(bindingContext: BindingContext, languageVersionSettings: LanguageVersionSettings, files: List<KtFile>): FqName? {
val mainFunctionDetector = MainFunctionDetector(bindingContext, languageVersionSettings)
return files.asSequence()
.map { file ->
if (mainFunctionDetector.hasMain(file.declarations))
@@ -21,4 +22,4 @@ fun findMainClass(generationState: GenerationState, files: List<KtFile>): FqName
null
}
.singleOrNull { it != null }
}
}
@@ -10,7 +10,6 @@ import com.intellij.openapi.module.Module
import com.intellij.openapi.roots.LibraryOrderEntry
import com.intellij.openapi.roots.ModuleRootManager
import com.intellij.openapi.roots.OrderRootType
import com.intellij.openapi.roots.libraries.ui.OrderRoot
import com.intellij.openapi.util.io.FileUtil
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VirtualFile
@@ -32,6 +31,7 @@ import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.config.JVMConfigurationKeys
import org.jetbrains.kotlin.config.JvmTarget
import org.jetbrains.kotlin.diagnostics.rendering.DefaultErrorMessages
import org.jetbrains.kotlin.idea.resolve.getLanguageVersionSettings
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.test.KotlinBaseTest.TestFile
import org.jetbrains.kotlin.test.MockLibraryUtil
@@ -198,7 +198,8 @@ class DebuggerTestCompilerFacility(
state.factory.writeAllTo(classesDir)
return findMainClass(state, files)?.asString() ?: error("Cannot find main class name")
return findMainClass(bindingContext, resolutionFacade.getLanguageVersionSettings(), files)?.asString()
?: error("Cannot find main class name")
}
private fun getClasspath(module: Module): List<String> {