|
| 1 | +//> using options -experimental -Yno-experimental |
| 2 | + |
| 3 | +package example |
| 4 | + |
| 5 | +import scala.annotation.{experimental, MacroAnnotation, StaticAnnotation} |
| 6 | +import scala.quoted._ |
| 7 | +import scala.collection.mutable.Map |
| 8 | +import scala.compiletime.ops.double |
| 9 | + |
| 10 | +// TODO make unrollLast the macro annotation and remove unrollHelper |
| 11 | +class unrollLast extends StaticAnnotation |
| 12 | + |
| 13 | +@experimental |
| 14 | +class unrollHelper extends MacroAnnotation { |
| 15 | + def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = |
| 16 | + import quotes.reflect._ |
| 17 | + tree match |
| 18 | + case tree: DefDef => transformDefDef(tree) |
| 19 | + case _ => report.throwError("unrollHelper can only be applied to a method definition", tree.pos) |
| 20 | + |
| 21 | + private def transformDefDef(using Quotes)(ddef: quotes.reflect.DefDef): List[quotes.reflect.Definition] = |
| 22 | + import quotes.reflect._ |
| 23 | + val unrollLastSym = Symbol.requiredClass("example.unrollLast") |
| 24 | + ddef.paramss match |
| 25 | + case Nil => |
| 26 | + report.throwError("unrollHelper must have an @unrollLast parameter", ddef.pos) |
| 27 | + case _ :: _ :: _ => |
| 28 | + report.throwError("unrollHelper does not yet support multiple parameter lists", ddef.pos) |
| 29 | + case TermParamClause(params) :: Nil => |
| 30 | + if params.isEmpty then report.throwError("unrollHelper must have an @unrollLast parameter", ddef.pos) |
| 31 | + else if params.init.exists(_.symbol.hasAnnotation(unrollLastSym)) || !params.last.symbol.hasAnnotation(unrollLastSym) then |
| 32 | + report.throwError("@unrollLast must be on the last parameter", ddef.pos) |
| 33 | + List(ddef, makeTelescopedDefDefWithoutLastArgument(ddef.symbol)) |
| 34 | + case _ => |
| 35 | + report.throwError("unrollHelper does not yet support type parameters", ddef.pos) |
| 36 | + |
| 37 | + private def makeTelescopedDefDefWithoutLastArgument(using Quotes)(defSym: quotes.reflect.Symbol): quotes.reflect.DefDef = |
| 38 | + import quotes.reflect._ |
| 39 | + def ddef1Rhs(argss: List[List[Tree]]): Some[Term] = |
| 40 | + val defaultArg = defaultGetter(defSym, argss.size + 2) // +1 for 1-based and +1 for the argument that was dropped |
| 41 | + val args1 = argss.head.asInstanceOf[List[Term]] :+ defaultArg |
| 42 | + Some(Ref(defSym).appliedToArgs(args1)) |
| 43 | + val sym1 = makeTelescopedSymbolWithoutLastArgument(defSym) |
| 44 | + DefDef(sym1, ddef1Rhs) |
| 45 | + |
| 46 | + private def makeTelescopedSymbolWithoutLastArgument(using Quotes)(defSym: quotes.reflect.Symbol): quotes.reflect.Symbol = |
| 47 | + import quotes.reflect._ |
| 48 | + val info1 = defSym.info match |
| 49 | + case info: MethodType => MethodType(info.paramNames.init)(_ => info.paramTypes.init, _ => info.resType) |
| 50 | + Symbol.newMethod(defSym.owner, defSym.name, info1, Flags.EmptyFlags, Symbol.noSymbol) |
| 51 | + |
| 52 | + private def defaultGetter(using Quotes)(sym: quotes.reflect.Symbol, idx: Int): quotes.reflect.Term = |
| 53 | + import quotes.reflect._ |
| 54 | + val getterSym = sym.owner.methodMember(sym.name + "$default$" + idx).head |
| 55 | + Ref(getterSym) |
| 56 | +} |
0 commit comments