Skip to content

Commit ed319e8

Browse files
authored
Fix signature computation and caching involving type variables (#18092)
This fixes the logic taking the signature of a type uninstantiated type variables (using tpnme.Uninstantiated) and non-permanently instantiated type variables (which have a legitimate signature in the currenty TyperState but cannot be cached). The test case was minimized from fmonniot/scala3mock#2.
2 parents 66b6dff + 4499bc0 commit ed319e8

File tree

6 files changed

+304
-165
lines changed

6 files changed

+304
-165
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/TypeErasure.scala

+202-140
Large diffs are not rendered by default.

Diff for: compiler/src/dotty/tools/dotc/core/Types.scala

+13-4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ object Types {
109109
/** Is this type still provisional? This is the case if the type contains, or depends on,
110110
* uninstantiated type variables or type symbols that have the Provisional flag set.
111111
* This is an antimonotonic property - once a type is not provisional, it stays so forever.
112+
*
113+
* FIXME: The semantics of this flag are broken by the existence of `TypeVar#resetInst`,
114+
* a non-provisional type could go back to being provisional after
115+
* a call to `resetInst`. This means all caches that rely on `isProvisional`
116+
* can likely end up returning stale results.
112117
*/
113118
def isProvisional(using Context): Boolean = mightBeProvisional && testProvisional
114119

@@ -2272,7 +2277,7 @@ object Types {
22722277

22732278
if ctx.runId != mySignatureRunId then
22742279
mySignature = computeSignature
2275-
if !mySignature.isUnderDefined then mySignatureRunId = ctx.runId
2280+
if !mySignature.isUnderDefined && !isProvisional then mySignatureRunId = ctx.runId
22762281
mySignature
22772282
end signature
22782283

@@ -3784,17 +3789,17 @@ object Types {
37843789
case SourceLanguage.Java =>
37853790
if ctx.runId != myJavaSignatureRunId then
37863791
myJavaSignature = computeSignature
3787-
if !myJavaSignature.isUnderDefined then myJavaSignatureRunId = ctx.runId
3792+
if !myJavaSignature.isUnderDefined && !isProvisional then myJavaSignatureRunId = ctx.runId
37883793
myJavaSignature
37893794
case SourceLanguage.Scala2 =>
37903795
if ctx.runId != myScala2SignatureRunId then
37913796
myScala2Signature = computeSignature
3792-
if !myScala2Signature.isUnderDefined then myScala2SignatureRunId = ctx.runId
3797+
if !myScala2Signature.isUnderDefined && !isProvisional then myScala2SignatureRunId = ctx.runId
37933798
myScala2Signature
37943799
case SourceLanguage.Scala3 =>
37953800
if ctx.runId != mySignatureRunId then
37963801
mySignature = computeSignature
3797-
if !mySignature.isUnderDefined then mySignatureRunId = ctx.runId
3802+
if !mySignature.isUnderDefined && !isProvisional then mySignatureRunId = ctx.runId
37983803
mySignature
37993804
end signature
38003805

@@ -4760,6 +4765,10 @@ object Types {
47604765
* is different from the variable's creation state (meaning unrolls are possible)
47614766
* in the current typer state.
47624767
*
4768+
* FIXME: the "once" in the statement above is not true anymore now that `resetInst`
4769+
* exists, this is problematic for caching (see `Type#isProvisional`),
4770+
* we should try getting rid of this method.
4771+
*
47634772
* @param origin the parameter that's tracked by the type variable.
47644773
* @param creatorState the typer state in which the variable was created.
47654774
* @param initNestingLevel the initial nesting level of the type variable. (c.f. nestingLevel)

Diff for: compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import core.Flags._
1111
import core.Names.Name
1212
import core.Symbols._
1313
import core.TypeApplications.{EtaExpansion, TypeParamInfo}
14-
import core.TypeErasure.{erasedGlb, erasure, fullErasure, isGenericArrayElement}
14+
import core.TypeErasure.{erasedGlb, erasure, fullErasure, isGenericArrayElement, tupleArity}
1515
import core.Types._
1616
import core.classfile.ClassfileConstants
1717
import SymUtils._
@@ -255,7 +255,7 @@ object GenericSignatures {
255255
case _ => jsig(elemtp)
256256

257257
case RefOrAppliedType(sym, pre, args) =>
258-
if (sym == defn.PairClass && tp.tupleArity > Definitions.MaxTupleArity)
258+
if (sym == defn.PairClass && tupleArity(tp) > Definitions.MaxTupleArity)
259259
jsig(defn.TupleXXLClass.typeRef)
260260
else if (isTypeParameterInSig(sym, sym0)) {
261261
assert(!sym.isAliasType, "Unexpected alias type: " + sym)

Diff for: compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

-18
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,6 @@ object TypeUtils {
4949
case ps => ps.reduceLeft(AndType(_, _))
5050
}
5151

52-
/** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs,
53-
* or -1 if this is not a tuple type.
54-
*/
55-
def tupleArity(using Context): Int = self/*.dealias*/ match { // TODO: why does dealias cause a failure in tests/run-deep-subtype/Tuple-toArray.scala
56-
case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
57-
val arity = tl.tupleArity
58-
if (arity < 0) arity else arity + 1
59-
case self: SingletonType =>
60-
if self.termSymbol == defn.EmptyTupleModule then 0 else -1
61-
case self: AndOrType =>
62-
val arity1 = self.tp1.tupleArity
63-
val arity2 = self.tp2.tupleArity
64-
if arity1 == arity2 then arity1 else -1
65-
case _ =>
66-
if defn.isTupleNType(self) then self.dealias.argInfos.length
67-
else -1
68-
}
69-
7052
/** The element types of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs */
7153
def tupleElementTypes(using Context): Option[List[Type]] = self.dealias match {
7254
case AppliedType(tycon, hd :: tl :: Nil) if tycon.isRef(defn.PairClass) =>

Diff for: compiler/test/dotty/tools/SignatureTest.scala

+76-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,23 @@ package dotty.tools
22

33
import vulpix.TestConfiguration
44

5+
import org.junit.Assert._
56
import org.junit.Test
67

7-
import dotc.ast.Trees._
8+
import dotc.ast.untpd
89
import dotc.core.Decorators._
910
import dotc.core.Contexts._
11+
import dotc.core.Flags._
1012
import dotc.core.Phases._
13+
import dotc.core.Names._
1114
import dotc.core.Types._
1215
import dotc.core.Symbols._
16+
import dotc.core.StdNames._
17+
import dotc.core.Signature
18+
import dotc.typer.ProtoTypes.constrained
19+
import dotc.typer.Inferencing.isFullyDefined
20+
import dotc.typer.ForceDegree
21+
import dotc.util.NoSourcePosition
1322

1423
import java.io.File
1524
import java.nio.file._
@@ -38,3 +47,69 @@ class SignatureTest:
3847
|${ref.denot.signature}""".stripMargin)
3948
}
4049
}
50+
51+
/** Ensure that signature computation returns an underdefined signature when
52+
* the signature depends on uninstantiated type variables.
53+
*/
54+
@Test def underdefined: Unit =
55+
inCompilerContext(TestConfiguration.basicClasspath, separateRun = false,
56+
"""trait Foo
57+
|trait Bar
58+
|class A[T <: Tuple]:
59+
| def and(x: T & Foo): Unit = {}
60+
| def andor(x: (T | Bar) & Foo): Unit = {}
61+
| def array(x: Array[(T | Bar) & Foo]): Unit = {}
62+
| def tuple(x: Foo *: T): Unit = {}
63+
| def tuple2(x: Foo *: (T | Tuple) & Foo): Unit = {}
64+
|""".stripMargin):
65+
val cls = requiredClass("A")
66+
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
67+
tvar <:< defn.TupleTypeRef
68+
val prefix = cls.typeRef.appliedTo(tvar)
69+
70+
def checkSignatures(expectedIsUnderDefined: Boolean)(using Context): Unit =
71+
for decl <- cls.info.decls.toList if decl.is(Method) && !decl.isConstructor do
72+
val meth = decl.asSeenFrom(prefix)
73+
val sig = meth.info.signature
74+
val what = if expectedIsUnderDefined then "underdefined" else "fully-defined"
75+
assert(sig.isUnderDefined == expectedIsUnderDefined, i"Signature of `$meth` with prefix `$prefix` and type `${meth.info}` should be $what but is `$sig`")
76+
77+
checkSignatures(expectedIsUnderDefined = true)
78+
assert(isFullyDefined(tvar, force = ForceDegree.all), s"Could not instantiate $tvar")
79+
checkSignatures(expectedIsUnderDefined = false)
80+
81+
/** Check that signature caching behaves correctly with respect to retracted
82+
* instantiations of type variables.
83+
*/
84+
@Test def cachingWithRetraction: Unit =
85+
inCompilerContext(TestConfiguration.basicClasspath, separateRun = false,
86+
"""trait Foo
87+
|trait Bar
88+
|class A[T]:
89+
| def and(x: T & Foo): Unit = {}
90+
|""".stripMargin):
91+
val cls = requiredClass("A")
92+
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
93+
val prefix = cls.typeRef.appliedTo(tvar)
94+
val ref = prefix.select(cls.requiredMethod("and")).asInstanceOf[TermRef]
95+
96+
/** Check that the signature of the first parameter of `ref` is equal to `expectedParamSig`. */
97+
def checkParamSig(ref: TermRef, expectedParamSig: TypeName)(using Context): Unit =
98+
assertEquals(i"Check failed for param signature of $ref",
99+
expectedParamSig, ref.signature.paramsSig.head)
100+
// Both NamedType and MethodOrPoly cache signatures, so check both caches.
101+
assertEquals(i"Check failed for param signature of ${ref.info} (but not for $ref itself)",
102+
expectedParamSig, ref.info.signature.paramsSig.head)
103+
104+
105+
// Initially, the param signature is Uninstantiated since it depends on an uninstantiated type variable
106+
checkParamSig(ref, tpnme.Uninstantiated)
107+
108+
// In this context, the signature is the erasure of `Bar & Foo`.
109+
inContext(ctx.fresh.setNewTyperState()):
110+
tvar =:= requiredClass("Bar").typeRef
111+
assert(isFullyDefined(tvar, force = ForceDegree.all), s"Could not instantiate $tvar")
112+
checkParamSig(ref, "Bar".toTypeName)
113+
114+
// If our caching logic is working correctly, we should get the original signature here.
115+
checkParamSig(ref, tpnme.Uninstantiated)

Diff for: tests/pos/scala3mock.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
class MockFunction1[T1]:
2+
def expects(v1: T1 | Foo): Any = ???
3+
def expects(matcher: String): Any = ???
4+
5+
def when[T1](f: T1 => Any): MockFunction1[T1] = ???
6+
7+
class Foo
8+
9+
def main =
10+
val f: Foo = new Foo
11+
when((x: Foo) => "").expects(f)

0 commit comments

Comments
 (0)