Skip to content

Fix #2578 Part 1: Tighten type checking of pattern bindings #6389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 7, 2019
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ object desugar {
*/
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key

/** An attachment for match expressions generated from a PatDef */
val PatDefMatch: Property.Key[Unit] = new Property.Key

/** Info of a variable in a pattern: The named tree and its type */
private type VarInfo = (NameTree, Tree)

Expand Down Expand Up @@ -956,7 +959,11 @@ object desugar {
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)

def rhsUnchecked = makeAnnotated("scala.unchecked", rhs)
def rhsUnchecked = {
val rhs1 = makeAnnotated("scala.unchecked", rhs)
rhs1.pushAttachment(PatDefMatch, ())
rhs1
}
val vars =
if (tupleOptimizable) // include `_`
pat match {
Expand Down
14 changes: 9 additions & 5 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1528,13 +1528,17 @@ object Types {
*/
def signature(implicit ctx: Context): Signature = Signature.NotAMethod

def dropRepeatedAnnot(implicit ctx: Context): Type = this match {
case AnnotatedType(parent, annot) if annot.symbol eq defn.RepeatedAnnot => parent
case tp @ AnnotatedType(parent, annot) =>
tp.derivedAnnotatedType(parent.dropRepeatedAnnot, annot)
case tp => tp
/** Drop annotation of given `cls` from this type */
def dropAnnot(cls: Symbol)(implicit ctx: Context): Type = stripTypeVar match {
case self @ AnnotatedType(pre, annot) =>
if (annot.symbol eq cls) pre
else self.derivedAnnotatedType(pre.dropAnnot(cls), annot)
case _ =>
this
}

def dropRepeatedAnnot(implicit ctx: Context): Type = dropAnnot(defn.RepeatedAnnot)

def annotatedToRepeated(implicit ctx: Context): Type = this match {
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.annotatedToRepeated)
case AnnotatedType(tp, annot) if annot matches defn.RepeatedAnnot =>
Expand Down
59 changes: 41 additions & 18 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1200,14 +1200,15 @@ object Parsers {
* | ForExpr
* | [SimpleExpr `.'] id `=' Expr
* | SimpleExpr1 ArgumentExprs `=' Expr
* | PostfixExpr [Ascription]
* | [‘inline’] PostfixExpr `match' `{' CaseClauses `}'
* | Expr2
* | [‘inline’] Expr2 `match' `{' CaseClauses `}'
* | `implicit' `match' `{' ImplicitCaseClauses `}'
* Bindings ::= `(' [Binding {`,' Binding}] `)'
* Binding ::= (id | `_') [`:' Type]
* Ascription ::= `:' CompoundType
* | `:' Annotation {Annotation}
* | `:' `_' `*'
* Bindings ::= `(' [Binding {`,' Binding}] `)'
* Binding ::= (id | `_') [`:' Type]
* Expr2 ::= PostfixExpr [Ascription]
* Ascription ::= `:' InfixType
* | `:' Annotation {Annotation}
* | `:' `_' `*'
*/
val exprInParens: () => Tree = () => expr(Location.InParens)

Expand Down Expand Up @@ -1324,15 +1325,16 @@ object Parsers {
t
}
case COLON =>
ascription(t, location)
in.nextToken()
val t1 = ascription(t, location)
if (in.token == MATCH) expr1Rest(t1, location) else t1
case MATCH =>
matchExpr(t, startOffset(t), Match)
case _ =>
t
}

def ascription(t: Tree, location: Location.Value): Tree = atSpan(startOffset(t)) {
in.skipToken()
in.token match {
case USCORE =>
val uscoreStart = in.skipToken()
Expand Down Expand Up @@ -1801,7 +1803,10 @@ object Parsers {
*/
def pattern1(): Tree = {
val p = pattern2()
if (isVarPattern(p) && in.token == COLON) ascription(p, Location.InPattern)
if (isVarPattern(p) && in.token == COLON) {
in.nextToken()
ascription(p, Location.InPattern)
}
else p
}

Expand Down Expand Up @@ -2353,14 +2358,32 @@ object Parsers {
tmplDef(start, mods)
}

/** PatDef ::= Pattern2 {`,' Pattern2} [`:' Type] `=' Expr
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
* ValDcl ::= id {`,' id} `:' Type
* VarDcl ::= id {`,' id} `:' Type
/** PatDef ::= ids [‘:’ Type] ‘=’ Expr
* | Pattern2 [‘:’ Type | Ascription] ‘=’ Expr
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
* ValDcl ::= id {`,' id} `:' Type
* VarDcl ::= id {`,' id} `:' Type
*/
def patDefOrDcl(start: Offset, mods: Modifiers): Tree = atSpan(start, nameStart) {
val lhs = commaSeparated(pattern2)
val tpt = typedOpt()
val first = pattern2()
var lhs = first match {
case id: Ident if in.token == COMMA =>
in.nextToken()
id :: commaSeparated(() => termIdent())
case _ =>
first :: Nil
}
def emptyType = TypeTree().withSpan(Span(in.lastOffset))
val tpt =
if (in.token == COLON) {
in.nextToken()
if (in.token == AT && lhs.tail.isEmpty) {
lhs = ascription(first, Location.ElseWhere) :: Nil
emptyType
}
else toplevelTyp()
}
else emptyType
val rhs =
if (tpt.isEmpty || in.token == EQUALS) {
accept(EQUALS)
Expand All @@ -2374,9 +2397,9 @@ object Parsers {
lhs match {
case (id: BackquotedIdent) :: Nil if id.name.isTermName =>
finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start)
case Ident(name: TermName) :: Nil => {
case Ident(name: TermName) :: Nil =>
finalizeDef(ValDef(name, tpt, rhs), mods, start)
} case _ =>
case _ =>
PatDef(mods, lhs, tpt, rhs)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ object TypeTestsCasts {

if (expr.tpe <:< testType)
if (expr.tpe.isNotNull) {
ctx.warning(TypeTestAlwaysSucceeds(foundCls, testCls), tree.sourcePos)
if (!inMatch) ctx.warning(TypeTestAlwaysSucceeds(foundCls, testCls), tree.sourcePos)
constant(expr, Literal(Constant(true)))
}
else expr.testNotNull
Expand Down
31 changes: 18 additions & 13 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,25 @@ trait SpaceLogic {
}
}

object SpaceEngine {

/** Is the unapply irrefutable?
* @param unapp The unapply function reference
*/
def isIrrefutableUnapply(unapp: tpd.Tree)(implicit ctx: Context): Boolean = {
val unappResult = unapp.tpe.widen.finalResultType
unappResult.isRef(defn.SomeClass) ||
unappResult =:= ConstantType(Constant(true)) ||
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) ||
productArity(unappResult) > 0
}
}

/** Scala implementation of space logic */
class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
import tpd._
import SpaceEngine._

private val scalaSomeClass = ctx.requiredClass("scala.Some")
private val scalaSeqFactoryClass = ctx.requiredClass("scala.collection.generic.SeqFactory")
private val scalaListType = ctx.requiredClassRef("scala.collection.immutable.List")
private val scalaNilType = ctx.requiredModuleRef("scala.collection.immutable.Nil")
Expand All @@ -309,15 +323,6 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
else Typ(AndType(tp1, tp2), true)
}

/** Whether the extractor is irrefutable */
def irrefutable(unapp: Tree): Boolean = {
// TODO: optionless patmat
unapp.tpe.widen.finalResultType.isRef(scalaSomeClass) ||
unapp.tpe.widen.finalResultType =:= ConstantType(Constant(true)) ||
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) ||
productArity(unapp.tpe.widen.finalResultType) > 0
}

/** Return the space that represents the pattern `pat` */
def project(pat: Tree): Space = pat match {
case Literal(c) =>
Expand All @@ -340,12 +345,12 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
else {
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.sourcePos)
if (elemTp.exists)
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, isIrrefutableUnapply(fun))
else
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), irrefutable(fun))
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)),isIrrefutableUnapply(fun))
}
else
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), irrefutable(fun))
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), isIrrefutableUnapply(fun))
case Typed(pat @ UnApply(_, _, _), _) => project(pat)
case Typed(expr, tpt) =>
Typ(erase(expr.tpe.stripAnnots), true)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
if (selType <:< unapplyArgType) {
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
selType
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
Expand Down
47 changes: 46 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ import ProtoTypes._
import Scopes._
import CheckRealizable._
import ErrorReporting.errorTree
import rewrites.Rewrites.patch
import util.Spans.Span

import util.SourcePosition
import transform.SymUtils._
import Decorators._
import ErrorReporting.{err, errorType}
import config.Printers.typr
import config.Printers.{typr, patmatch}
import NameKinds.DefaultGetterName
import Applications.unapplyArgs
import transform.patmat.SpaceEngine.isIrrefutableUnapply

import collection.mutable
import SymDenotations.{NoCompleter, NoDenotation}
Expand Down Expand Up @@ -594,6 +598,47 @@ trait Checking {
ctx.error(ex"$cls cannot be instantiated since it${rstatus.msg}", pos)
}

/** Check that pattern `pat` is irrefutable for scrutinee tye `pt`.
* This means `pat` is either marked @unchecked or `pt` conforms to the
* pattern's type. If pattern is an UnApply, do the check recursively.
*/
def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = {
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")

def fail(pat: Tree, pt: Type): Boolean = {
ctx.errorOrMigrationWarning(
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)}
|
|If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""",
pat.sourcePos)
false
}

def check(pat: Tree, pt: Type): Boolean = (pt <:< pat.tpe) || fail(pat, pt)

!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
pat match {
case Bind(_, pat1) =>
checkIrrefutable(pat1, pt)
case UnApply(fn, _, pats) =>
check(pat, pt) &&
(isIrrefutableUnapply(fn) || fail(pat, pt)) && {
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos)
pats.corresponds(argPts)(checkIrrefutable)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For UnApply, we should also check if the extractor itself is irrefutable, i.e. if it returns Some, or true, or a product. The logic to do that is in SpaceEngine#irrefutable.

Additionally, for some reason unapplies are simply rejected as val definitions:

scala> {
     |   object Positive { def unapply(i: Int): Option[Int] = Some(i).filter(_ > 0) }
     |   val Positive(p) = 5
     |   5 match { case Positive(p) => p }
     | }
3 |  val Positive(p) = 5
  |      ^^^^^^^^^^^
  | ((i: Int): Option[Int])(Positive.unapply) is not a valid result type of an unapply method of an extractor.

case Alternative(pats) =>
pats.forall(checkIrrefutable(_, pt))
case Typed(arg, tpt) =>
check(pat, pt) && checkIrrefutable(arg, pt)
case Ident(nme.WILDCARD) =>
true
case _ =>
check(pat, pt)
}
}
}

/** Check that `path` is a legal prefix for an import or export clause */
def checkLegalImportPath(path: Tree)(implicit ctx: Context): Unit = {
checkStable(path.tpe, path.sourcePos)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ object ErrorReporting {
}
"""\$\{\w*\}""".r.replaceSomeIn(raw, m => translate(m.matched.drop(2).init))
}

def rewriteNotice: String =
if (ctx.scala2Mode) "\nThis patch can be inserted automatically under -rewrite."
else ""
}

def err(implicit ctx: Context): Errors = new Errors
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,15 @@ class Typer extends Namer
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
val sel1 = typedExpr(tree.selector)
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
result match {
case Match(sel, CaseDef(pat, _, _) :: _)
if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) =>
if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode)
patch(Span(pat.span.end), ": @unchecked")
case _ =>
}
result
}
}

Expand Down Expand Up @@ -1817,8 +1825,11 @@ class Typer extends Namer
}
case _ => arg1
}
val tpt = TypeTree(AnnotatedType(arg1.tpe.widenIfUnstable, Annotation(annot1)))
assignType(cpy.Typed(tree)(arg2, tpt), tpt)
val argType =
if (arg1.isInstanceOf[Bind]) arg1.tpe.widen // bound symbol is not accessible outside of Bind node
else arg1.tpe.widenIfUnstable
val annotatedTpt = TypeTree(AnnotatedType(argType, Annotation(annot1)))
assignType(cpy.Typed(tree)(arg2, annotatedTpt), annotatedTpt)
}
}

Expand Down
3 changes: 0 additions & 3 deletions compiler/test-resources/repl/patdef
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@ scala> val _ @ List(x) = List(1)
val x: Int = 1
scala> val List(_ @ List(x)) = List(List(2))
val x: Int = 2
scala> val B @ List(), C: List[Int] = List()
val B: List[Int] = List()
val C: List[Int] = List()
7 changes: 4 additions & 3 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class CompilationTests extends ParallelTesting {
aggregateTests(
compileFilesInDir("tests/neg", defaultOptions),
compileFilesInDir("tests/neg-tailcall", defaultOptions),
compileFilesInDir("tests/neg-strict", defaultOptions.and("-strict")),
compileFilesInDir("tests/neg-no-kind-polymorphism", defaultOptions and "-Yno-kind-polymorphism"),
compileFilesInDir("tests/neg-custom-args/deprecation", defaultOptions.and("-Xfatal-warnings", "-deprecation")),
compileFilesInDir("tests/neg-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")),
Expand All @@ -160,8 +161,6 @@ class CompilationTests extends ParallelTesting {
compileFile("tests/neg-custom-args/i3246.scala", scala2Mode),
compileFile("tests/neg-custom-args/overrideClass.scala", scala2Mode),
compileFile("tests/neg-custom-args/autoTuplingTest.scala", defaultOptions.and("-language:noAutoTupling")),
compileFile("tests/neg-custom-args/i1050.scala", defaultOptions.and("-strict")),
compileFile("tests/neg-custom-args/nullless.scala", defaultOptions.and("-strict")),
compileFile("tests/neg-custom-args/nopredef.scala", defaultOptions.and("-Yno-predef")),
compileFile("tests/neg-custom-args/noimports.scala", defaultOptions.and("-Yno-imports")),
compileFile("tests/neg-custom-args/noimports2.scala", defaultOptions.and("-Yno-imports")),
Expand Down Expand Up @@ -249,7 +248,9 @@ class CompilationTests extends ParallelTesting {

val lib =
compileList("src", librarySources,
defaultOptions.and("-Ycheck-reentrant", "-strict", "-priorityclasspath", defaultOutputDir))(libGroup)
defaultOptions.and("-Ycheck-reentrant",
// "-strict", // TODO: re-enable once we allow : @unchecked in pattern definitions. Right now, lots of narrowing pattern definitions fail.
"-priorityclasspath", defaultOutputDir))(libGroup)

val compilerSources = sources(Paths.get("compiler/src"))
val compilerManagedSources = sources(Properties.dottyCompilerManagedSources)
Expand Down
Loading