Skip to content

Commit e404de4

Browse files
committed
Add additional checks for type parameters and update tests
1 parent 2ca2fc3 commit e404de4

File tree

8 files changed

+97
-24
lines changed

8 files changed

+97
-24
lines changed

Diff for: compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

+35-10
Original file line numberDiff line numberDiff line change
@@ -2727,7 +2727,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
27272727
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
27282728
assert(!conPrivateWithin.exists || conPrivateWithin.isType, "consPrivateWithin must be a type symbol or `Symbol.noSymbol`")
27292729
checkValidFlags(clsFlags.toTypeFlags, Flags.validClassFlags)
2730-
checkValidFlags(conFlags, Flags.validClassConstructorFlags)
2730+
checkValidFlags(conFlags.toTermFlags, Flags.validClassConstructorFlags)
27312731
val cls = dotc.core.Symbols.newNormalizedClassSymbolUsingClassSymbolinParents(
27322732
owner,
27332733
name.toTypeName,
@@ -2750,33 +2750,58 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
27502750
if (conParamFlags.length <= clauseIdx) throwShapeException()
27512751
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
27522752
checkMethodOrPolyShape(res, clauseIdx + 1)
2753-
case _ =>
2753+
case other =>
2754+
xCheckMacroAssert(
2755+
other.typeSymbol == cls,
2756+
"Incorrect type returned from the innermost PolyOrMethod."
2757+
)
2758+
(other, methodType) match
2759+
case (AppliedType(tycon, args), pt: PolyType) =>
2760+
xCheckMacroAssert(
2761+
args.length == pt.typeParams.length &&
2762+
args.zip(pt.typeParams).forall {
2763+
case (arg, param) => arg == param.paramRef
2764+
},
2765+
"Constructor result type does not correspond to the declared type parameters"
2766+
)
2767+
case _ =>
2768+
xCheckMacroAssert(
2769+
!(other.isInstanceOf[AppliedType] || methodType.isInstanceOf[PolyType]),
2770+
"AppliedType has to be the innermost resultTypeExp result if and only if conMethodType returns a PolyType"
2771+
)
27542772
checkMethodOrPolyShape(methodType, clauseIdx = 0)
2773+
27552774
cls.enter(dotc.core.Symbols.newSymbol(cls, nme.CONSTRUCTOR, Flags.Synthetic | Flags.Method | conFlags, methodType, conPrivateWithin, dotty.tools.dotc.util.Spans.NoCoord))
2756-
def getParamAccessors(methodType: TypeRepr, clauseIdx: Int): List[((String, TypeRepr, Boolean, Int), Int)] =
2775+
2776+
case class ParamSymbolData(name: String, tpe: TypeRepr, isTypeParam: Boolean, clauseIdx: Int, elementIdx: Int)
2777+
def getParamSymbolsData(methodType: TypeRepr, clauseIdx: Int): List[ParamSymbolData] =
27572778
methodType match
27582779
case MethodType(paramInfosExp, resultTypeExp, res) =>
2759-
paramInfosExp.zip(resultTypeExp).map(_ :* false :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
2780+
paramInfosExp.zip(resultTypeExp).zipWithIndex.map { case ((name, tpe), elementIdx) =>
2781+
ParamSymbolData(name, tpe, isTypeParam = false, clauseIdx, elementIdx)
2782+
} ++ getParamSymbolsData(res, clauseIdx + 1)
27602783
case pt @ PolyType(paramNames, paramBounds, res) =>
2761-
paramNames.zip(paramBounds).map(_ :* true :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
2784+
paramNames.zip(paramBounds).zipWithIndex.map {case ((name, tpe), elementIdx) =>
2785+
ParamSymbolData(name, tpe, isTypeParam = true, clauseIdx, elementIdx)
2786+
} ++ getParamSymbolsData(res, clauseIdx + 1)
27622787
case result =>
27632788
List()
2764-
// Maps PolyType indexes to type parameter symbols
2789+
// Maps PolyType indexes to type parameter symbol typerefs
27652790
val paramRefMap = collection.mutable.HashMap[Int, Symbol]()
27662791
val paramRefRemapper = new Types.TypeMap {
27672792
def apply(tp: Types.Type) = tp match {
27682793
case pRef: ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
27692794
case _ => mapOver(tp)
27702795
}
27712796
}
2772-
for ((name, tpe, isType, clauseIdx), elementIdx) <- getParamAccessors(methodType, 0) do
2773-
if isType then
2774-
checkValidFlags(conParamFlags(clauseIdx)(elementIdx), Flags.validClassTypeParamFlags)
2797+
for case ParamSymbolData(name, tpe, isTypeParam, clauseIdx, elementIdx) <- getParamSymbolsData(methodType, 0) do
2798+
if isTypeParam then
2799+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTypeFlags, Flags.validClassTypeParamFlags)
27752800
val symbol = dotc.core.Symbols.newSymbol(cls, name.toTypeName, Flags.Param | Flags.Deferred | Flags.Private | Flags.PrivateLocal | Flags.Local | conParamFlags(clauseIdx)(elementIdx), tpe, conParamPrivateWithins(clauseIdx)(elementIdx))
27762801
paramRefMap.addOne(elementIdx, symbol)
27772802
cls.enter(symbol)
27782803
else
2779-
checkValidFlags(conParamFlags(clauseIdx)(elementIdx), Flags.validClassTermParamFlags)
2804+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTermFlags, Flags.validClassTermParamFlags)
27802805
val fixedType = paramRefRemapper(tpe)
27812806
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor | conParamFlags(clauseIdx)(elementIdx), fixedType, conParamPrivateWithins(clauseIdx)(elementIdx)))
27822807
for sym <- decls(cls) do cls.enter(sym)

Diff for: library/src/scala/quoted/Quotes.scala

+38-9
Original file line numberDiff line numberDiff line change
@@ -3812,6 +3812,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
38123812
def classSymbol(fullName: String): Symbol
38133813

38143814
/** Generates a new class symbol for a class with a public parameterless constructor.
3815+
* For more settings, look to the other newClass methods.
38153816
*
38163817
* Example usage:
38173818
* ```
@@ -3856,13 +3857,41 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
38563857

38573858
/** Generates a new class symbol for a class with a public single term clause constructor.
38583859
*
3859-
* @param owner The owner of the class
3860-
* @param name The name of the class
3861-
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
3862-
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
3863-
* @param clsFlags extra flags with which the class symbol should be constructed.
3864-
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol.
3865-
* @param conParams constructor parameter pairs of names and types.
3860+
* Example usage:
3861+
* ```
3862+
* val name = nameExpr.valueOrAbort
3863+
* def decls(cls: Symbol): List[Symbol] =
3864+
* List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
3865+
* val parents = List(TypeTree.of[Object])
3866+
* val cls = Symbol.newClass(
3867+
* Symbol.spliceOwner,
3868+
* name,
3869+
* parents = _ => parents.map(_.tpe),
3870+
* decls,
3871+
* selfType = None,
3872+
* clsFlags = Flags.EmptyFlags,
3873+
* Symbol.noSymbol,
3874+
* List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String]))
3875+
* )
3876+
*
3877+
* val fooSym = cls.declaredMethod("foo").head
3878+
* val idxSym = cls.fieldMember("idx")
3879+
* val strSym = cls.fieldMember("str")
3880+
* val fooDef = DefDef(fooSym, argss =>
3881+
* Some('{println(s"Foo method call with (${${Ref(idxSym).asExpr}}, ${${Ref(strSym).asExpr}})")}.asTerm)
3882+
* )
3883+
* val clsDef = ClassDef(cls, parents, body = List(fooDef))
3884+
* val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), List(idxExpr.asTerm, strExpr.asTerm))
3885+
*
3886+
* Block(List(clsDef), Apply(Select(newCls, cls.methodMember("foo")(0)), Nil)).asExprOf[Unit]
3887+
* ```
3888+
* @param owner The owner of the class
3889+
* @param name The name of the class
3890+
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
3891+
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
3892+
* @param clsFlags extra flags with which the class symbol should be constructed.
3893+
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol.
3894+
* @param conParams constructor parameter pairs of names and types.
38663895
*
38673896
* Parameters assigned by the constructor can be obtained via `classSymbol.memberField`.
38683897
* This symbol starts without an accompanying definition.
@@ -3893,7 +3922,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
38933922
* val conMethodType =
38943923
* (classType: TypeRepr) => PolyType(List("T"))(_ => List(TypeBounds.empty), polyType =>
38953924
* MethodType(List("param"))((_: MethodType) => List(polyType.param(0)), (_: MethodType) =>
3896-
* classType
3925+
* AppliedType(classType, List(polyType.param(0)))
38973926
* )
38983927
* )
38993928
* val cls = Symbol.newClass(
@@ -3955,7 +3984,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
39553984
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol
39563985
* @param clsAnnotations annotations of the class
39573986
* @param conMethodType Function returning MethodOrPoly type representing the type of the constructor.
3958-
* Takes the result type as parameter which must be returned from the innermost MethodOrPoly.
3987+
* Takes the result type as parameter which must be returned from the innermost MethodOrPoly and have type parameters applied if those are used.
39593988
* PolyType may only represent the first clause of the constructor.
39603989
* @param conFlags extra flags with which the constructor symbol should be constructed. Can be `Synthetic` | `Method` | `Private` | `Protected` | `PrivateLocal` | `Local`
39613990
* @param conPrivateWithin the symbol within which the constructor for this new class symbol should be private. May be noSymbol.

Diff for: tests/run-macros/newClassParams/Macro_1.scala

+18-3
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,26 @@ private def makeClassAndCallExpr(nameExpr: Expr[String], idxExpr: Expr[Int], str
88

99
val name = nameExpr.valueOrAbort
1010

11-
def decls(cls: Symbol): List[Symbol] = List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
11+
def decls(cls: Symbol): List[Symbol] =
12+
List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
1213
val parents = List(TypeTree.of[Object])
13-
val cls = Symbol.newClass(Symbol.spliceOwner, name, parents = _ => parents.map(_.tpe), decls, selfType = None, Flags.EmptyFlags, Symbol.noSymbol, List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String])))
14+
val cls = Symbol.newClass(
15+
Symbol.spliceOwner,
16+
name,
17+
parents = _ => parents.map(_.tpe),
18+
decls,
19+
selfType = None,
20+
clsFlags = Flags.EmptyFlags,
21+
Symbol.noSymbol,
22+
List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String]))
23+
)
1424

15-
val fooDef = DefDef(cls.methodMember("foo")(0), argss => Some('{println(s"Foo method call with (${${Ref(cls.fieldMember("idx")).asExpr}}, ${${Ref(cls.fieldMember("str")).asExpr}})")}.asTerm))
25+
val fooSym = cls.declaredMethod("foo").head
26+
val idxSym = cls.fieldMember("idx")
27+
val strSym = cls.fieldMember("str")
28+
val fooDef = DefDef(fooSym, argss =>
29+
Some('{println(s"Foo method call with (${${Ref(idxSym).asExpr}}, ${${Ref(strSym).asExpr}})")}.asTerm)
30+
)
1631
val clsDef = ClassDef(cls, parents, body = List(fooDef))
1732
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), List(idxExpr.asTerm, strExpr.asTerm))
1833

Diff for: tests/run-macros/newClassTraitAndAbstract/Macro_1.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ private def makeClassExpr(using Quotes)(
3333
val conMethodType =
3434
(classType: TypeRepr) => PolyType(List("A", "B"))(
3535
_ => List(TypeBounds.empty, TypeBounds.upper(TypeRepr.of[Int])),
36-
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) => classType)
36+
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) =>
37+
AppliedType(classType, List(polyType.param(0), polyType.param(1)))
38+
)
3739
)
3840

3941
val traitSymbol = Symbol.newClass(

Diff for: tests/run-macros/newClassTypeParams/Macro_1.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ private def makeClassExpr(nameExpr: Expr[String])(using Quotes): Expr[Any] = {
1111
val conMethodType =
1212
(classType: TypeRepr) => PolyType(List("A", "B"))(
1313
_ => List(TypeBounds.empty, TypeBounds.upper(TypeRepr.of[Int])),
14-
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) => classType)
14+
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) =>
15+
AppliedType(classType, List(polyType.param(0), polyType.param(1)))
16+
)
1517
)
1618

1719
val cls = Symbol.newClass(

0 commit comments

Comments
 (0)