Skip to content

Commit 3cf3dd3

Browse files
authored
Merge pull request #6551 from dotty-staging/trust-case-class-unapply
Fix (part of) #6323: trust case class unapplies to produce checkable type tests
2 parents 49dd34d + 94cff4a commit 3cf3dd3

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

Diff for: compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

+18-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import NameKinds.{PatMatStdBinderName, PatMatAltsName, PatMatResultName}
1616
import config.Printers.patmatch
1717
import reporting.diagnostic.messages._
1818
import dotty.tools.dotc.ast._
19+
import util.Property._
1920

2021
/** The pattern matching transform.
2122
* After this phase, the only Match nodes remaining in the code are simple switches
@@ -52,6 +53,8 @@ object PatternMatcher {
5253
/** Minimal number of cases to emit a switch */
5354
final val MinSwitchCases = 4
5455

56+
val TrustedTypeTestKey: Key[Unit] = new StickyKey[Unit]
57+
5558
/** Was symbol generated by pattern matcher? */
5659
def isPatmatGenerated(sym: Symbol)(implicit ctx: Context): Boolean =
5760
sym.is(Synthetic) && sym.name.is(PatMatStdBinderName)
@@ -153,24 +156,24 @@ object PatternMatcher {
153156

154157
/** The different kinds of tests */
155158
sealed abstract class Test
156-
case class TypeTest(tpt: Tree) extends Test { // scrutinee.isInstanceOf[tpt]
159+
case class TypeTest(tpt: Tree, trusted: Boolean) extends Test { // scrutinee.isInstanceOf[tpt]
157160
override def equals(that: Any): Boolean = that match {
158161
case that: TypeTest => this.tpt.tpe =:= that.tpt.tpe
159162
case _ => false
160163
}
161164
override def hashCode: Int = tpt.tpe.hash
162165
}
163-
case class EqualTest(tree: Tree) extends Test { // scrutinee == tree
166+
case class EqualTest(tree: Tree) extends Test { // scrutinee == tree
164167
override def equals(that: Any): Boolean = that match {
165168
case that: EqualTest => this.tree === that.tree
166169
case _ => false
167170
}
168171
override def hashCode: Int = tree.hash
169172
}
170-
case class LengthTest(len: Int, exact: Boolean) extends Test // scrutinee (== | >=) len
171-
case object NonEmptyTest extends Test // !scrutinee.isEmpty
172-
case object NonNullTest extends Test // scrutinee ne null
173-
case object GuardTest extends Test // scrutinee
173+
case class LengthTest(len: Int, exact: Boolean) extends Test // scrutinee (== | >=) len
174+
case object NonEmptyTest extends Test // !scrutinee.isEmpty
175+
case object NonNullTest extends Test // scrutinee ne null
176+
case object GuardTest extends Test // scrutinee
174177

175178
// ------- Generating plans from trees ------------------------
176179

@@ -352,7 +355,12 @@ object PatternMatcher {
352355
// begin patternPlan
353356
swapBind(tree) match {
354357
case Typed(pat, tpt) =>
355-
TestPlan(TypeTest(tpt), scrutinee, tree.span,
358+
val isTrusted = pat match {
359+
case UnApply(extractor, _, _) =>
360+
extractor.symbol.is(Synthetic) && extractor.symbol.owner.linkedClass.is(Case)
361+
case _ => false
362+
}
363+
TestPlan(TypeTest(tpt, isTrusted), scrutinee, tree.span,
356364
letAbstract(ref(scrutinee).cast(tpt.tpe)) { casted =>
357365
nonNull += casted
358366
patternPlan(casted, pat, onSuccess)
@@ -685,7 +693,7 @@ object PatternMatcher {
685693
.select(defn.Seq_length.matchingMember(scrutinee.tpe))
686694
.select(if (exact) defn.Int_== else defn.Int_>=)
687695
.appliedTo(Literal(Constant(len)))
688-
case TypeTest(tpt) =>
696+
case TypeTest(tpt, trusted) =>
689697
val expectedTp = tpt.tpe
690698

691699
// An outer test is needed in a situation like `case x: y.Inner => ...`
@@ -716,6 +724,7 @@ object PatternMatcher {
716724
scrutinee.isInstance(expectedTp) // will be translated to an equality test
717725
case _ =>
718726
val typeTest = scrutinee.select(defn.Any_typeTest).appliedToType(expectedTp)
727+
if (trusted) typeTest.pushAttachment(TrustedTypeTestKey, ())
719728
if (outerTestNeeded) typeTest.and(outerTest) else typeTest
720729
}
721730
}
@@ -899,7 +908,7 @@ object PatternMatcher {
899908
val seen = mutable.Set[Int]()
900909
def showTest(test: Test) = test match {
901910
case EqualTest(tree) => i"EqualTest($tree)"
902-
case TypeTest(tpt) => i"TypeTest($tpt)"
911+
case TypeTest(tpt, trusted) => i"TypeTest($tpt, trusted=$trusted)"
903912
case _ => test.toString
904913
}
905914
def showPlan(plan: Plan): Unit =

Diff for: compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ object TypeTestsCasts {
298298

299299
if (sym.isTypeTest) {
300300
val argType = tree.args.head.tpe
301-
if (!checkable(expr.tpe, argType, tree.span))
301+
val isTrusted = tree.getAttachment(PatternMatcher.TrustedTypeTestKey).nonEmpty
302+
if (!isTrusted && !checkable(expr.tpe, argType, tree.span))
302303
ctx.warning(i"the type test for $argType cannot be checked at runtime", tree.sourcePos)
303304
transformTypeTest(expr, tree.args.head.tpe, flagUnrelated = true)
304305
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object Test {
2+
sealed trait Foo[A, B]
3+
final case class Bar[X](x: X) extends Foo[X, X]
4+
5+
def foo[A, B](value: Foo[A, B], a: A => Int): B = value match {
6+
case Bar(x) => a(x); x
7+
}
8+
9+
def bar[A, B](value: Foo[A, B], a: A => Int): B = value match {
10+
case b: Bar[a] => b.x
11+
}
12+
13+
def err1[A, B](value: Foo[A, B], a: A => Int): B = value match {
14+
case b: Bar[A] => // spurious // error
15+
b.x
16+
}
17+
18+
def err2[A, B](value: Foo[A, B], a: A => Int): B = value match {
19+
case b: Bar[B] => // spurious // error
20+
b.x
21+
}
22+
23+
def fail[A, B](value: Foo[A, B], a: A => Int): B = value match {
24+
case b: Bar[Int] => // error
25+
b.x
26+
}
27+
}

0 commit comments

Comments
 (0)