Skip to content

Commit 17d9e85

Browse files
author
EnzeXing
committed
Adding semantics for calling on SafeValue and evaluating SeqLiteral
1 parent 0ff1370 commit 17d9e85

File tree

2 files changed

+82
-30
lines changed

2 files changed

+82
-30
lines changed

Diff for: compiler/src/dotty/tools/dotc/transform/init/Objects.scala

+66-30
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,19 @@ class Objects(using Context @constructorOnly):
227227
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement:
228228
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"
229229

230-
/** Represents common base values like Int, String, etc.
230+
/**
231+
* Represents common base values like Int, String, etc.
232+
* Assumption: all methods calls on such values should be pure (no side effects)
231233
*/
232-
case object SafeValue extends ValueElement:
233-
val safeTypes = defn.ScalaNumericValueTypeList ++ List(defn.UnitType, defn.BooleanType, defn.StringType)
234-
def show(using Context): String = "SafeValue"
234+
case class SafeValue(tpe: Type) extends ValueElement:
235+
// tpe could be a AppliedType(java.lang.Class, T)
236+
val baseType = if tpe.isInstanceOf[AppliedType] then tpe.asInstanceOf[AppliedType].underlying else tpe
237+
assert(baseType.isInstanceOf[TypeRef] && SafeValue.safeTypes.contains(baseType), "Invalid creation of SafeValue! Type = " + tpe)
238+
val typeref = baseType.asInstanceOf[TypeRef]
239+
def show(using Context): String = "SafeValue of type " + tpe
240+
241+
object SafeValue:
242+
val safeTypes = defn.ScalaNumericValueTypeList ++ List(defn.UnitType, defn.BooleanType, defn.StringType, defn.NullType, defn.ClassClass.typeRef)
235243

236244
/**
237245
* Represents a set of values
@@ -669,7 +677,7 @@ class Objects(using Context @constructorOnly):
669677
a match
670678
case UnknownValue => UnknownValue
671679
case Package(_) => a
672-
case SafeValue => SafeValue
680+
case SafeValue(_) => a
673681
case ref: Ref => if ref.klass.isSubClass(klass) then ref else Bottom
674682
case ValueSet(values) => values.map(v => v.filterClass(klass)).join
675683
case arr: OfArray => if defn.ArrayClass.isSubClass(klass) then arr else Bottom
@@ -698,7 +706,7 @@ class Objects(using Context @constructorOnly):
698706
* @param superType The type of the super in a super call. NoType for non-super calls.
699707
* @param needResolve Whether the target of the call needs resolution?
700708
*/
701-
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.tree.show), printer, (_: Value).show) {
709+
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
702710
value.filterClass(meth.owner) match
703711
case UnknownValue =>
704712
if reportUnknown then
@@ -708,11 +716,33 @@ class Objects(using Context @constructorOnly):
708716
UnknownValue
709717

710718
case Package(packageSym) =>
711-
report.warning("[Internal error] Unexpected call on package = " + value.show + ", meth = " + meth.show + Trace.show, Trace.position)
712-
Bottom
713-
714-
case SafeValue =>
715-
SafeValue // Check return type, if not safe, try to analyze body, 1.until(2).map(i => UninitializedObject)
719+
// calls on packages are unexpected. However the typer might mistakenly
720+
// set the receiver to be a package instead of package object.
721+
// See packageObjectStringInterpolator.scala
722+
if !meth.owner.denot.isPackageObject then
723+
report.warning("[Internal error] Unexpected call on package = " + value.show + ", meth = " + meth.show + Trace.show, Trace.position)
724+
Bottom
725+
else
726+
// Method call on package object instead
727+
val packageObj = accessObject(meth.owner.moduleClass.asClass)
728+
call(packageObj, meth, args, receiver, superType, needResolve)
729+
730+
case v @ SafeValue(tpe) =>
731+
// Assume such method is pure. Check return type, only try to analyze body if return type is not safe
732+
val target = resolve(v.typeref.symbol.asClass, meth)
733+
if !target.hasSource then
734+
UnknownValue
735+
else
736+
val ddef = target.defTree.asInstanceOf[DefDef]
737+
val returnType = ddef.tpt.tpe
738+
if SafeValue.safeTypes.contains(returnType) then
739+
// since method is pure and return type is safe, no need to analyze method body
740+
SafeValue(returnType)
741+
else
742+
val cls = target.owner.enclosingClass.asClass
743+
// convert SafeType to an OfClass before analyzing method body
744+
val ref = OfClass(cls, Bottom, NoSymbol, Nil, Env.NoEnv)
745+
call(ref, meth, args, receiver, superType, needResolve)
716746

717747
case Bottom =>
718748
Bottom
@@ -739,7 +769,7 @@ class Objects(using Context @constructorOnly):
739769
Bottom
740770
else
741771
// Array.length is OK
742-
SafeValue
772+
SafeValue(defn.IntType)
743773

744774
case ref: Ref =>
745775
val isLocal = !meth.owner.isClass
@@ -760,10 +790,10 @@ class Objects(using Context @constructorOnly):
760790
arr
761791
else if target.equals(defn.Predef_classOf) then
762792
// Predef.classOf is a stub method in tasty and is replaced in backend
763-
SafeValue
793+
UnknownValue
764794
else if target.equals(defn.ClassTagModule_apply) then
765-
// ClassTag and other reflection related values are considered safe
766-
SafeValue
795+
// ClassTag and other reflection related values are not analyzed
796+
UnknownValue
767797
else if target.hasSource then
768798
val cls = target.owner.enclosingClass.asClass
769799
val ddef = target.defTree.asInstanceOf[DefDef]
@@ -851,6 +881,7 @@ class Objects(using Context @constructorOnly):
851881
Returns.installHandler(ctor)
852882
eval(ddef.rhs, ref, cls, cacheResult = true)
853883
Returns.popHandler(ctor)
884+
value
854885
}
855886
else
856887
// no source code available
@@ -877,8 +908,9 @@ class Objects(using Context @constructorOnly):
877908
else
878909
UnknownValue
879910

880-
case SafeValue =>
881-
SafeValue
911+
case v @ SafeValue(_) =>
912+
report.warning("[Internal error] Unexpected selection on safe value " + v.show + ", field = " + field.show + Trace.show, Trace.position)
913+
Bottom
882914

883915
case Package(packageSym) =>
884916
if field.isStaticObject then
@@ -962,7 +994,7 @@ class Objects(using Context @constructorOnly):
962994
case arr: OfArray =>
963995
report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + " field = " + field + Trace.show, Trace.position)
964996

965-
case SafeValue | UnknownValue =>
997+
case SafeValue(_) | UnknownValue =>
966998
report.warning("Assigning to base or unknown value is forbidden. " + Trace.show, Trace.position)
967999

9681000
case ValueSet(values) =>
@@ -994,7 +1026,7 @@ class Objects(using Context @constructorOnly):
9941026
*/
9951027
def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
9961028
outer.filterClass(klass.owner) match
997-
case _ : Fun | _: OfArray | SafeValue =>
1029+
case _ : Fun | _: OfArray | SafeValue(_) =>
9981030
report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position)
9991031
Bottom
10001032

@@ -1089,7 +1121,7 @@ class Objects(using Context @constructorOnly):
10891121
case UnknownValue =>
10901122
report.warning("Calling on unknown value. " + Trace.show, Trace.position)
10911123
Bottom
1092-
case _: ValueSet | _: Ref | _: OfArray | _: Package | SafeValue =>
1124+
case _: ValueSet | _: Ref | _: OfArray | _: Package | SafeValue(_) =>
10931125
report.warning("[Internal error] Unexpected by-name value " + value.show + ". " + Trace.show, Trace.position)
10941126
Bottom
10951127
else
@@ -1276,8 +1308,8 @@ class Objects(using Context @constructorOnly):
12761308
case _: This =>
12771309
evalType(expr.tpe, thisV, klass)
12781310

1279-
case Literal(_) =>
1280-
SafeValue
1311+
case Literal(const) =>
1312+
SafeValue(const.tpe)
12811313

12821314
case Typed(expr, tpt) =>
12831315
if tpt.tpe.hasAnnotation(defn.UncheckedAnnot) then
@@ -1352,7 +1384,12 @@ class Objects(using Context @constructorOnly):
13521384
res
13531385

13541386
case SeqLiteral(elems, elemtpt) =>
1355-
evalExprs(elems, thisV, klass).join
1387+
// Obtain the output Seq from SeqLiteral tree by calling respective wrapArrayMethod
1388+
val wrapArrayMethodName = ast.tpd.wrapArrayMethodName(elemtpt.tpe)
1389+
val meth = defn.getWrapVarargsArrayModule.requiredMethod(wrapArrayMethodName)
1390+
val module = defn.getWrapVarargsArrayModule.moduleClass.asClass
1391+
val args = evalArgs(elems.map(Arg.apply), thisV, klass)
1392+
call(ObjectRef(module), meth, args, module.typeRef, NoType)
13561393

13571394
case Inlined(call, bindings, expansion) =>
13581395
evalExprs(bindings, thisV, klass)
@@ -1563,7 +1600,7 @@ class Objects(using Context @constructorOnly):
15631600

15641601
// call .apply
15651602
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1566-
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(SafeValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1603+
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(SafeValue(defn.IntType), summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
15671604

15681605
if isWildcardStarArgList(pats) then
15691606
if pats.size == 1 then
@@ -1574,7 +1611,7 @@ class Objects(using Context @constructorOnly):
15741611
else
15751612
// call .drop
15761613
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
1577-
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(SafeValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1614+
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(SafeValue(defn.IntType), summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
15781615
for pat <- pats.init do evalPattern(applyRes, pat)
15791616
evalPattern(dropRes, pats.last)
15801617
end if
@@ -1585,8 +1622,7 @@ class Objects(using Context @constructorOnly):
15851622
end evalSeqPatterns
15861623

15871624
def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
1588-
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
1589-
(catchValue == Bottom && remainingScrutinee != Bottom)
1625+
remainingScrutinee == Bottom || catchValue == Bottom
15901626

15911627
var remainingScrutinee = scrutinee
15921628
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
@@ -1615,8 +1651,8 @@ class Objects(using Context @constructorOnly):
16151651
*/
16161652
def evalType(tp: Type, thisV: ThisValue, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
16171653
tp match
1618-
case _: ConstantType =>
1619-
SafeValue
1654+
case consttpe: ConstantType =>
1655+
SafeValue(consttpe.underlying)
16201656

16211657
case tmref: TermRef if tmref.prefix == NoPrefix =>
16221658
val sym = tmref.symbol
@@ -1866,7 +1902,7 @@ class Objects(using Context @constructorOnly):
18661902
resolveThis(target, ref.outerValue(klass), outerCls)
18671903
case ValueSet(values) =>
18681904
values.map(ref => resolveThis(target, ref, klass)).join
1869-
case _: Fun | _ : OfArray | _: Package | SafeValue =>
1905+
case _: Fun | _ : OfArray | _: Package | SafeValue(_) =>
18701906
report.warning("[Internal error] unexpected thisV = " + thisV + ", target = " + target.show + ", klass = " + klass.show + Trace.show, Trace.position)
18711907
Bottom
18721908
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package p
2+
package object a {
3+
val b = 10
4+
implicit class CI(s: StringContext) {
5+
def ci(args: Any*) = 10
6+
}
7+
}
8+
9+
import p.a._
10+
11+
object A:
12+
val f = b // p.a(ObjectRef(p.a)).b
13+
def foo(s: String): String = s
14+
val f1 = ci"a" // => p.a(Package(p).select(a)).CI(StringContext"a").ci()
15+
16+

0 commit comments

Comments
 (0)