Skip to content

Commit fcded74

Browse files
author
EnzeXing
committed
Refactor pattern matching, skipping cases when safe to do so
1 parent 7573951 commit fcded74

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

Diff for: compiler/src/dotty/tools/dotc/transform/init/Objects.scala

+38-18
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ class Objects(using Context @constructorOnly):
602602
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
603603
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))
604604

605+
def remove(b: Value): Value = (a, b) match
606+
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
607+
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
608+
case (a: Ref, b: Ref) if a.equals(b) => Bottom
609+
case _ => a
610+
605611
def widen(height: Int)(using Context): Value =
606612
if height == 0 then Cold
607613
else
@@ -1341,29 +1347,25 @@ class Objects(using Context @constructorOnly):
13411347
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
13421348
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
13431349

1344-
def evalCase(caseDef: CaseDef): Value =
1345-
evalPattern(scrutinee, caseDef.pat)
1346-
eval(caseDef.guard, thisV, klass)
1347-
eval(caseDef.body, thisV, klass)
1348-
13491350
/** Abstract evaluation of patterns.
13501351
*
13511352
* It augments the local environment for bound pattern variables. As symbols are globally
13521353
* unique, we can put them in a single environment.
13531354
*
13541355
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
13551356
*/
1356-
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
1357+
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
13571358
val trace2 = Trace.trace.add(pat)
13581359
pat match
13591360
case Alternative(pats) =>
1360-
for pat <- pats do evalPattern(scrutinee, pat)
1361-
scrutinee
1361+
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
1362+
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
1363+
(orType, values.join)
13621364

13631365
case bind @ Bind(_, pat) =>
1364-
val value = evalPattern(scrutinee, pat)
1366+
val (tpe, value) = evalPattern(scrutinee, pat)
13651367
initLocal(bind.symbol, value)
1366-
scrutinee
1368+
(tpe, value)
13671369

13681370
case UnApply(fun, implicits, pats) =>
13691371
given Trace = trace2
@@ -1372,6 +1374,10 @@ class Objects(using Context @constructorOnly):
13721374
val funRef = fun1.tpe.asInstanceOf[TermRef]
13731375
val unapplyResTp = funRef.widen.finalResultType
13741376

1377+
val receiverType = fun1 match
1378+
case ident: Ident => funRef.prefix
1379+
case select: Select => select.qualifier.tpe
1380+
13751381
val receiver = fun1 match
13761382
case ident: Ident =>
13771383
evalType(funRef.prefix, thisV, klass)
@@ -1460,17 +1466,18 @@ class Objects(using Context @constructorOnly):
14601466
end if
14611467
end if
14621468
end if
1463-
scrutinee
1469+
(receiverType, scrutinee.filterType(receiverType))
14641470

14651471
case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
1466-
scrutinee
1472+
(defn.ThrowableType, scrutinee)
14671473

1468-
case Typed(pat, _) =>
1469-
evalPattern(scrutinee, pat)
1474+
case Typed(pat, typeTree) =>
1475+
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
1476+
(typeTree.tpe, value)
14701477

14711478
case tree =>
14721479
// For all other trees, the semantics is normal.
1473-
eval(tree, thisV, klass)
1480+
(defn.ThrowableType, eval(tree, thisV, klass))
14741481

14751482
end evalPattern
14761483

@@ -1494,12 +1501,12 @@ class Objects(using Context @constructorOnly):
14941501
if isWildcardStarArgList(pats) then
14951502
if pats.size == 1 then
14961503
// call .toSeq
1497-
val toSeqDenot = getMemberMethod(scrutineeType, nme.toSeq, toSeqType(elemType))
1504+
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
14981505
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
14991506
evalPattern(toSeqRes, pats.head)
15001507
else
15011508
// call .drop
1502-
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
1509+
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
15031510
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
15041511
for pat <- pats.init do evalPattern(applyRes, pat)
15051512
evalPattern(dropRes, pats.last)
@@ -1510,8 +1517,21 @@ class Objects(using Context @constructorOnly):
15101517
end if
15111518
end evalSeqPatterns
15121519

1520+
def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
1521+
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
1522+
(catchValue == Bottom && remainingScrutinee != Bottom)
15131523

1514-
cases.map(evalCase).join
1524+
var remainingScrutinee = scrutinee
1525+
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
1526+
for caseDef <- cases do
1527+
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
1528+
eval(caseDef.guard, thisV, klass)
1529+
if !canSkipCase(remainingScrutinee, value) then
1530+
caseResults.addOne(eval(caseDef.body, thisV, klass))
1531+
if catchesAllOf(caseDef, tpe) then
1532+
remainingScrutinee = remainingScrutinee.remove(value)
1533+
1534+
caseResults.join
15151535
end patternMatch
15161536

15171537
/** Handle semantics of leaf nodes

0 commit comments

Comments
 (0)