From 7c3bdde102d18de855ea48d66edd71aa19ee3d7a Mon Sep 17 00:00:00 2001 From: Ilya Chernikov Date: Fri, 15 Feb 2019 17:57:52 +0100 Subject: [PATCH] Fix sequence of the script definitions search - explicit ones should be tried first --- .../kotlin/scripts/ScriptProviderTest.kt | 12 ++++++++---- .../legacy/CliScriptDefinitionProvider.kt | 16 +++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/compiler/tests/org/jetbrains/kotlin/scripts/ScriptProviderTest.kt b/compiler/tests/org/jetbrains/kotlin/scripts/ScriptProviderTest.kt index efafcf026f2..16d54f2a338 100644 --- a/compiler/tests/org/jetbrains/kotlin/scripts/ScriptProviderTest.kt +++ b/compiler/tests/org/jetbrains/kotlin/scripts/ScriptProviderTest.kt @@ -22,7 +22,7 @@ class ScriptProviderTest : KtUsefulTestCase() { val genDefCounter = AtomicInteger() val standardDef = FakeScriptDefinition() val shadedDef = FakeScriptDefinition(".x.kts") - val provider = CliScriptDefinitionProvider().apply { + val provider = TestCliScriptDefinitionProvider(standardDef).apply { setScriptDefinitions(listOf(shadedDef, standardDef)) setScriptDefinitionsSources(listOf(TestScriptDefinitionSource(genDefCounter, ".y.kts", ".x.kts"))) } @@ -41,8 +41,8 @@ class ScriptProviderTest : KtUsefulTestCase() { provider.isScript("a.x.kts").let { Assert.assertTrue(it) - Assert.assertEquals(2, genDefCounter.get()) - Assert.assertEquals(0, shadedDef.matchCounter.get()) + Assert.assertEquals(1, genDefCounter.get()) + Assert.assertEquals(1, shadedDef.matchCounter.get()) } provider.isScript("a.z.kts").let { @@ -58,7 +58,7 @@ class ScriptProviderTest : KtUsefulTestCase() { } } -private class FakeScriptDefinition(val suffix: String = ".kts") : KotlinScriptDefinition(ScriptTemplateWithArgs::class) { +private open class FakeScriptDefinition(val suffix: String = ".kts") : KotlinScriptDefinition(ScriptTemplateWithArgs::class) { val matchCounter = AtomicInteger() override fun isScript(fileName: String): Boolean = fileName.endsWith(suffix).also { if (it) matchCounter.incrementAndGet() @@ -76,4 +76,8 @@ private class TestScriptDefinitionSource(val counter: AtomicInteger, val defGens yield(gen()) } } +} + +private class TestCliScriptDefinitionProvider(private val standardDef: KotlinScriptDefinition) : CliScriptDefinitionProvider() { + override fun getDefaultScriptDefinition(): KotlinScriptDefinition = standardDef } \ No newline at end of file diff --git a/plugins/scripting/scripting-compiler/src/org/jetbrains/kotlin/scripting/legacy/CliScriptDefinitionProvider.kt b/plugins/scripting/scripting-compiler/src/org/jetbrains/kotlin/scripting/legacy/CliScriptDefinitionProvider.kt index 78f5ed1a74f..96bf0c60103 100644 --- a/plugins/scripting/scripting-compiler/src/org/jetbrains/kotlin/scripting/legacy/CliScriptDefinitionProvider.kt +++ b/plugins/scripting/scripting-compiler/src/org/jetbrains/kotlin/scripting/legacy/CliScriptDefinitionProvider.kt @@ -10,12 +10,16 @@ import org.jetbrains.kotlin.script.ScriptDefinitionsSource import org.jetbrains.kotlin.script.StandardScriptDefinition import kotlin.concurrent.write -class CliScriptDefinitionProvider : LazyScriptDefinitionProvider() { +open class CliScriptDefinitionProvider : LazyScriptDefinitionProvider() { private val definitionsFromSources: MutableList> = arrayListOf() - private val definitions: MutableList = arrayListOf(StandardScriptDefinition) + private val definitions: MutableList = arrayListOf() + private var hasStandardDefinition = true - override val currentDefinitions: Sequence = - definitionsFromSources.asSequence().flatMap { it } + definitions.asSequence() + override val currentDefinitions: Sequence + get() { + val base = definitions.asSequence() + definitionsFromSources.asSequence().flatMap { it } + return if (hasStandardDefinition) base + getDefaultScriptDefinition() else base + } override fun getDefaultScriptDefinition(): KotlinScriptDefinition { return StandardScriptDefinition @@ -24,7 +28,9 @@ class CliScriptDefinitionProvider : LazyScriptDefinitionProvider() { fun setScriptDefinitions(newDefinitions: List) { lock.write { definitions.clear() - definitions.addAll(newDefinitions) + val (withoutStdDef, stdDef) = newDefinitions.partition { it != getDefaultScriptDefinition() } + definitions.addAll(withoutStdDef) + hasStandardDefinition = stdDef.isNotEmpty() } }