Skip to content

Commit 4793fd8

Browse files
committed
improvement: Detect objects with main class in scripts
Prebiously, if user had a legacy script with main method then it would not be picked up at all. Now, when we detect the correct signature we try to run it. This will work in case of `def main..` and when object extends App The possibility of false positives is pretty low, since user would have to have their own App, String or Array types. We will also only use that object if there are no toplevel statements
1 parent a0edbcd commit 4793fd8

File tree

10 files changed

+259
-22
lines changed

10 files changed

+259
-22
lines changed

modules/build/src/main/java/scala/build/internal/JavaParserProxyMakerSubst.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ public JavaParserProxy get(
2020
scala.build.Logger logger,
2121
Supplier<String> javaCommand
2222
) {
23-
return new JavaParserProxyBinary(archiveCache, logger, javaClassNameVersionOpt, javaCommand);
23+
return new JavaParserProxyBinary( archiveCache, logger, javaClassNameVersionOpt, javaCommand);
2424
}
2525
}

modules/build/src/main/scala/scala/build/ScopedSources.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ final case class ScopedSources(
6262
): Either[BuildException, Sources] = either {
6363
val combinedOptions = combinedBuildOptions(scope, baseOptions)
6464

65-
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions)
65+
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger)
6666

6767
val wrappedScripts = unwrappedScripts
6868
.flatMap(_.valueFor(scope).toSeq)
6969
.map(_.wrap(codeWrapper))
7070

7171
codeWrapper match {
72-
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
72+
case _: AppCodeWrapper if wrappedScripts.size > 1 =>
7373
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
7474
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
7575
case _ => ()

modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala.build.internal
22

3-
case object AppCodeWrapper extends CodeWrapper {
3+
case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
44
override def mainClassObject(className: Name) = className
55

66
def apply(
@@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper {
1212
) = {
1313
val wrapperObjectName = indexedWrapperName.backticked
1414

15+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
16+
val invokeMain = mainObject match
17+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)"
18+
case otherwise =>
19+
otherwise.warningMessage.foreach(log)
20+
""
1521
val packageDirective =
1622
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
1723
val top = AmmUtil.normalizeNewlines(
1824
s"""$packageDirective
1925
|
2026
|object $wrapperObjectName extends App {
21-
|val scriptPath = \"\"\"$scriptPath\"\"\"
27+
|val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
2228
|""".stripMargin
2329
)
2430
val bottom = AmmUtil.normalizeNewlines(

modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package scala.build.internal
55
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
66
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
77
*/
8-
case object ClassCodeWrapper extends CodeWrapper {
8+
case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
99

1010
override def mainClassObject(className: Name): Name =
1111
Name(className.raw ++ "_sc")
@@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper {
1616
extraCode: String,
1717
scriptPath: String
1818
) = {
19+
20+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
21+
val mainInvocation = mainObject match
22+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)"
23+
case otherwise =>
24+
otherwise.warningMessage.foreach(log)
25+
s"val _ = script.hashCode()"
26+
1927
val name = mainClassObject(indexedWrapperName).backticked
20-
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
28+
val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
2129
val mainObjectCode =
2230
AmmUtil.normalizeNewlines(s"""|object $name {
2331
| private var args$$opt0 = Option.empty[Array[String]]
@@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper {
3341
|
3442
| def main(args: Array[String]): Unit = {
3543
| args$$set(args)
36-
| val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
44+
| $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
3745
| }
3846
|}
3947
|

modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala

+12-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package scala.build.internal
44
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
55
* threads from script
66
*/
7-
case object ObjectCodeWrapper extends CodeWrapper {
7+
case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
88

99
override def mainClassObject(className: Name): Name =
1010
Name(className.raw ++ "_sc")
@@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper {
1515
extraCode: String,
1616
scriptPath: String
1717
) = {
18+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
1819
val name = mainClassObject(indexedWrapperName).backticked
1920
val aliasedWrapperName = name + "$$alias"
20-
val funHashCodeMethod =
21+
val realScript =
2122
if (name == "main_sc")
22-
s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
23-
else s"${indexedWrapperName.backticked}.hashCode()"
23+
s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
24+
else s"${indexedWrapperName.backticked}"
25+
26+
val funHashCodeMethod = mainObject match
27+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)"
28+
case otherwise =>
29+
otherwise.warningMessage.foreach(log)
30+
s"val _ = $realScript.hashCode()"
2431
// We need to call hashCode (or any other method so compiler does not report a warning)
2532
val mainObjectCode =
2633
AmmUtil.normalizeNewlines(s"""|object $name {
@@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
3441
| }
3542
| def main(args: Array[String]): Unit = {
3643
| args$$set(args)
37-
| val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
44+
| $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
3845
| }
3946
|}
4047
|""".stripMargin)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package scala.build.internal
2+
3+
import scala.build.internal.util.WarningMessages
4+
5+
object WrapperUtils {
6+
7+
enum ScriptMainMethod:
8+
case Exists(name: String)
9+
case Multiple(names: Seq[String])
10+
case ToplevelStatsPresent
11+
case NoMain
12+
13+
def warningMessage: Option[String] =
14+
this match
15+
case ScriptMainMethod.Multiple(names) =>
16+
Some(WarningMessages.multipleMainObjectsInScript(names))
17+
case ScriptMainMethod.ToplevelStatsPresent => Some(
18+
WarningMessages.mixedToplvelAndObjectInScript
19+
)
20+
case _ => None
21+
22+
def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod =
23+
import scala.meta.*
24+
25+
val scriptDialect =
26+
if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3
27+
28+
given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
29+
val parsedCode = code.parse[Source] match
30+
case Parsed.Success(Source(stats)) => stats
31+
case _ => Nil
32+
33+
// Check if there is a main function defined inside an object
34+
def checkSignature(defn: Defn.Def) =
35+
defn.paramClauseGroups match
36+
case List(Member.ParamClauseGroup(
37+
Type.ParamClause(Nil),
38+
List(Term.ParamClause(
39+
List(Term.Param(
40+
Nil,
41+
_: Term.Name,
42+
Some(Type.Apply.After_4_6_0(
43+
Type.Name("Array"),
44+
Type.ArgClause(List(Type.Name("String")))
45+
)),
46+
None
47+
)),
48+
None
49+
))
50+
)) => true
51+
case _ => false
52+
53+
def noToplevelStatements = parsedCode.forall {
54+
case _: Term => false
55+
case _ => true
56+
}
57+
58+
def hasMainSignature(templ: Template) = templ.body.stats.exists {
59+
case defn: Defn.Def =>
60+
defn.name.value == "main" && checkSignature(defn)
61+
case _ => false
62+
}
63+
def extendsApp(templ: Template) = templ.inits match
64+
case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true
65+
case _ => false
66+
val potentialMains = parsedCode.collect {
67+
case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) =>
68+
Seq(objName.value)
69+
}.flatten
70+
71+
potentialMains match
72+
case head :: Nil if noToplevelStatements =>
73+
ScriptMainMethod.Exists(head)
74+
case head :: Nil =>
75+
ScriptMainMethod.ToplevelStatsPresent
76+
case Nil => ScriptMainMethod.NoMain
77+
case seq =>
78+
ScriptMainMethod.Multiple(seq)
79+
80+
}

modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala

+6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ object WarningMessages {
105105
val offlineModeBloopJvmNotFound =
106106
"Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback"
107107

108+
def multipleMainObjectsInScript(names: Seq[String]) =
109+
s"Only single main is allowed within scripts and multiple main classes were found in the script: ${names.mkString(", ")}"
110+
111+
def mixedToplvelAndObjectInScript =
112+
"Script contains objects with main methods and top-level statements, only the latter will be run."
113+
108114
def directivesInMultipleFilesWarning(
109115
projectFilePath: String,
110116
pathsToReport: Iterable[String] = Nil

modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala

+10-7
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
134134
(codeWrapper: CodeWrapper) =>
135135
if (containsMainAnnot) logger.diagnostic(
136136
codeWrapper match {
137-
case _: AppCodeWrapper.type =>
137+
case _: AppCodeWrapper =>
138138
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
139139
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
140140
}
@@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
157157
* @return
158158
* code wrapper compatible with provided BuildOptions
159159
*/
160-
def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
160+
def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = {
161161
val effectiveScalaVersion =
162162
buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt)
163163
.orElse(buildOptions.scalaOptions.defaultScalaVersion)
164164
.getOrElse(Constants.defaultScalaVersion)
165+
def logWarning(msg: String) = logger.diagnostic(msg)
165166

166167
def objectCodeWrapperForScalaVersion =
167168
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
168-
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
169-
else ObjectCodeWrapper
169+
if effectiveScalaVersion.startsWith("2") then
170+
AppCodeWrapper(effectiveScalaVersion, logWarning)
171+
else ObjectCodeWrapper(effectiveScalaVersion, logWarning)
170172

171173
buildOptions.scriptOptions.forceObjectWrapper match {
172174
case Some(true) => objectCodeWrapperForScalaVersion
173175
case _ =>
174176
buildOptions.scalaOptions.platform.map(_.value) match {
175-
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
176-
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
177-
case _ => ClassCodeWrapper
177+
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
178+
case _ if effectiveScalaVersion.startsWith("2") =>
179+
AppCodeWrapper(effectiveScalaVersion, logWarning)
180+
case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning)
178181
}
179182
}
180183
}

0 commit comments

Comments
 (0)