Skip to content

Commit e6fae5c

Browse files
committed
improvement: Detect objects with main class in scripts
Prebiously, if user had a legacy script with main method then it would not be picked up at all. Now, when we detect the correct signature we try to run it. The possibility of false positives is pretty low, since user would have to have their own String or Array types.
1 parent a0edbcd commit e6fae5c

File tree

8 files changed

+117
-18
lines changed

8 files changed

+117
-18
lines changed

modules/build/src/main/scala/scala/build/ScopedSources.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ final case class ScopedSources(
6969
.map(_.wrap(codeWrapper))
7070

7171
codeWrapper match {
72-
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
72+
case _: AppCodeWrapper if wrappedScripts.size > 1 =>
7373
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
7474
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
7575
case _ => ()

modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala.build.internal
22

3-
case object AppCodeWrapper extends CodeWrapper {
3+
case class AppCodeWrapper(scalaVersion: String) extends CodeWrapper {
44
override def mainClassObject(className: Name) = className
55

66
def apply(
@@ -12,13 +12,18 @@ case object AppCodeWrapper extends CodeWrapper {
1212
) = {
1313
val wrapperObjectName = indexedWrapperName.backticked
1414

15+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
16+
val invokeMain = mainObject match {
17+
case None => ""
18+
case Some(name) => s"\n$name.main(args)"
19+
}
1520
val packageDirective =
1621
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
1722
val top = AmmUtil.normalizeNewlines(
1823
s"""$packageDirective
1924
|
2025
|object $wrapperObjectName extends App {
21-
|val scriptPath = \"\"\"$scriptPath\"\"\"
26+
|val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
2227
|""".stripMargin
2328
)
2429
val bottom = AmmUtil.normalizeNewlines(

modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package scala.build.internal
22

3+
// import scala.meta.parsers.Parsed
4+
35
/** Script code wrapper that solves problem of deadlocks when using threads. The code is placed in a
46
* class instance constructor, the created object is kept in 'mainObjectCode'.script to support
57
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
68
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
79
*/
8-
case object ClassCodeWrapper extends CodeWrapper {
10+
case class ClassCodeWrapper(scalaVersion: String) extends CodeWrapper {
911

1012
override def mainClassObject(className: Name): Name =
1113
Name(className.raw ++ "_sc")
@@ -16,8 +18,15 @@ case object ClassCodeWrapper extends CodeWrapper {
1618
extraCode: String,
1719
scriptPath: String
1820
) = {
21+
22+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
23+
24+
val mainInvocation = mainObject match
25+
case None => s"val _ = script.hashCode()"
26+
case Some(name) => s"script.$name.main(args)"
27+
1928
val name = mainClassObject(indexedWrapperName).backticked
20-
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
29+
val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
2130
val mainObjectCode =
2231
AmmUtil.normalizeNewlines(s"""|object $name {
2332
| private var args$$opt0 = Option.empty[Array[String]]
@@ -33,7 +42,7 @@ case object ClassCodeWrapper extends CodeWrapper {
3342
|
3443
| def main(args: Array[String]): Unit = {
3544
| args$$set(args)
36-
| val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
45+
| $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
3746
| }
3847
|}
3948
|

modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala

+11-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package scala.build.internal
44
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
55
* threads from script
66
*/
7-
case object ObjectCodeWrapper extends CodeWrapper {
7+
case class ObjectCodeWrapper(scalaVersion: String) extends CodeWrapper {
88

99
override def mainClassObject(className: Name): Name =
1010
Name(className.raw ++ "_sc")
@@ -15,12 +15,18 @@ case object ObjectCodeWrapper extends CodeWrapper {
1515
extraCode: String,
1616
scriptPath: String
1717
) = {
18+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
1819
val name = mainClassObject(indexedWrapperName).backticked
1920
val aliasedWrapperName = name + "$$alias"
20-
val funHashCodeMethod =
21+
val realScript =
2122
if (name == "main_sc")
22-
s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
23-
else s"${indexedWrapperName.backticked}.hashCode()"
23+
s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
24+
else s"${indexedWrapperName.backticked}"
25+
26+
val funHashCodeMethod = mainObject match {
27+
case None => s"val _ = $realScript.hashCode()"
28+
case Some(name) => s"$realScript.$name.main(args)"
29+
}
2430
// We need to call hashCode (or any other method so compiler does not report a warning)
2531
val mainObjectCode =
2632
AmmUtil.normalizeNewlines(s"""|object $name {
@@ -34,7 +40,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
3440
| }
3541
| def main(args: Array[String]): Unit = {
3642
| args$$set(args)
37-
| val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
43+
| $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
3844
| }
3945
|}
4046
|""".stripMargin)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package scala.build.internal
2+
3+
object WrapperUtils {
4+
5+
def mainObjectInScript(scalaVersion: String, code: String): Option[String] =
6+
import scala.meta.*
7+
8+
val scriptDialect =
9+
if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3
10+
11+
given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
12+
val parsedCode = code.parse[Source] match
13+
case Parsed.Success(Source(stats)) => stats
14+
case _ => Nil
15+
16+
// Check if there is a main function defined inside an object
17+
def checkSignature(defn: Defn.Def) =
18+
defn.paramClauseGroups match
19+
case List(Member.ParamClauseGroup(
20+
Type.ParamClause(Nil),
21+
List(Term.ParamClause(
22+
List(Term.Param(
23+
Nil,
24+
_: Term.Name,
25+
Some(Type.Apply.After_4_6_0(
26+
Type.Name("Array"),
27+
Type.ArgClause(List(Type.Name("String")))
28+
)),
29+
None
30+
)),
31+
None
32+
))
33+
)) => true
34+
case _ => false
35+
parsedCode.collect {
36+
case Defn.Object(_, objName, templ) =>
37+
templ.body.stats.find {
38+
case defn: Defn.Def =>
39+
defn.name.value == "main" && checkSignature(defn)
40+
case _ => false
41+
}.map(_ => objName.value)
42+
}.flatten.headOption
43+
}

modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
134134
(codeWrapper: CodeWrapper) =>
135135
if (containsMainAnnot) logger.diagnostic(
136136
codeWrapper match {
137-
case _: AppCodeWrapper.type =>
137+
case _: AppCodeWrapper =>
138138
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
139139
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
140140
}
@@ -165,16 +165,16 @@ case object ScriptPreprocessor extends Preprocessor {
165165

166166
def objectCodeWrapperForScalaVersion =
167167
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
168-
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
169-
else ObjectCodeWrapper
168+
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper(effectiveScalaVersion)
169+
else ObjectCodeWrapper(effectiveScalaVersion)
170170

171171
buildOptions.scriptOptions.forceObjectWrapper match {
172172
case Some(true) => objectCodeWrapperForScalaVersion
173173
case _ =>
174174
buildOptions.scalaOptions.platform.map(_.value) match {
175175
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
176-
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
177-
case _ => ClassCodeWrapper
176+
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper(effectiveScalaVersion)
177+
case _ => ClassCodeWrapper(effectiveScalaVersion)
178178
}
179179
}
180180
}

modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala

+36
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,42 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
6969
}
7070
}
7171

72+
test("main.sc has object") {
73+
val message = "Hello"
74+
val inputs = TestInputs(
75+
os.rel / "main.sc" ->
76+
s"""|
77+
|object Main {
78+
| def main(args: Array[String]): Unit = println("$message")
79+
|}
80+
|""".stripMargin
81+
)
82+
inputs.fromRoot { root =>
83+
val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd =
84+
root
85+
).out.trim()
86+
expect(output == message)
87+
}
88+
}
89+
90+
test("main.sc has object with object wrapper") {
91+
val message = "Hello"
92+
val inputs = TestInputs(
93+
os.rel / "main.sc" ->
94+
s"""|//> using objectWrapper
95+
|object Main {
96+
| def main(args: Array[String]): Unit = println("$message")
97+
|}
98+
|""".stripMargin
99+
)
100+
inputs.fromRoot { root =>
101+
val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(cwd =
102+
root
103+
).out.trim()
104+
expect(output == message)
105+
}
106+
}
107+
72108
if (actualScalaVersion.startsWith("3"))
73109
test("use method from main.sc file") {
74110
val message = "Hello"

project/deps.sc

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object Deps {
117117
def jsoniterScala = "2.23.2"
118118
def jsoniterScalaJava8 = "2.13.5.2"
119119
def jsoup = "1.18.3"
120-
def scalaMeta = "4.9.9"
120+
def scalaMeta = "4.12.7"
121121
def scalaNative04 = "0.4.17"
122122
def scalaNative05 = "0.5.6"
123123
def scalaNative = scalaNative05
@@ -227,7 +227,7 @@ object Deps {
227227
def semanticDbJavac = ivy"com.sourcegraph:semanticdb-javac:${Versions.javaSemanticdb}"
228228
def semanticDbScalac = ivy"org.scalameta:::semanticdb-scalac:${Versions.scalaMeta}"
229229
def scalametaSemanticDbShared =
230-
ivy"org.scalameta:semanticdb-shared_${Scala.scala213}:${Versions.scalaMeta}"
230+
ivy"org.scalameta:semanticdb-shared_2.13:${Versions.scalaMeta}"
231231
.exclude("org.jline" -> "jline") // to prevent incompatibilities with GraalVM <23
232232
.exclude("com.lihaoyi" -> "sourcecode_2.13")
233233
.exclude("org.scala-lang.modules" -> "scala-collection-compat_2.13")

0 commit comments

Comments
 (0)