Skip to content

Commit e896db2

Browse files
kasiaMarekdwijnand
authored andcommitted
feat: infer expected type
1 parent 18af52a commit e896db2

File tree

4 files changed

+433
-0
lines changed

4 files changed

+433
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package dotty.tools.pc
2+
3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.ast.tpd.*
5+
import dotty.tools.dotc.core.Constants.Constant
6+
import dotty.tools.dotc.core.Contexts.Context
7+
import dotty.tools.dotc.core.Flags
8+
import dotty.tools.dotc.core.StdNames
9+
import dotty.tools.dotc.core.Symbols
10+
import dotty.tools.dotc.core.Types.*
11+
import dotty.tools.dotc.core.Types.Type
12+
import dotty.tools.dotc.interactive.Interactive
13+
import dotty.tools.dotc.interactive.InteractiveDriver
14+
import dotty.tools.dotc.util.SourceFile
15+
import dotty.tools.dotc.util.Spans.Span
16+
import dotty.tools.pc.IndexedContext
17+
import dotty.tools.pc.printer.ShortenedTypePrinter
18+
import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam
19+
import dotty.tools.pc.utils.InteractiveEnrichments.*
20+
21+
import scala.meta.internal.metals.ReportContext
22+
import scala.meta.pc.OffsetParams
23+
import scala.meta.pc.SymbolSearch
24+
25+
class InferExpectedType(
26+
search: SymbolSearch,
27+
driver: InteractiveDriver,
28+
params: OffsetParams
29+
)(implicit rc: ReportContext):
30+
val uri = params.uri().nn
31+
val code = params.text().nn
32+
33+
val sourceFile = SourceFile.virtual(uri, code)
34+
driver.run(uri, sourceFile)
35+
36+
val ctx = driver.currentCtx
37+
val pos = driver.sourcePosition(params)
38+
39+
def infer() =
40+
driver.compilationUnits.get(uri) match
41+
case Some(unit) =>
42+
val path =
43+
Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx)
44+
val newctx = ctx.fresh.setCompilationUnit(unit)
45+
val tpdPath =
46+
Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using
47+
newctx
48+
)
49+
val locatedCtx =
50+
Interactive.contextOfPath(tpdPath)(using newctx)
51+
val indexedCtx = IndexedContext(locatedCtx)
52+
val printer =
53+
ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx)
54+
InterCompletionType.inferType(path)(using newctx).map{
55+
tpe => printer.tpe(tpe)
56+
}
57+
case None => None
58+
59+
object InterCompletionType:
60+
def inferType(path: List[Tree])(using Context): Option[Type] =
61+
path match
62+
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(Literal(Constant(null)))) :: rest => inferType(rest, lit.span)
63+
case ident :: rest => inferType(rest, ident.span)
64+
case _ => None
65+
66+
def inferType(path: List[Tree], span: Span)(using Context): Option[Type] =
67+
path match
68+
case Typed(expr, tpt) :: _ if expr.span.contains(span) && !tpt.tpe.isErroneous => Some(tpt.tpe)
69+
case Block(_, expr) :: rest if expr.span.contains(span) =>
70+
inferType(rest, span)
71+
case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span)
72+
case Alternative(_) :: rest => inferType(rest, span)
73+
case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
74+
case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
75+
case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span)
76+
case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType)
77+
case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
78+
inferType(rest, span)
79+
case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
80+
// x match
81+
// case @@
82+
case CaseDef(pat, _, _) :: Match(sel, cases) :: rest if pat.span.contains(span) && cases.exists(_.span.contains(span)) && !sel.tpe.isErroneous =>
83+
sel.tpe match
84+
case tpe: TermRef => Some(tpe.symbol.info).filterNot(_.isErroneous)
85+
case tpe => Some(tpe)
86+
// List(@@)
87+
case SeqLiteral(_, tpe) :: _ if !tpe.tpe.isErroneous =>
88+
Some(tpe.tpe)
89+
// val _: T = @@
90+
// def _: T = @@
91+
case (defn: ValOrDefDef) :: rest if !defn.tpt.tpe.isErroneous => Some(defn.tpt.tpe)
92+
// f(@@)
93+
case (app: Apply) :: rest =>
94+
val param =
95+
for {
96+
ind <- app.args.zipWithIndex.collectFirst {
97+
case (arg, id) if arg.span.contains(span) => id
98+
}
99+
params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam))
100+
param <- params.get(ind)
101+
} yield param.info
102+
param match
103+
// def f[T](a: T): T = ???
104+
// f[Int](@@)
105+
// val _: Int = f(@@)
106+
case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) =>
107+
for {
108+
(typeParams, args) <-
109+
app match
110+
case Apply(TypeApply(fun, args), _) =>
111+
val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
112+
typeParams.map((_, args.map(_.tpe)))
113+
// val f: (j: "a") => Int
114+
// f(@@)
115+
case Apply(Select(v, StdNames.nme.apply), _) =>
116+
v.symbol.info match
117+
case AppliedType(des, args) =>
118+
Some((des.typeSymbol.typeParams, args))
119+
case _ => None
120+
case _ => None
121+
ind = typeParams.indexOf(t.symbol)
122+
tpe <- args.get(ind)
123+
if !tpe.isErroneous
124+
} yield tpe
125+
case Some(tpe) => Some(tpe)
126+
case _ => None
127+
case _ => None
128+

Diff for: presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala

+10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import scala.meta.pc.{PcSymbolInformation as IPcSymbolInformation}
3030

3131
import dotty.tools.dotc.reporting.StoreReporter
3232
import dotty.tools.pc.completions.CompletionProvider
33+
import dotty.tools.pc.InferExpectedType
3334
import dotty.tools.pc.completions.OverrideCompletions
3435
import dotty.tools.pc.buildinfo.BuildInfo
3536

@@ -198,6 +199,15 @@ case class ScalaPresentationCompiler(
198199
.asJava
199200
}
200201

202+
def inferExpectedType(params: OffsetParams): CompletableFuture[ju.Optional[String]] =
203+
compilerAccess.withInterruptableCompiler(Some(params))(
204+
Optional.empty(),
205+
params.token,
206+
) { access =>
207+
val driver = access.compiler()
208+
new InferExpectedType(search, driver, params).infer().asJava
209+
}
210+
201211
def shutdown(): Unit =
202212
compilerAccess.shutdown()
203213

Diff for: presentation-compiler/src/main/dotty/tools/pc/utils/InteractiveEnrichments.scala

+3
Original file line numberDiff line numberDiff line change
@@ -412,4 +412,7 @@ object InteractiveEnrichments extends CommonMtagsEnrichments:
412412
RefinedType(parent.dealias, name, refinedInfo.deepDealias)
413413
case dealised => dealised
414414

415+
extension[T] (list: List[T])
416+
def get(n: Int): Option[T] = if 0 <= n && n < list.size then Some(list(n)) else None
417+
415418
end InteractiveEnrichments

0 commit comments

Comments
 (0)