|
| 1 | +package dotty.tools.pc |
| 2 | + |
| 3 | +import dotty.tools.dotc.ast.tpd |
| 4 | +import dotty.tools.dotc.ast.tpd.* |
| 5 | +import dotty.tools.dotc.core.Constants.Constant |
| 6 | +import dotty.tools.dotc.core.Contexts.Context |
| 7 | +import dotty.tools.dotc.core.Flags |
| 8 | +import dotty.tools.dotc.core.StdNames |
| 9 | +import dotty.tools.dotc.core.Symbols |
| 10 | +import dotty.tools.dotc.core.Types.* |
| 11 | +import dotty.tools.dotc.core.Types.Type |
| 12 | +import dotty.tools.dotc.interactive.Interactive |
| 13 | +import dotty.tools.dotc.interactive.InteractiveDriver |
| 14 | +import dotty.tools.dotc.util.SourceFile |
| 15 | +import dotty.tools.dotc.util.Spans.Span |
| 16 | +import dotty.tools.pc.IndexedContext |
| 17 | +import dotty.tools.pc.printer.ShortenedTypePrinter |
| 18 | +import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam |
| 19 | +import dotty.tools.pc.utils.InteractiveEnrichments.* |
| 20 | + |
| 21 | +import scala.meta.internal.metals.ReportContext |
| 22 | +import scala.meta.pc.OffsetParams |
| 23 | +import scala.meta.pc.SymbolSearch |
| 24 | + |
| 25 | +class InferExpectedType( |
| 26 | + search: SymbolSearch, |
| 27 | + driver: InteractiveDriver, |
| 28 | + params: OffsetParams |
| 29 | +)(implicit rc: ReportContext): |
| 30 | + val uri = params.uri().nn |
| 31 | + val code = params.text().nn |
| 32 | + |
| 33 | + val sourceFile = SourceFile.virtual(uri, code) |
| 34 | + driver.run(uri, sourceFile) |
| 35 | + |
| 36 | + val ctx = driver.currentCtx |
| 37 | + val pos = driver.sourcePosition(params) |
| 38 | + |
| 39 | + def infer() = |
| 40 | + driver.compilationUnits.get(uri) match |
| 41 | + case Some(unit) => |
| 42 | + val path = |
| 43 | + Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx) |
| 44 | + val newctx = ctx.fresh.setCompilationUnit(unit) |
| 45 | + val tpdPath = |
| 46 | + Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using |
| 47 | + newctx |
| 48 | + ) |
| 49 | + val locatedCtx = |
| 50 | + Interactive.contextOfPath(tpdPath)(using newctx) |
| 51 | + val indexedCtx = IndexedContext(locatedCtx) |
| 52 | + val printer = |
| 53 | + ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx) |
| 54 | + InterCompletionType.inferType(path)(using newctx).map{ |
| 55 | + tpe => printer.tpe(tpe) |
| 56 | + } |
| 57 | + case None => None |
| 58 | + |
| 59 | +object InterCompletionType: |
| 60 | + def inferType(path: List[Tree])(using Context): Option[Type] = |
| 61 | + path match |
| 62 | + case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(Literal(Constant(null)))) :: rest => inferType(rest, lit.span) |
| 63 | + case ident :: rest => inferType(rest, ident.span) |
| 64 | + case _ => None |
| 65 | + |
| 66 | + def inferType(path: List[Tree], span: Span)(using Context): Option[Type] = |
| 67 | + path match |
| 68 | + case Typed(expr, tpt) :: _ if expr.span.contains(span) && !tpt.tpe.isErroneous => Some(tpt.tpe) |
| 69 | + case Block(_, expr) :: rest if expr.span.contains(span) => |
| 70 | + inferType(rest, span) |
| 71 | + case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span) |
| 72 | + case Alternative(_) :: rest => inferType(rest, span) |
| 73 | + case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span) |
| 74 | + case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span) |
| 75 | + case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span) |
| 76 | + case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType) |
| 77 | + case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => |
| 78 | + inferType(rest, span) |
| 79 | + case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span) |
| 80 | + // x match |
| 81 | + // case @@ |
| 82 | + case CaseDef(pat, _, _) :: Match(sel, cases) :: rest if pat.span.contains(span) && cases.exists(_.span.contains(span)) && !sel.tpe.isErroneous => |
| 83 | + sel.tpe match |
| 84 | + case tpe: TermRef => Some(tpe.symbol.info).filterNot(_.isErroneous) |
| 85 | + case tpe => Some(tpe) |
| 86 | + // List(@@) |
| 87 | + case SeqLiteral(_, tpe) :: _ if !tpe.tpe.isErroneous => |
| 88 | + Some(tpe.tpe) |
| 89 | + // val _: T = @@ |
| 90 | + // def _: T = @@ |
| 91 | + case (defn: ValOrDefDef) :: rest if !defn.tpt.tpe.isErroneous => Some(defn.tpt.tpe) |
| 92 | + // f(@@) |
| 93 | + case (app: Apply) :: rest => |
| 94 | + val param = |
| 95 | + for { |
| 96 | + ind <- app.args.zipWithIndex.collectFirst { |
| 97 | + case (arg, id) if arg.span.contains(span) => id |
| 98 | + } |
| 99 | + params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam)) |
| 100 | + param <- params.get(ind) |
| 101 | + } yield param.info |
| 102 | + param match |
| 103 | + // def f[T](a: T): T = ??? |
| 104 | + // f[Int](@@) |
| 105 | + // val _: Int = f(@@) |
| 106 | + case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) => |
| 107 | + for { |
| 108 | + (typeParams, args) <- |
| 109 | + app match |
| 110 | + case Apply(TypeApply(fun, args), _) => |
| 111 | + val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam)) |
| 112 | + typeParams.map((_, args.map(_.tpe))) |
| 113 | + // val f: (j: "a") => Int |
| 114 | + // f(@@) |
| 115 | + case Apply(Select(v, StdNames.nme.apply), _) => |
| 116 | + v.symbol.info match |
| 117 | + case AppliedType(des, args) => |
| 118 | + Some((des.typeSymbol.typeParams, args)) |
| 119 | + case _ => None |
| 120 | + case _ => None |
| 121 | + ind = typeParams.indexOf(t.symbol) |
| 122 | + tpe <- args.get(ind) |
| 123 | + if !tpe.isErroneous |
| 124 | + } yield tpe |
| 125 | + case Some(tpe) => Some(tpe) |
| 126 | + case _ => None |
| 127 | + case _ => None |
| 128 | + |
0 commit comments