Skip to content

Clean up Middle IR and variable assignment/reassignment #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: hkmc2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,24 @@ sealed abstract class Block extends Product with AutoLocated:

protected def children: Ls[Located] = ??? // Maybe extending AutoLocated is unnecessary

/** Returns all local variables defined in this block through an `Assign` instruction.
* Does not look inside nested definitions. */
lazy val definedVars: Set[Local] = this match
case _: Return | _: Throw => Set.empty
case Begin(sub, rst) => sub.definedVars ++ rst.definedVars

// * Currently, `TermSymbol` is used to represent private fields,
// * such as `let` bindings in object scopes and non-`val` class parameters.
// * Assignments to such fields are not defining a variable of the IR,
// * even though they're represented using the same Assign instruction.
// * TODO: It may be much cleaner to just have the elaborator use explicit selections and ReassignField instead.
// * TODO: Confusingly, we also have the case of `Define` being used to define a `val` without an owner (see below). We should simplify all this ad-hoc logic.
case Assign(l: TermSymbol, r, rst) => rst.definedVars
case Assign(l, r, rst) => rst.definedVars + l
case AssignField(l, n, r, rst) => rst.definedVars
case AssignDynField(l, n, ai, r, rst) => rst.definedVars

case Reassign(l, r, rst) => rst.definedVars
case ReassignField(l, n, r, rst) => rst.definedVars
case ReassignDynField(l, n, ai, r, rst) => rst.definedVars
case Match(scrut, arms, dflt, rst) =>
arms.flatMap(_._2.definedVars).toSet ++ dflt.toList.flatMap(_.definedVars) ++ rst.definedVars
case End(_) => Set.empty
Expand All @@ -49,8 +60,9 @@ sealed abstract class Block extends Product with AutoLocated:
case _: Return | _: Throw | _: End | _: Break | _: Continue => 1
case Begin(sub, rst) => sub.size + rst.size
case Assign(_, _, rst) => 1 + rst.size
case AssignField(_, _, _, rst) => 1 + rst.size
case AssignDynField(_, _, _, _, rst) => 1 + rst.size
case Reassign(_, _, rst) => 1 + rst.size
case ReassignField(_, _, _, rst) => 1 + rst.size
case ReassignDynField(_, _, _, _, rst) => 1 + rst.size
case Match(_, arms, dflt, rst) =>
1 + arms.map(_._2.size).sum + dflt.map(_.size).getOrElse(0) + rst.size
case Define(_, rst) => 1 + rst.size
Expand All @@ -63,6 +75,7 @@ sealed abstract class Block extends Product with AutoLocated:
case b: BlockTail => f(b)
case Begin(sub, rst) => Begin(sub, rst.mapTail(f))
case Assign(lhs, rhs, rst) => Assign(lhs, rhs, rst.mapTail(f))
case Reassign(lhs, rhs, rst) => Reassign(lhs, rhs, rst.mapTail(f))
case Define(defn, rst) => Define(defn, rst.mapTail(f))
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) =>
HandleBlock(lhs, res, par, args, cls, handlers.map(h => Handler(h.sym, h.resumeSym, h.params, h.body)), body, rest.mapTail(f))
Expand All @@ -71,10 +84,10 @@ sealed abstract class Block extends Product with AutoLocated:
case Match(scrut, arms, dflt, rst) =>
Match(scrut, arms, dflt, rst.mapTail(f))
case Label(label, body, rest) => Label(label, body, rest.mapTail(f))
case af @ AssignField(lhs, nme, rhs, rest) =>
AssignField(lhs, nme, rhs, rest.mapTail(f))(af.symbol)
case adf @ AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
AssignDynField(lhs, fld, arrayIdx, rhs, rest.mapTail(f))
case af @ ReassignField(lhs, nme, rhs, rest) =>
ReassignField(lhs, nme, rhs, rest.mapTail(f))(af.symbol)
case adf @ ReassignDynField(lhs, fld, arrayIdx, rhs, rest) =>
ReassignDynField(lhs, fld, arrayIdx, rhs, rest.mapTail(f))
case tb @ TryBlock(sub, fin, rest) =>
TryBlock(sub, fin, rest.mapTail(f))

Expand All @@ -90,9 +103,10 @@ sealed abstract class Block extends Product with AutoLocated:
case Continue(label) => Set(label)
case Begin(sub, rest) => sub.freeVars ++ rest.freeVars
case TryBlock(sub, finallyDo, rest) => sub.freeVars ++ finallyDo.freeVars ++ rest.freeVars
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVars ++ rest.freeVars
case AssignField(lhs, nme, rhs, rest) => lhs.freeVars ++ rhs.freeVars ++ rest.freeVars
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVars ++ fld.freeVars ++ rhs.freeVars ++ rest.freeVars
case Assign(lhs, rhs, rest) => rhs.freeVars ++ rest.freeVars
case Reassign(lhs, rhs, rest) => rhs.freeVars ++ rest.freeVars + lhs
case ReassignField(lhs, nme, rhs, rest) => lhs.freeVars ++ rhs.freeVars ++ rest.freeVars
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVars ++ fld.freeVars ++ rhs.freeVars ++ rest.freeVars
case Define(defn, rest) => defn.freeVars ++ rest.freeVars
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
(bod.freeVars - lhs) ++ rst.freeVars ++ hdr.flatMap(_.freeVars)
Expand All @@ -113,9 +127,10 @@ sealed abstract class Block extends Product with AutoLocated:
case Continue(label) => Set(label)
case Begin(sub, rest) => sub.freeVarsLLIR ++ rest.freeVarsLLIR
case TryBlock(sub, finallyDo, rest) => sub.freeVarsLLIR ++ finallyDo.freeVarsLLIR ++ rest.freeVarsLLIR
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case AssignField(lhs, nme, rhs, rest) => lhs.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVarsLLIR ++ fld.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case Assign(lhs, rhs, rest) => rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case Reassign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case ReassignField(lhs, nme, rhs, rest) => lhs.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVarsLLIR ++ fld.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
case Define(defn, rest) => defn.freeVarsLLIR ++ rest.freeVarsLLIR
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
(bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVars)
Expand All @@ -126,8 +141,9 @@ sealed abstract class Block extends Product with AutoLocated:
case Begin(sub, rest) => sub :: rest :: Nil
case TryBlock(sub, finallyDo, rest) => sub :: finallyDo :: rest :: Nil
case Assign(_, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case AssignField(_, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case AssignDynField(_, _, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case Reassign(_, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case ReassignField(_, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case ReassignDynField(_, _, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case Define(d, rest) => d.subBlocks ::: rest :: Nil
case HandleBlock(_, _, par, args, _, handlers, body, rest) => par.subBlocks ++ args.flatMap(_.subBlocks) ++ handlers.map(_.body) :+ body :+ rest
case Label(_, body, rest) => body :: rest :: Nil
Expand Down Expand Up @@ -198,18 +214,24 @@ sealed abstract class Block extends Product with AutoLocated:
if newRest is rest
then this
else Assign(lhs, rhs, newRest)

case Reassign(lhs, rhs, rest) =>
val newRest = rest.flatten(k)
if newRest is rest
then this
else Reassign(lhs, rhs, newRest)

case a@AssignField(lhs, nme, rhs, rest) =>
case a@ReassignField(lhs, nme, rhs, rest) =>
val newRest = rest.flatten(k)
if newRest is rest
then this
else AssignField(lhs, nme, rhs, newRest)(a.symbol)
else ReassignField(lhs, nme, rhs, newRest)(a.symbol)

case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) =>
val newRest = rest.flatten(k)
if newRest is rest
then this
else AssignDynField(lhs, fld, arrayIdx, rhs, newRest)
else ReassignDynField(lhs, fld, arrayIdx, rhs, newRest)

case Define(defn, rest) =>
val newDefn = defn match
Expand Down Expand Up @@ -271,12 +293,18 @@ case class Begin(sub: Block, rest: Block) extends Block with ProductWithTail

case class TryBlock(sub: Block, finallyDo: Block, rest: Block) extends Block with ProductWithTail

// * Assigns the initial value of a variable or private field symbol.
// * Invariant: any given symbol may be assigned a value at most once, in the same block
// * (and specifically not in nested definitions),
// * and may only be read or reassigned after it has been assigned first.
case class Assign(lhs: Local, rhs: Result, rest: Block) extends Block with ProductWithTail
// case class Assign(lhs: Path, rhs: Result, rest: Block) extends Block with ProductWithTail

case class AssignField(lhs: Path, nme: Tree.Ident, rhs: Result, rest: Block)(val symbol: Opt[FieldSymbol]) extends Block with ProductWithTail
// * Reassigns the value of a variable or private field symbol.
case class Reassign(lhs: Local, rhs: Result, rest: Block) extends Block with ProductWithTail

case class ReassignField(lhs: Path, nme: Tree.Ident, rhs: Result, rest: Block)(val symbol: Opt[FieldSymbol]) extends Block with ProductWithTail

case class AssignDynField(lhs: Path, fld: Path, arrayIdx: Bool, rhs: Result, rest: Block) extends Block with ProductWithTail
case class ReassignDynField(lhs: Path, fld: Path, arrayIdx: Bool, rhs: Result, rest: Block) extends Block with ProductWithTail

case class Define(defn: Defn, rest: Block) extends Block with ProductWithTail

Expand Down Expand Up @@ -466,7 +494,8 @@ extension (k: Block => Block)
def transform(f: (Block => Block) => (Block => Block)) = f(k)

def assign(l: Local, r: Result) = k.chain(Assign(l, r, _))
def assignFieldN(lhs: Path, nme: Tree.Ident, rhs: Result) = k.chain(AssignField(lhs, nme, rhs, _)(N))
def reassign(l: Local, r: Result) = k.chain(Reassign(l, r, _))
def reassignFieldN(lhs: Path, nme: Tree.Ident, rhs: Result) = k.chain(ReassignField(lhs, nme, rhs, _)(N))
def break(l: Local): Block = k.rest(Break(l))
def continue(l: Local): Block = k.rest(Continue(l))
def define(defn: Defn) = k.chain(Define(defn, _))
Expand Down
13 changes: 9 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ class BlockTransformer(subst: SymbolSubst):
val l2 = applyLocal(l)
val rst2 = applySubBlock(rst)
if (l2 is l) && (r2 is r) && (rst2 is rst) then b else Assign(l2, r2, rst2)
case b @ AssignField(l, n, r, rst) =>
case Reassign(l, r, rst) =>
applyResult2(r): r2 =>
val l2 = applyLocal(l)
val rst2 = applySubBlock(rst)
if (l2 is l) && (r2 is r) && (rst2 is rst) then b else Reassign(l2, r2, rst2)
case b @ ReassignField(l, n, r, rst) =>
applyResult2(r): r2 =>
val l2 = applyPath(l)
val rst2 = applySubBlock(rst)
val sym = b.symbol.mapConserve(_.subst)
if (l2 is l) && (r2 is r) && (rst2 is rst) && (sym is b.symbol)
then b else AssignField(l2, n, r2, rst2)(sym)
then b else ReassignField(l2, n, r2, rst2)(sym)
case Define(defn, rst) =>
val defn2 = applyDefn(defn)
val rst2 = applySubBlock(rst)
Expand All @@ -83,14 +88,14 @@ class BlockTransformer(subst: SymbolSubst):
if (l2 is l) && (res2 is res) && (par2 is par) && (args2 is args) &&
(cls2 is cls) && (hdr2 is hdr) && (bod2 is bod) && (rst2 is rst)
then b else HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) =>
applyResult2(rhs): rhs2 =>
val lhs2 = applyPath(lhs)
val fld2 = applyPath(fld)
val rest2 = applySubBlock(rest)
if (lhs2 is lhs) && (fld2 is fld) && (rhs2 is rhs) && (rest2 is rest)
then b
else AssignDynField(lhs2, fld2, arrayIdx, rhs2, rest2)
else ReassignDynField(lhs2, fld2, arrayIdx, rhs2, rest2)


def applyResult2(r: Result)(k: Result => Block): Block = k(applyResult(r))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class BlockTraverser:
case Begin(sub, rst) => applySubBlock(sub); applySubBlock(rst)
case TryBlock(sub, fin, rst) => applySubBlock(sub); applySubBlock(fin); applySubBlock(rst)
case Assign(l, r, rst) => applyLocal(l); applyResult(r); applySubBlock(rst)
case b @ AssignField(l, n, r, rst) =>
case Reassign(l, r, rst) => applyLocal(l); applyResult(r); applySubBlock(rst)
case b @ ReassignField(l, n, r, rst) =>
applyPath(l); applyResult(r); applySubBlock(rst); b.symbol.foreach(_.traverse)
case Define(defn, rst) => applyDefn(defn); applySubBlock(rst)
case HandleBlock(l, res, par, args, cls, hdr, bod, rst) =>
Expand All @@ -46,7 +47,7 @@ class BlockTraverser:
hdr.foreach(applyHandler)
applySubBlock(bod)
applySubBlock(rst)
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) =>
applyPath(lhs)
applyResult(rhs)
applyPath(fld)
Expand Down
29 changes: 16 additions & 13 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
private def funcLikeHandlerCtx(ctorThis: Option[Path], isHandlerMtd: Bool, nme: Str) =
HandlerCtx(false, false, nme, ctorThis, state =>
blockBuilder
.assignFieldN(state.res.contTrace.last, nextIdent, Instantiate(
.reassignFieldN(state.res.contTrace.last, nextIdent, Instantiate(
state.cls.selN(Tree.Ident("class")),
Value.Lit(Tree.IntLit(state.uid)) :: Nil))
.assignFieldN(state.res.contTrace, lastIdent, state.res.contTrace.last.next)
.reassignFieldN(state.res.contTrace, lastIdent, state.res.contTrace.last.next)
.ret(state.res))
private def functionHandlerCtx(nme: Str) = funcLikeHandlerCtx(N, false, nme)
private def topLevelCtx(nme: Str) = HandlerCtx(true, false, nme, N, _ => rtThrowMsg("Unhandled effects"))
Expand Down Expand Up @@ -270,12 +270,15 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
case Assign(lhs, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(Assign(lhs, rhs, head), parts)
case blk @ AssignField(lhs, nme, rhs, rest) =>
case Reassign(lhs, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(AssignField(lhs, nme, rhs, head)(blk.symbol), parts)
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
PartRet(Reassign(lhs, rhs, head), parts)
case blk @ ReassignField(lhs, nme, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(AssignDynField(lhs, fld, arrayIdx, rhs, head), parts)
PartRet(ReassignField(lhs, nme, rhs, head)(blk.symbol), parts)
case ReassignDynField(lhs, fld, arrayIdx, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(ReassignDynField(lhs, fld, arrayIdx, rhs, head), parts)
case Return(_, _) => PartRet(blk, Nil)
// ignored cases
case TryBlock(sub, finallyDo, rest) => ??? // ignore
Expand Down Expand Up @@ -385,7 +388,7 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El

val handlerBody = translateBlock(h.body, HandlerCtx(false, true,
s"Cont$$handleBlock$$${h.lhs.nme}$$", N, state => blockBuilder
.assignFieldN(state.res.contTrace.last, nextIdent, PureCall(state.cls, Value.Lit(Tree.IntLit(state.uid)) :: Nil))
.reassignFieldN(state.res.contTrace.last, nextIdent, PureCall(state.cls, Value.Lit(Tree.IntLit(state.uid)) :: Nil))
.ret(PureCall(paths.handleBlockImplPath, state.res :: h.lhs.asPath :: Nil))))

val handlerMtds = h.handlers.map: handler =>
Expand Down Expand Up @@ -460,13 +463,13 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
override def applyBlock(b: Block): Block = b match
case ReturnCont(res, uid) =>
blockBuilder
.assign(pcSymbol, Value.Lit(Tree.IntLit(uid)))
.assignFieldN(res.asPath.contTrace.last, nextIdent, clsSym.asPath)
.assignFieldN(res.asPath.contTrace, lastIdent, clsSym.asPath)
.reassign(pcSymbol, Value.Lit(Tree.IntLit(uid)))
.reassignFieldN(res.asPath.contTrace.last, nextIdent, clsSym.asPath)
.reassignFieldN(res.asPath.contTrace, lastIdent, clsSym.asPath)
.ret(res.asPath)
case StateTransition(uid) =>
blockBuilder
.assign(pcSymbol, Value.Lit(Tree.IntLit(uid)))
.reassign(pcSymbol, Value.Lit(Tree.IntLit(uid)))
.continue(loopLbl)
case FnEnd() =>
blockBuilder.break(loopLbl)
Expand All @@ -486,7 +489,7 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El

val resumedVal = VarSymbol(Tree.Ident("value$"))

def createAssignment(sym: Local) = Assign(sym, resumedVal.asPath, End())
def createAssignment(sym: Local) = Reassign(sym, resumedVal.asPath, End()) // TODO: or is this `Assign`?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always a reassignment, as the old symbol would hold the previous EffectSig


val assignedResumedCases = for
b <- parts
Expand Down Expand Up @@ -531,7 +534,7 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
Assign(freshTmp(), PureCall(
Value.Ref(State.builtinOpsMap("super")), // refers to runtime.FunctionContFrame which is pure
Value.Lit(Tree.UnitLit(true)) :: Nil), End()),
AssignField(
ReassignField(
clsSym.asPath,
pcVar.id,
Value.Ref(pcVar),
Expand Down
17 changes: 10 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -601,24 +601,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case (blk, (bms, local)) =>
val initial = blk.assign(local, createCall(bms, ctx))
ctx.defns(bms) match
case c: ClsLikeDefn => initial.assignFieldN(local.asPath, Tree.Ident("class"), bms.asPath)
case c: ClsLikeDefn => initial.reassignFieldN(local.asPath, Tree.Ident("class"), bms.asPath)
case _ => initial

val remaining = rewritten match
case Assign(lhs: InnerSymbol, rhs, rest) => ctx.getIsymPath(lhs) match

// TODO (LP:) not sure about the semantics of these "Reassign" cases
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CAG2Mark 👀


case Reassign(lhs: InnerSymbol, rhs, rest) => ctx.getIsymPath(lhs) match
case Some(value) if !belongsToCtor(lhs) =>
Assign(value, applyResult(rhs), applyBlock(rest))
Reassign(value, applyResult(rhs), applyBlock(rest))
case _ => super.applyBlock(rewritten)

case Assign(t: TermSymbol, rhs, rest) if t.owner.isDefined =>
case Reassign(t: TermSymbol, rhs, rest) if t.owner.isDefined =>
ctx.getIsymPath(t.owner.get) match
case Some(value) if !belongsToCtor(t.owner.get) =>
AssignField(value.asPath, t.id, applyResult(rhs), applyBlock(rest))(N)
ReassignField(value.asPath, t.id, applyResult(rhs), applyBlock(rest))(N)
case _ => super.applyBlock(rewritten)

case Assign(lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match
case Reassign(lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match
case Some(captureSym) =>
AssignField(ctx.getLocalClosPath(lhs).get, captureSym.id, applyResult(rhs), applyBlock(rest))(N)
ReassignField(ctx.getLocalClosPath(lhs).get, captureSym.id, applyResult(rhs), applyBlock(rest))(N)
case None => ctx.getLocalPath(lhs) match
case None => super.applyBlock(rewritten)
case Some(value) => Assign(value, applyResult(rhs), applyBlock(rest))
Expand Down
Loading
Loading