Skip to content

Commit c21133f

Browse files
committed
Fix signature computation involving nested type variables
During typechecking, we sometimes need to compute the signature of a type involving uninstantiated type variables or wildcards (in particular, because signature matching is doen as part of subtype checking for methods), this is handled with `tpnme.Uninstantiated` which is handled specially in `Signature`. Before this commit, `sigName` only checked for wildcards and type variables at the top-level of the type, even though nested types can still have an impact on type erasure, in particular when they appear as part of: - an intersection - a union - an underlying type of a derived value class - the element type of an array type - an element type of a tuple (... *: X *: ...) We keep track of all these situations by returning `null` in `TypeErasure#apply` when encountering a wildcard or uninstantiated type variable, which we then propage upwards to the `sigName` call. This propagation only happens for `sigName` calls (where `inSigName` is true), otherwise we throw an assertion since erasure shouldn't normally be computed for underdefined types.
1 parent 84bc1bd commit c21133f

File tree

3 files changed

+129
-38
lines changed

3 files changed

+129
-38
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/TypeErasure.scala

+79-37
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ object TypeErasure {
8383
* `EmptyTuple.type` because of a missing dealias, but this is now
8484
* impossible to fix.
8585
*
86-
* @return The arity if it can be determined or -1 otherwise.
86+
* @return The arity if it can be determined, or:
87+
* -1 if this type does not have a fixed arity
88+
* -2 if the arity depends on an uninstantiated type variable or WildcardType.
8789
*/
88-
def tupleArity(tp: Type)(using Context): Int = tp/*.dealias*/ match {
90+
def tupleArity(tp: Type)(using Context): Int = tp/*.dealias*/ match
8991
case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
9092
val arity = tupleArity(tl)
9193
if (arity < 0) arity else arity + 1
@@ -94,11 +96,14 @@ object TypeErasure {
9496
case tp: AndOrType =>
9597
val arity1 = tupleArity(tp.tp1)
9698
val arity2 = tupleArity(tp.tp2)
97-
if arity1 == arity2 then arity1 else -1
99+
if arity1 == arity2 then arity1 else math.min(-1, math.min(arity1, arity2))
100+
case tp: WildcardType => -2
101+
case tp: TypeVar if !tp.isInstantiated => -2
98102
case _ =>
99103
if defn.isTupleNType(tp) then tp.dealias.argInfos.length
100-
else -1
101-
}
104+
else tp.dealias match
105+
case tp: TypeVar if !tp.isInstantiated => -2
106+
case _ => -1
102107

103108
def normalizeClass(cls: ClassSymbol)(using Context): ClassSymbol = {
104109
if (cls.owner == defn.ScalaPackageClass) {
@@ -204,19 +209,19 @@ object TypeErasure {
204209
* @param tp The type to erase.
205210
*/
206211
def erasure(tp: Type)(using Context): Type =
207-
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = false, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx)
212+
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = false, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx).nn
208213

209214
/** The value class erasure of a Scala type, where value classes are semi-erased to
210215
* ErasedValueType (they will be fully erased in [[ElimErasedValueType]]).
211216
*
212217
* @param tp The type to erase.
213218
*/
214219
def valueErasure(tp: Type)(using Context): Type =
215-
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = true, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx)
220+
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = true, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx).nn
216221

217222
/** The erasure that Scala 2 would use for this type. */
218223
def scala2Erasure(tp: Type)(using Context): Type =
219-
erasureFn(sourceLanguage = SourceLanguage.Scala2, semiEraseVCs = true, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx)
224+
erasureFn(sourceLanguage = SourceLanguage.Scala2, semiEraseVCs = true, isConstructor = false, isSymbol = false, inSigName = false)(tp)(using preErasureCtx).nn
220225

221226
/** Like value class erasure, but value classes erase to their underlying type erasure */
222227
def fullErasure(tp: Type)(using Context): Type =
@@ -265,8 +270,8 @@ object TypeErasure {
265270
if (defn.isPolymorphicAfterErasure(sym)) eraseParamBounds(sym.info.asInstanceOf[PolyType])
266271
else if (sym.isAbstractType) TypeAlias(WildcardType)
267272
else if sym.is(ConstructorProxy) then NoType
268-
else if (sym.isConstructor) outer.addParam(sym.owner.asClass, erase(tp)(using preErasureCtx))
269-
else if (sym.is(Label)) erase.eraseResult(sym.info)(using preErasureCtx)
273+
else if (sym.isConstructor) outer.addParam(sym.owner.asClass, erase(tp)(using preErasureCtx).nn)
274+
else if (sym.is(Label)) erase.eraseResult(sym.info)(using preErasureCtx).nn
270275
else erase.eraseInfo(tp, sym)(using preErasureCtx) match {
271276
case einfo: MethodType =>
272277
if (sym.isGetter && einfo.resultType.isRef(defn.UnitClass))
@@ -587,8 +592,14 @@ import TypeErasure._
587592
*/
588593
class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, isSymbol: Boolean, inSigName: Boolean) {
589594

590-
/** The erasure |T| of a type T. This is:
595+
/** The erasure |T| of a type T.
596+
*
597+
* If computing the erasure of T requires erasing a WildcardType or an
598+
* uninstantiated type variable, then an exception signaling an internal
599+
* error will be thrown, unless `inSigName` is set in which case `null`
600+
* will be returned.
591601
*
602+
* In all other situations, |T| will be non-null and computed as follow:
592603
* - For a refined type scala.Array+[T]:
593604
* - if T is Nothing or Null, []Object
594605
* - otherwise, if T <: Object, []|T|
@@ -620,7 +631,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
620631
* - For NoType or NoPrefix, the type itself.
621632
* - For any other type, exception.
622633
*/
623-
private def apply(tp: Type)(using Context): Type = tp match {
634+
private def apply(tp: Type)(using Context): Type | Null = (tp match
624635
case _: ErasedValueType =>
625636
tp
626637
case tp: TypeRef =>
@@ -641,13 +652,19 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
641652
case _: ThisType =>
642653
this(tp.widen)
643654
case SuperType(thistpe, supertpe) =>
644-
SuperType(this(thistpe), this(supertpe))
655+
val eThis = this(thistpe)
656+
val eSuper = this(supertpe)
657+
if eThis == null || eSuper == null then null
658+
else SuperType(eThis, eSuper)
645659
case ExprType(rt) =>
646660
defn.FunctionType(0)
647661
case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass =>
648662
erasePolyFunctionApply(refinedInfo)
649663
case RefinedType(parent, nme.apply, refinedInfo: MethodType) if defn.isErasedFunctionType(parent) =>
650664
eraseErasedFunctionApply(refinedInfo)
665+
case tp: TypeVar if !tp.isInstantiated =>
666+
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
667+
null
651668
case tp: TypeProxy =>
652669
this(tp.underlying)
653670
case tp @ AndType(tp1, tp2) =>
@@ -656,7 +673,10 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
656673
else if sourceLanguage.isScala2 then
657674
this(Scala2Erasure.intersectionDominator(Scala2Erasure.flattenedParents(tp)))
658675
else
659-
erasedGlb(this(tp1), this(tp2))
676+
val e1 = this(tp1)
677+
val e2 = this(tp2)
678+
if e1 == null || e2 == null then null
679+
else erasedGlb(e1, e2)
660680
case OrType(tp1, tp2) =>
661681
if isSymbol && sourceLanguage.isScala2 && ctx.settings.scalajs.value then
662682
// In Scala2Unpickler we unpickle Scala.js pseudo-unions as if they were
@@ -670,10 +690,13 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
670690
// alone (and this doesn't impact the SJSIR we generate).
671691
JSDefinitions.jsdefn.PseudoUnionType
672692
else
673-
TypeComparer.orType(this(tp1), this(tp2), isErased = true)
693+
val e1 = this(tp1)
694+
val e2 = this(tp2)
695+
if e1 == null || e2 == null then null
696+
else TypeComparer.orType(e1, e2, isErased = true)
674697
case tp: MethodType =>
675698
def paramErasure(tpToErase: Type) =
676-
erasureFn(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, inSigName)(tpToErase)
699+
erasureFn(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, inSigName = false)(tpToErase).nn
677700
val (names, formals0) = if tp.hasErasedParams then
678701
tp.paramNames
679702
.zip(tp.paramInfos)
@@ -700,7 +723,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
700723
else {
701724
def eraseParent(tp: Type) = tp.dealias match { // note: can't be opaque, since it's a class parent
702725
case tp: AppliedType if tp.tycon.isRef(defn.PairClass) => defn.ObjectType
703-
case _ => apply(tp)
726+
case _ => apply(tp).nn
704727
}
705728
val erasedParents: List[Type] =
706729
if ((cls eq defn.ObjectClass) || cls.isPrimitiveValueClass) Nil
@@ -725,11 +748,12 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
725748
}
726749
case _: ErrorType | JavaArrayType(_) =>
727750
tp
728-
case tp: WildcardType if inSigName =>
729-
tp
751+
case tp: WildcardType =>
752+
assert(inSigName, i"Cannot erase wildcard type $tp")
753+
null
730754
case tp if (tp `eq` NoType) || (tp `eq` NoPrefix) =>
731755
tp
732-
}
756+
).ensuring(etp => etp != null || inSigName)
733757

734758
/** Like translucentSuperType, but issue a fatal error if it does not exist. */
735759
private def checkedSuperType(tp: TypeProxy)(using Context): Type =
@@ -760,15 +784,19 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
760784
val defn.ArrayOf(elemtp) = tp: @unchecked
761785
if (isGenericArrayElement(elemtp, isScala2 = sourceLanguage.isScala2)) defn.ObjectType
762786
else
763-
try JavaArrayType(erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, inSigName)(elemtp))
787+
try
788+
val eElem = erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, inSigName)(elemtp)
789+
if eElem == null then null
790+
else JavaArrayType(eElem)
764791
catch case ex: Throwable =>
765792
handleRecursive("erase array type", tp.show, ex)
766793
}
767794

768-
private def erasePair(tp: Type)(using Context): Type = {
795+
private def erasePair(tp: Type)(using Context): Type | Null = {
769796
val arity = tupleArity(tp)
770-
if (arity < 0) defn.ProductClass.typeRef
771-
else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn
797+
if arity == -2 then null // erasure depends on an uninstantiated type variable or WildcardType
798+
else if arity == -1 then defn.ProductClass.typeRef
799+
else if arity <= Definitions.MaxTupleArity then defn.TupleType(arity).nn
772800
else defn.TupleXXLClass.typeRef
773801
}
774802

@@ -777,12 +805,13 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
777805
* to the underlying type.
778806
*/
779807
def eraseInfo(tp: Type, sym: Symbol)(using Context): Type =
808+
assert(!inSigName) // therefore apply(...).nn won't fail
780809
val tp1 = tp match
781810
case tp: MethodicType => integrateContextResults(tp, contextResultCount(sym))
782811
case _ => tp
783812
tp1 match
784813
case ExprType(rt) =>
785-
if sym.is(Param) then apply(tp1)
814+
if sym.is(Param) then apply(tp1).nn
786815
// Note that params with ExprTypes are eliminated by ElimByName,
787816
// but potentially re-introduced by ResolveSuper, when we add
788817
// forwarders to mixin methods.
@@ -794,9 +823,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
794823
eraseResult(tp1.resultType) match
795824
case rt: MethodType => rt
796825
case rt => MethodType(Nil, Nil, rt)
797-
case tp1 => this(tp1)
826+
case tp1 => this(tp1).nn
798827

799-
private def eraseDerivedValueClass(tp: Type)(using Context): Type = {
828+
private def eraseDerivedValueClass(tp: Type)(using Context): Type | Null = {
800829
val cls = tp.classSymbol.asClass
801830
val unbox = valueClassUnbox(cls)
802831
if unbox.exists then
@@ -806,6 +835,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
806835
// The underlying part of an ErasedValueType cannot be an ErasedValueType itself
807836
val erase = erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, inSigName)
808837
val erasedUnderlying = erase(underlying)
838+
if erasedUnderlying == null then return null
809839

810840
// Ideally, we would just use `erasedUnderlying` as the erasure of `tp`, but to
811841
// be binary-compatible with Scala 2 we need two special cases for polymorphic
@@ -839,6 +869,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
839869

840870
/** The erasure of a function result type. */
841871
def eraseResult(tp: Type)(using Context): Type =
872+
assert(!inSigName) // therefore apply(...).nn won't fail
842873
// For a value class V, "new V(x)" should have type V for type adaptation to work
843874
// correctly (see SIP-15 and [[Erasure.Boxing.adaptToType]]), so the result type of a
844875
// constructor method should not be semi-erased.
@@ -848,18 +879,25 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
848879
case tp: TypeRef =>
849880
val sym = tp.symbol
850881
if (sym eq defn.UnitClass) sym.typeRef
851-
else this(tp)
882+
else apply(tp).nn
852883
case tp: AppliedType =>
853884
val sym = tp.tycon.typeSymbol
854885
if (sym.isClass && !erasureDependsOnArgs(sym)) eraseResult(tp.tycon)
855-
else this(tp)
886+
else apply(tp).nn
856887
case _ =>
857-
this(tp)
888+
apply(tp).nn
858889

859890
/** The name of the type as it is used in `Signature`s.
860-
* Need to ensure correspondence with erasure!
891+
*
892+
* If `tp` is null, or if computing its erasure requires erasing a
893+
* WildcardType or an uninstantiated type variable, then the special name
894+
* `tpnme.Uninstantiated` which is used to signal an underdefined signature
895+
* is used.
896+
*
897+
* Note: Need to ensure correspondence with erasure!
861898
*/
862-
private def sigName(tp: Type)(using Context): TypeName = try
899+
private def sigName(tp: Type | Null)(using Context): TypeName = try
900+
if tp == null then return tpnme.Uninstantiated
863901
tp match {
864902
case tp: TypeRef =>
865903
if (!tp.denot.exists)
@@ -873,6 +911,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
873911
}
874912
if (semiEraseVCs && isDerivedValueClass(sym)) {
875913
val erasedVCRef = eraseDerivedValueClass(tp)
914+
if erasedVCRef == null then return tpnme.Uninstantiated
876915
if (erasedVCRef.exists) return sigName(erasedVCRef)
877916
}
878917
if (defn.isSyntheticFunctionClass(sym))
@@ -897,14 +936,15 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
897936
case ErasedValueType(_, underlying) =>
898937
sigName(underlying)
899938
case JavaArrayType(elem) =>
900-
sigName(elem) ++ "[]"
939+
val elemName = sigName(elem)
940+
if elemName eq tpnme.Uninstantiated then elemName
941+
else elemName ++ "[]"
901942
case tp: TermRef =>
902943
sigName(underlyingOfTermRef(tp))
903944
case ExprType(rt) =>
904945
sigName(defn.FunctionOf(Nil, rt))
905-
case tp: TypeVar =>
906-
val inst = tp.instanceOpt
907-
if (inst.exists) sigName(inst) else tpnme.Uninstantiated
946+
case tp: TypeVar if !tp.isInstantiated =>
947+
tpnme.Uninstantiated
908948
case tp @ RefinedType(parent, nme.apply, _) if parent.typeSymbol eq defn.PolyFunctionClass =>
909949
// we need this case rather than falling through to the default
910950
// because RefinedTypes <: TypeProxy and it would be caught by
@@ -916,7 +956,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
916956
sigName(tp.underlying)
917957
case tp: WildcardType =>
918958
tpnme.Uninstantiated
919-
case _: ErrorType | NoType =>
959+
case tp: ErrorType =>
960+
tpnme.ERROR
961+
case _ if tp eq NoType => // Can't write `case NoType` because of #18083.
920962
tpnme.ERROR
921963
case _ =>
922964
val erasedTp = this(tp)

Diff for: compiler/test/dotty/tools/SignatureTest.scala

+39-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@ import vulpix.TestConfiguration
44

55
import org.junit.Test
66

7-
import dotc.ast.Trees._
7+
import dotc.ast.untpd
88
import dotc.core.Decorators._
99
import dotc.core.Contexts._
10+
import dotc.core.Flags._
1011
import dotc.core.Phases._
1112
import dotc.core.Types._
1213
import dotc.core.Symbols._
14+
import dotc.core.StdNames._
15+
import dotc.core.Signature
16+
import dotc.typer.ProtoTypes.constrained
17+
import dotc.typer.Inferencing.isFullyDefined
18+
import dotc.typer.ForceDegree
19+
import dotc.util.NoSourcePosition
1320

1421
import java.io.File
1522
import java.nio.file._
@@ -38,3 +45,34 @@ class SignatureTest:
3845
|${ref.denot.signature}""".stripMargin)
3946
}
4047
}
48+
49+
/** Ensure that signature computation returns an underdefined signature when
50+
* the signature depends on uninstantiated type variables.
51+
*/
52+
@Test def underdefined: Unit =
53+
inCompilerContext(TestConfiguration.basicClasspath, separateRun = false,
54+
"""trait Foo
55+
|trait Bar
56+
|class A[T <: Tuple]:
57+
| def and(x: T & Foo): Unit = {}
58+
| def andor(x: (T | Bar) & Foo): Unit = {}
59+
| def array(x: Array[(T | Bar) & Foo]): Unit = {}
60+
| def tuple(x: Foo *: T): Unit = {}
61+
| def tuple2(x: Foo *: (T | Tuple) & Foo): Unit = {}
62+
|""".stripMargin):
63+
val cls = requiredClass("A")
64+
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
65+
tvar <:< defn.TupleTypeRef
66+
val prefix = cls.typeRef.appliedTo(tvar)
67+
68+
def checkSignatures(expectedIsUnderDefined: Boolean)(using Context): Unit =
69+
for decl <- cls.info.decls.toList if decl.is(Method) && !decl.isConstructor do
70+
val meth = decl.asSeenFrom(prefix)
71+
val sig = meth.info.signature
72+
val what = if expectedIsUnderDefined then "underdefined" else "fully-defined"
73+
assert(sig.isUnderDefined == expectedIsUnderDefined, i"Signature of `$meth` with prefix `$prefix` and type `${meth.info}` should be $what but is `$sig`")
74+
75+
checkSignatures(expectedIsUnderDefined = true)
76+
assert(isFullyDefined(tvar, force = ForceDegree.all), s"Could not instantiate $tvar")
77+
checkSignatures(expectedIsUnderDefined = false)
78+

Diff for: tests/pos/scala3mock.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
class MockFunction1[T1]:
2+
def expects(v1: T1 | Foo): Any = ???
3+
def expects(matcher: String): Any = ???
4+
5+
def when[T1](f: T1 => Any): MockFunction1[T1] = ???
6+
7+
class Foo
8+
9+
def main =
10+
val f: Foo = new Foo
11+
when((x: Foo) => "").expects(f)

0 commit comments

Comments
 (0)