Skip to content

Commit b825f5a

Browse files
authored
Improve Contains handling (#21361)
Make use of enclosing Contains assumptions to improve the subsumes logic.
2 parents d40da0b + b2292a8 commit b825f5a

File tree

8 files changed

+126
-23
lines changed

8 files changed

+126
-23
lines changed

Diff for: compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+18
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,21 @@ extension (self: Type)
713713
case _ =>
714714
self
715715

716+
/** An extractor for a contains argument */
717+
object ContainsImpl:
718+
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =
719+
tree.fun.tpe.widen match
720+
case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl =>
721+
tree.args match
722+
case csArg :: refArg :: Nil => Some((csArg, refArg))
723+
case _ => None
724+
case _ => None
725+
726+
/** An extractor for a contains parameter */
727+
object ContainsParam:
728+
def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] =
729+
sym.info.dealias match
730+
case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil)
731+
if tycon.typeSymbol == defn.Caps_ContainsTrait
732+
&& cs.typeSymbol.isAbstractOrParamType => Some((cs, ref))
733+
case _ => None

Diff for: compiler/src/dotty/tools/dotc/cc/CaptureRef.scala

+4
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType:
116116
case x1: SingletonCaptureRef => x1.subsumes(y)
117117
case _ => false
118118
case x: TermParamRef => subsumesExistentially(x, y)
119+
case x: TypeRef => assumedContainsOf(x).contains(y)
119120
case _ => false
120121

122+
def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] =
123+
CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty)
124+
121125
end CaptureRef
122126

123127
trait SingletonCaptureRef extends SingletonType, CaptureRef

Diff for: compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property}
1616
import typer.ErrorReporting.Addenda
1717
import TypeComparer.subsumesExistentially
1818
import util.common.alwaysTrue
19-
import scala.collection.mutable
19+
import scala.collection.{mutable, immutable}
2020
import CCState.*
2121

2222
/** A class for capture sets. Capture sets can be constants or variables.
@@ -1125,6 +1125,12 @@ object CaptureSet:
11251125
foldOver(cs, t)
11261126
collect(CaptureSet.empty, tp)
11271127

1128+
type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]]
1129+
val AssumedContains: Property.Key[AssumedContains] = Property.Key()
1130+
1131+
def assumedContains(using Context): AssumedContains =
1132+
ctx.property(AssumedContains).getOrElse(immutable.Map.empty)
1133+
11281134
private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key()
11291135

11301136
/** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable

Diff for: compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+26-20
Original file line numberDiff line numberDiff line change
@@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer:
676676
i"Sealed type variable $pname", "be instantiated to",
677677
i"This is often caused by a local capability$where\nleaking as part of its result.",
678678
tree.srcPos)
679-
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
680-
if meth == defn.Caps_containsImpl then checkContains(tree)
681-
res
679+
try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
680+
finally checkContains(tree)
682681
end recheckTypeApply
683682

684683
/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
685684
* capability and assert that `{r} <:CS`.
686685
*/
687-
def checkContains(tree: TypeApply)(using Context): Unit =
688-
tree.fun.knownType.widen match
689-
case fntpe: PolyType =>
690-
tree.args match
691-
case csArg :: refArg :: Nil =>
692-
val cs = csArg.knownType.captureSet
693-
val ref = refArg.knownType
694-
capt.println(i"check contains $cs , $ref")
695-
ref match
696-
case ref: CaptureRef if ref.isTracked =>
697-
checkElem(ref, cs, tree.srcPos)
698-
case _ =>
699-
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
700-
case _ =>
701-
case _ =>
686+
def checkContains(tree: TypeApply)(using Context): Unit = tree match
687+
case ContainsImpl(csArg, refArg) =>
688+
val cs = csArg.knownType.captureSet
689+
val ref = refArg.knownType
690+
capt.println(i"check contains $cs , $ref")
691+
ref match
692+
case ref: CaptureRef if ref.isTracked =>
693+
checkElem(ref, cs, tree.srcPos)
694+
case _ =>
695+
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
696+
case _ =>
702697

703698
override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
704699
inNestedLevel(super.recheckBlock(tree, pt))
@@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer:
814809
val localSet = capturedVars(sym)
815810
if !localSet.isAlwaysEmpty then
816811
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv)
812+
813+
// ctx with AssumedContains entries for each Contains parameter
814+
val bodyCtx =
815+
var ac = CaptureSet.assumedContains
816+
for paramSyms <- sym.paramSymss do
817+
for case ContainsParam(cs, ref) <- paramSyms do
818+
ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref)
819+
if ac.isEmpty then ctx
820+
else ctx.withProperty(CaptureSet.AssumedContains, Some(ac))
821+
817822
inNestedLevel: // TODO: needed here?
818-
try checkInferredResult(super.recheckDefDef(tree, sym), tree)
823+
try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree)
819824
finally
820825
if !sym.isAnonymousFunction then
821826
// Anonymous functions propagate their type to the enclosing environment
822827
// so it is not in general sound to interpolate their types.
823828
interpolateVarsIn(tree.tpt)
824829
curEnv = saved
825-
830+
end recheckDefDef
831+
826832
/** If val or def definition with inferred (result) type is visible
827833
* in other compilation units, check that the actual inferred type
828834
* conforms to the expected type where all inferred capture sets are dropped.

Diff for: compiler/src/dotty/tools/dotc/core/Definitions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ class Definitions {
10021002
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
10031003
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
10041004
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
1005-
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
1005+
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
10061006
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
10071007

10081008
@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")

Diff for: tests/pos-custom-args/captures/i21313.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import caps.CapSet
22

33
trait Async:
4-
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T
4+
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T =
5+
val x: Async^{this} = ???
6+
val y: Async^{Cap^} = x
7+
val ac: Async^ = ???
8+
def f(using caps.Contains[Cap, ac.type]) =
9+
val x2: Async^{this} = ???
10+
val y2: Async^{Cap^} = x2
11+
val x3: Async^{ac} = ???
12+
val y3: Async^{Cap^} = x3
13+
???
514

615
trait Source[+T, Cap^]:
716
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.

Diff for: tests/run/Providers.check

+8
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,11 @@ Executing query: insert into subscribers(name, email) values Daniel daniel@Rockt
1818
You've just been subscribed to RockTheJVM. Welcome, Martin
1919
Acquired connection
2020
Executing query: insert into subscribers(name, email) values Martin [email protected]
21+
22+
Injected2
23+
You've just been subscribed to RockTheJVM. Welcome, Daniel
24+
Acquired connection
25+
Executing query: insert into subscribers(name, email) values Daniel [email protected]
26+
You've just been subscribed to RockTheJVM. Welcome, Martin
27+
Acquired connection
28+
Executing query: insert into subscribers(name, email) values Martin [email protected]

Diff for: tests/run/Providers.scala

+52
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ end Providers
6565
Explicit().test()
6666
println(s"\nInjected")
6767
Injected().test()
68+
println(s"\nInjected2")
69+
Injected2().test()
6870

6971
/** Demonstrator for explicit dependency construction */
7072
class Explicit:
@@ -173,5 +175,55 @@ class Injected:
173175
end explicit
174176
end Injected
175177

178+
/** Injected with builders in companion objects */
179+
class Injected2:
180+
import Providers.*
181+
182+
case class User(name: String, email: String)
183+
184+
class UserSubscription(emailService: EmailService, db: UserDatabase):
185+
def subscribe(user: User) =
186+
emailService.email(user)
187+
db.insert(user)
188+
object UserSubscription:
189+
def apply()(using Provider[(EmailService, UserDatabase)]): UserSubscription =
190+
new UserSubscription(provided[EmailService], provided[UserDatabase])
191+
192+
class EmailService:
193+
def email(user: User) =
194+
println(s"You've just been subscribed to RockTheJVM. Welcome, ${user.name}")
195+
196+
class UserDatabase(pool: ConnectionPool):
197+
def insert(user: User) =
198+
pool.get().runQuery(s"insert into subscribers(name, email) values ${user.name} ${user.email}")
199+
object UserDatabase:
200+
def apply()(using Provider[(ConnectionPool)]): UserDatabase =
201+
new UserDatabase(provided[ConnectionPool])
202+
203+
class ConnectionPool(n: Int):
204+
def get(): Connection =
205+
println(s"Acquired connection")
206+
Connection()
207+
208+
class Connection():
209+
def runQuery(query: String): Unit =
210+
println(s"Executing query: $query")
211+
212+
def test() =
213+
given Provider[EmailService] = provide(EmailService())
214+
given Provider[ConnectionPool] = provide(ConnectionPool(10))
215+
given Provider[UserDatabase] = provide(UserDatabase())
216+
given Provider[UserSubscription] = provide(UserSubscription())
217+
218+
def subscribe(user: User)(using Provider[UserSubscription]) =
219+
val sub = UserSubscription()
220+
sub.subscribe(user)
221+
222+
subscribe(User("Daniel", "[email protected]"))
223+
subscribe(User("Martin", "[email protected]"))
224+
end test
225+
end Injected2
226+
227+
176228

177229

0 commit comments

Comments
 (0)