Skip to content

Commit a9e4c5e

Browse files
authored
Fix argument parsing of flags in the presence of allowPositional=true (#66)
Should fix #58 and #60 Previously, we allowed any arg to take positional arguments if `allowPositional = true` (which is the case for Ammonite and Mill user-defined entrypoints.), even `mainargs.Flag`s. for which being positional doesn't make sense. ```scala val positionalArgSigs = argSigs .filter { case x: ArgSig.Simple[_, _] if x.reader.noTokens => false case x: ArgSig.Simple[_, _] if x.positional => true case x => allowPositional } ``` The relevant code path was rewritten in #62, but the buggy behavior was preserved before and after that change. This wasn't caught in other uses of `mainargs.Flag`, e.g. for Ammonite/Mill's own flags, because those did not set `allowPositional=true` This PR tweaks `TokenGrouping.groupArgs` to be more discerning about how it selects positional, keyword, and missing arguments: 1. Now, only `TokenReader.Simple[_]`s with `.positional` or `allowPositional` can be positional; `Flag`s, `Leftover`, etc. cannot 2. Keyword arguments are limited only to `Flag`s and `Simple` with `!a.positional` Added `mainargs.IssueTests.issue60` as a regression test, that fails on main and passes on this PR. Existing tests all pass
1 parent 3298647 commit a9e4c5e

File tree

6 files changed

+57
-31
lines changed

6 files changed

+57
-31
lines changed

Diff for: build.sc

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mim
6262

6363
def ivyDeps = Agg(
6464
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1"
65-
) ++ Agg(ivy"com.lihaoyi::pprint:0.8.1")
65+
)
6666
}
6767

6868
def scalaMajor(scalaVersion: String) = if (isScala3(scalaVersion)) "3" else "2"

Diff for: mainargs/src/Renderer.scala

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ object Renderer {
7272
val flattenedAll: Seq[ArgSig] =
7373
mainMethods.map(_.flattenedArgSigs)
7474
.flatten
75+
.map(_._1)
76+
7577
val leftColWidth = getLeftColWidth(flattenedAll)
7678
mainMethods match {
7779
case Seq() => ""

Diff for: mainargs/src/TokenGrouping.scala

+16-23
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,22 @@ case class TokenGrouping[B](remaining: List[String], grouped: Map[ArgSig, Seq[St
77
object TokenGrouping {
88
def groupArgs[B](
99
flatArgs0: Seq[String],
10-
argSigs0: Seq[ArgSig],
10+
argSigs: Seq[(ArgSig, TokensReader.Terminal[_])],
1111
allowPositional: Boolean,
1212
allowRepeats: Boolean,
1313
allowLeftover: Boolean
1414
): Result[TokenGrouping[B]] = {
15-
val argSigs: Seq[ArgSig] = argSigs0
16-
.map(ArgSig.flatten(_).collect { case x: ArgSig => x })
17-
.flatten
18-
19-
val positionalArgSigs = argSigs
20-
.filter {
21-
case x: ArgSig if x.reader.isLeftover || x.reader.isConstant => false
22-
case x: ArgSig if x.positional => true
23-
case x => allowPositional
24-
}
15+
val positionalArgSigs = argSigs.collect {
16+
case (a, r: TokensReader.Simple[_]) if allowPositional | a.positional =>
17+
a
18+
}
2519

2620
val flatArgs = flatArgs0.toList
2721
val keywordArgMap = argSigs
28-
.filter { case x: ArgSig if x.positional => false; case _ => true }
22+
.collect {
23+
case (a, r: TokensReader.Simple[_]) if !a.positional => a
24+
case (a, r: TokensReader.Flag) => a
25+
}
2926
.flatMap { x => (x.name.map("--" + _) ++ x.shortName.map("-" + _)).map(_ -> x) }
3027
.toMap[String, ArgSig]
3128

@@ -77,17 +74,13 @@ object TokenGrouping {
7774
}
7875
.toSeq
7976

80-
val missing = argSigs
81-
.collect { case x: ArgSig => x }
82-
.filter { x =>
83-
x.reader match {
84-
case r: TokensReader.Simple[_] =>
85-
!r.allowEmpty &&
86-
x.default.isEmpty &&
87-
!current.contains(x)
88-
case _ => false
89-
}
90-
}
77+
val missing = argSigs.collect {
78+
case (a, r: TokensReader.Simple[_])
79+
if !r.allowEmpty
80+
&& a.default.isEmpty
81+
&& !current.contains(a) =>
82+
a
83+
}
9184

9285
val unknown = if (allowLeftover) Nil else remaining
9386
if (missing.nonEmpty || duplicates.nonEmpty || unknown.nonEmpty) {

Diff for: mainargs/src/TokensReader.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ object ArgSig {
241241
)
242242
}
243243

244-
def flatten[T](x: ArgSig): Seq[ArgSig] = x.reader match {
245-
case _: TokensReader.Terminal[T] => Seq(x)
244+
def flatten[T](x: ArgSig): Seq[(ArgSig, TokensReader.Terminal[_])] = x.reader match {
245+
case r: TokensReader.Terminal[T] => Seq((x, r))
246246
case cls: TokensReader.Class[_] => cls.main.argSigs0.flatMap(flatten(_))
247247
}
248248
}
@@ -281,11 +281,11 @@ case class MainData[T, B](
281281
invokeRaw: (B, Seq[Any]) => T
282282
) {
283283

284-
val flattenedArgSigs: Seq[ArgSig] =
285-
argSigs0.iterator.flatMap[ArgSig](ArgSig.flatten(_)).toVector
284+
val flattenedArgSigs: Seq[(ArgSig, TokensReader.Terminal[_])] =
285+
argSigs0.iterator.flatMap[(ArgSig, TokensReader.Terminal[_])](ArgSig.flatten(_)).toVector
286286

287287
val renderedArgSigs: Seq[ArgSig] =
288-
flattenedArgSigs.filter(a => !a.hidden && !a.reader.isConstant)
288+
flattenedArgSigs.collect{case (a, r) if !a.hidden && !r.isConstant => a}
289289
}
290290

291291
object MainData {

Diff for: mainargs/test/src/CoreTests.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class CoreTests(allowPositional: Boolean) extends TestSuite {
5959
List("foo", "bar", "qux", "ex")
6060
)
6161
val evaledArgs = check.mains.value.map(_.flattenedArgSigs.map {
62-
case ArgSig(name, s, docs, None, parser, _, _) => (s, docs, None, parser)
63-
case ArgSig(name, s, docs, Some(default), parser, _, _) =>
62+
case (ArgSig(name, s, docs, None, parser, _, _), _) => (s, docs, None, parser)
63+
case (ArgSig(name, s, docs, Some(default), parser, _, _), _) =>
6464
(s, docs, Some(default(CoreBase)), parser)
6565
})
6666

Diff for: mainargs/test/src/IssueTests.scala

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package mainargs
2+
import utest._
3+
4+
object IssueTests extends TestSuite {
5+
6+
object Main {
7+
@main
8+
def mycmd(@arg(name = "the-flag") f: mainargs.Flag = mainargs.Flag(false),
9+
@arg str: String = "s",
10+
args: Leftover[String]) = {
11+
(f.value, str, args.value)
12+
}
13+
}
14+
15+
val tests = Tests {
16+
test("issue60") {
17+
test {
18+
val parsed = ParserForMethods(Main)
19+
.runEither(Seq("--str", "str", "a", "b", "c", "d"), allowPositional = true)
20+
21+
assert(parsed == Right((false, "str", List("a", "b", "c", "d"))))
22+
}
23+
test {
24+
val parsed = ParserForMethods(Main)
25+
.runEither(Seq("a", "b", "c", "d"), allowPositional = true)
26+
27+
assert(parsed == Right((false, "a", List("b", "c", "d"))))
28+
}
29+
}
30+
}
31+
}

0 commit comments

Comments
 (0)