Skip to content

Commit fd72f1a

Browse files
authored
Add inlay hints for by-name parameters (#23283)
Add `=>` hints for function parameters that are passed by name. Porting scalameta/metals#7404
1 parent dec859f commit fd72f1a

File tree

2 files changed

+162
-5
lines changed

2 files changed

+162
-5
lines changed

presentation-compiler/src/main/dotty/tools/pc/PcInlayHintsProvider.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,29 @@ class PcInlayHintsProvider(
116116
InlayHintKind.Type,
117117
)
118118
.addDefinition(adjustedPos.start)
119+
case ByNameParameters(byNameParams) =>
120+
def adjustByNameParameterPos(pos: SourcePosition): SourcePosition =
121+
val adjusted = adjustPos(pos)
122+
val start = text.indexWhere(!_.isWhitespace, adjusted.start)
123+
val end = text.lastIndexWhere(!_.isWhitespace, adjusted.end - 1)
124+
125+
val startsWithBrace = text.lift(start).contains('{')
126+
val endsWithBrace = text.lift(end).contains('}')
127+
128+
if startsWithBrace && endsWithBrace then
129+
adjusted.withStart(start + 1)
130+
else
131+
adjusted
132+
133+
byNameParams.foldLeft(inlayHints) {
134+
case (ih, pos) =>
135+
val adjusted = adjustByNameParameterPos(pos)
136+
ih.add(
137+
adjusted.startPos.toLsp,
138+
List(LabelPart("=> ")),
139+
InlayHintKind.Parameter
140+
)
141+
}
119142
case _ => inlayHints
120143

121144
private def toLabelParts(
@@ -388,3 +411,28 @@ object InferredType:
388411
index >= 0 && index < afterDef.size && afterDef(index) == '@'
389412

390413
end InferredType
414+
415+
object ByNameParameters:
416+
def unapply(tree: Tree)(using params: InlayHintsParams, ctx: Context): Option[List[SourcePosition]] =
417+
def shouldSkipSelect(sel: Select) =
418+
isForComprehensionMethod(sel) || sel.symbol.name == nme.unapply
419+
420+
if (params.byNameParameters()){
421+
tree match
422+
case Apply(TypeApply(sel: Select, _), _) if shouldSkipSelect(sel) =>
423+
None
424+
case Apply(sel: Select, _) if shouldSkipSelect(sel) =>
425+
None
426+
case Apply(fun, args) =>
427+
val funTp = fun.typeOpt.widenTermRefExpr
428+
val params = funTp.paramInfoss.flatten
429+
Some(
430+
args
431+
.zip(params)
432+
.collect {
433+
case (tree, param) if param.isByName => tree.sourcePos
434+
}
435+
)
436+
case _ => None
437+
} else None
438+
end ByNameParameters

presentation-compiler/test/dotty/tools/pc/tests/inlayHints/InlayHintsSuite.scala

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,17 +611,17 @@ class InlayHintsSuite extends BaseInlayHintsSuite {
611611
|class DemoSpec {
612612
| import ScalatestMock._
613613
|
614-
| /*StringTestOps<<(6:17)>>(*/"foo"/*)*/ should {
615-
| /*StringTestOps<<(6:17)>>(*/"checkThing1"/*)*/ in {
614+
| /*StringTestOps<<(6:17)>>(*/"foo"/*)*/ should {/*=> */
615+
| /*StringTestOps<<(6:17)>>(*/"checkThing1"/*)*/ in {/*=> */
616616
| checkThing1[String]/*(using instancesString<<(10:15)>>)*/
617617
| }/*(using here<<(5:15)>>)*/
618-
| /*StringTestOps<<(6:17)>>(*/"checkThing2"/*)*/ in {
618+
| /*StringTestOps<<(6:17)>>(*/"checkThing2"/*)*/ in {/*=> */
619619
| checkThing2[String]/*(using instancesString<<(10:15)>>, instancesString<<(10:15)>>)*/
620620
| }/*(using here<<(5:15)>>)*/
621621
| }/*(using subjectRegistrationFunction<<(3:15)>>)*/
622622
|
623-
| /*StringTestOps<<(6:17)>>(*/"bar"/*)*/ should {
624-
| /*StringTestOps<<(6:17)>>(*/"checkThing1"/*)*/ in {
623+
| /*StringTestOps<<(6:17)>>(*/"bar"/*)*/ should {/*=> */
624+
| /*StringTestOps<<(6:17)>>(*/"checkThing1"/*)*/ in {/*=> */
625625
| checkThing1[String]/*(using instancesString<<(10:15)>>)*/
626626
| }/*(using here<<(5:15)>>)*/
627627
| }/*(using subjectRegistrationFunction<<(3:15)>>)*/
@@ -1075,4 +1075,113 @@ class InlayHintsSuite extends BaseInlayHintsSuite {
10751075
| val x: (path: String, num: Int, line: Int) = test
10761076
|""".stripMargin
10771077
)
1078+
1079+
@Test def `by-name-regular` =
1080+
check(
1081+
"""|object Main:
1082+
| def foo(x: => Int, y: Int, z: => Int)(w: Int, v: => Int): Unit = ()
1083+
| foo(1, 2, 3)(4, 5)
1084+
|""".stripMargin,
1085+
"""|object Main:
1086+
| def foo(x: => Int, y: Int, z: => Int)(w: Int, v: => Int): Unit = ()
1087+
| foo(/*=> */1, 2, /*=> */3)(4, /*=> */5)
1088+
|""".stripMargin
1089+
)
1090+
1091+
@Test def `by-name-block` =
1092+
check(
1093+
"""|object Main:
1094+
| def Future[A](arg: => A): A = arg
1095+
|
1096+
| Future(1 + 2)
1097+
| Future {
1098+
| 1 + 2
1099+
| }
1100+
| Future {
1101+
| val x = 1
1102+
| val y = 2
1103+
| x + y
1104+
| }
1105+
| Some(Option(2)
1106+
| .getOrElse {
1107+
| List(1,2)
1108+
| .headOption
1109+
| })
1110+
|""".stripMargin,
1111+
"""|object Main:
1112+
| def Future[A](arg: => A): A = arg
1113+
|
1114+
| Future/*[Int<<scala/Int#>>]*/(/*=> */1 + 2)
1115+
| Future/*[Int<<scala/Int#>>]*/ {/*=> */
1116+
| 1 + 2
1117+
| }
1118+
| Future/*[Int<<scala/Int#>>]*/ {/*=> */
1119+
| val x/*: Int<<scala/Int#>>*/ = 1
1120+
| val y/*: Int<<scala/Int#>>*/ = 2
1121+
| x + y
1122+
| }
1123+
| Some/*[Int<<scala/Int#>> | Option<<scala/Option#>>[Int<<scala/Int#>>]]*/(Option/*[Int<<scala/Int#>>]*/(2)
1124+
| .getOrElse/*[Int<<scala/Int#>> | Option<<scala/Option#>>[Int<<scala/Int#>>]]*/ {/*=> */
1125+
| List/*[Int<<scala/Int#>>]*/(1,2)
1126+
| .headOption
1127+
| })
1128+
|""".stripMargin
1129+
)
1130+
1131+
@Test def `by-name-for-comprehension` =
1132+
check(
1133+
"""|object Main:
1134+
| case class Test[A](v: A):
1135+
| def flatMap(f: => (A => Test[Int])): Test[Int] = f(v)
1136+
| def map(f: => (A => Int)): Test[Int] = Test(f(v))
1137+
|
1138+
| def main(args: Array[String]): Unit =
1139+
| val result: Test[Int] = for {
1140+
| a <- Test(10)
1141+
| b <- Test(20)
1142+
| } yield a + b
1143+
|
1144+
|""".stripMargin,
1145+
"""|object Main:
1146+
| case class Test[A](v: A):
1147+
| def flatMap(f: => (A => Test[Int])): Test[Int] = f(v)
1148+
| def map(f: => (A => Int)): Test[Int] = Test/*[Int<<scala/Int#>>]*/(f(v))
1149+
|
1150+
| def main(args: Array[String]): Unit =
1151+
| val result: Test[Int] = for {
1152+
| a <- Test/*[Int<<scala/Int#>>]*/(10)
1153+
| b <- Test/*[Int<<scala/Int#>>]*/(20)
1154+
| } yield a + b
1155+
|
1156+
|""".stripMargin,
1157+
)
1158+
1159+
@Test def `by-name-for-comprehension-generic` =
1160+
check(
1161+
"""|object Main:
1162+
| case class Test[A](v: A):
1163+
| def flatMap[B](f: => (A => Test[B])): Test[B] = f(v)
1164+
| def map[B](f: => (A => B)): Test[B] = Test(f(v))
1165+
|
1166+
| def main(args: Array[String]): Unit =
1167+
| val result: Test[Int] = for {
1168+
| a <- Test(10)
1169+
| b <- Test(20)
1170+
| } yield a + b
1171+
|
1172+
|""".stripMargin,
1173+
"""|object Main:
1174+
| case class Test[A](v: A):
1175+
| def flatMap[B](f: => (A => Test[B])): Test[B] = f(v)
1176+
| def map[B](f: => (A => B)): Test[B] = Test/*[B<<(4:13)>>]*/(f(v))
1177+
|
1178+
| def main(args: Array[String]): Unit =
1179+
| val result: Test[Int] = for {
1180+
| a <- Test/*[Int<<scala/Int#>>]*/(10)
1181+
| b <- Test/*[Int<<scala/Int#>>]*/(20)
1182+
| } yield a + b
1183+
|
1184+
|""".stripMargin,
1185+
)
1186+
10781187
}

0 commit comments

Comments
 (0)