Skip to content

Commit 517c3fe

Browse files
committed
add in cooperative cancellation, test that it works
1 parent 4ccba38 commit 517c3fe

File tree

9 files changed

+247
-98
lines changed

9 files changed

+247
-98
lines changed

compiler/src/dotty/tools/dotc/Run.scala

+49-17
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,18 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
177177
if local != null then
178178
op(using ctx)(local)
179179

180-
def doBeginUnit()(using Context): Unit =
181-
trackProgress: progress =>
182-
progress.informUnitStarting(ctx.compilationUnit)
180+
private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T =
181+
val local = _progress
182+
if local != null then
183+
op(using ctx)(local)
184+
else
185+
default
186+
187+
def didEnterUnit()(using Context): Boolean =
188+
foldProgress(true /* should progress by default */)(_.tryEnterUnit(ctx.compilationUnit))
189+
190+
def didEnterFinal()(using Context): Boolean =
191+
foldProgress(true /* should progress by default */)(p => !p.checkCancellation())
183192

184193
def doAdvanceUnit()(using Context): Unit =
185194
trackProgress: progress =>
@@ -195,6 +204,13 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
195204
trackProgress: progress =>
196205
progress.enterPhase(currentPhase)
197206

207+
/** interrupt the thread and set cancellation state */
208+
private def cancelInterrupted(): Unit =
209+
try
210+
trackProgress(_.cancel())
211+
finally
212+
Thread.currentThread().nn.interrupt()
213+
198214
private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
199215
trackProgress: progress =>
200216
progress.unitc = 0 // reset unit count in current (sub)phase
@@ -213,7 +229,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
213229
progress.seen += 1 // trace that we've seen a (sub)phase
214230
progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase
215231
progress.subtraversalc += 1 // record that we've seen a subphase
216-
progress.tickSubphase()
232+
if !progress.isCancelled() then
233+
progress.tickSubphase()
217234

218235
/** Will be set to true if any of the compiled compilation units contains
219236
* a pureFunctions language import.
@@ -297,7 +314,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
297314
Stats.trackTime(s"phase time ms/$phase") {
298315
val start = System.currentTimeMillis
299316
val profileBefore = profiler.beforePhase(phase)
300-
units = phase.runOn(units)
317+
try units = phase.runOn(units)
318+
catch case _: InterruptedException => cancelInterrupted()
301319
profiler.afterPhase(phase, profileBefore)
302320
if (ctx.settings.Xprint.value.containsPhase(phase))
303321
for (unit <- units)
@@ -333,7 +351,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
333351
if (!ctx.reporter.hasErrors)
334352
Rewrites.writeBack()
335353
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
336-
while (finalizeActions.nonEmpty) {
354+
while (finalizeActions.nonEmpty && didEnterFinal()) {
337355
val action = finalizeActions.remove(0)
338356
action()
339357
}
@@ -481,6 +499,8 @@ object Run {
481499

482500

483501
private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
502+
export cb.{cancel, isCancelled}
503+
484504
private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run
485505
private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase
486506
private[Run] var latec: Int = 0 // current late unit count
@@ -515,34 +535,46 @@ object Run {
515535

516536

517537
/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
518-
private def currentProgress()(using Context): Int =
519-
traversalc * run.files.size + unitc + latec
538+
private def currentProgress(): Int =
539+
traversalc * work() + unitc + latec
520540

521541
/**Total progress is computed as the sum of
522542
* - the number of traversals we expect to make over all files
523543
* - the number of late compilations
524544
*/
525-
private def totalProgress()(using Context): Int =
526-
totalTraversals * run.files.size + run.lateFiles.size
545+
private def totalProgress(): Int =
546+
totalTraversals * work() + run.lateFiles.size
547+
548+
private def work(): Int = run.files.size
527549

528550
private def requireInitialized(): Unit =
529551
require((currPhase: Phase | Null) != null, "enterPhase was not called")
530552

531-
/** trace that we are beginning a unit in the current (sub)phase */
532-
private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit =
533-
requireInitialized()
534-
cb.informUnitStarting(currPhaseName, unit)
553+
private[Run] def checkCancellation(): Boolean =
554+
if Thread.interrupted() then cancel()
555+
isCancelled()
556+
557+
/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
558+
private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean =
559+
if checkCancellation() then false
560+
else
561+
requireInitialized()
562+
cb.informUnitStarting(currPhaseName, unit)
563+
true
535564

536565
/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
537566
private[Run] def refreshProgress()(using Context): Unit =
538567
requireInitialized()
539-
cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName)
568+
val total = totalProgress()
569+
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
570+
cancel()
540571

541572
extension (run: Run | Null)
542573

543574
/** record that the current phase has begun for the compilation unit of the current Context */
544-
def beginUnit()(using Context): Unit =
545-
if run != null then run.doBeginUnit()
575+
def enterUnit()(using Context): Boolean =
576+
if run != null then run.didEnterUnit()
577+
else true // don't check cancellation if we're not tracking progress
546578

547579
/** advance the unit count and record progress in the current phase */
548580
def advanceUnit()(using Context): Unit =

compiler/src/dotty/tools/dotc/core/Phases.scala

+35-17
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,20 @@ object Phases {
326326

327327
/** @pre `isRunnable` returns true */
328328
def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] =
329-
units.map { unit =>
329+
val buf = List.newBuilder[CompilationUnit]
330+
for unit <- units do
330331
given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports
331-
ctx.run.beginUnit()
332-
try run
333-
catch case ex: Throwable if !ctx.run.enrichedErrorMessage =>
334-
println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit"))
335-
throw ex
336-
finally ctx.run.advanceUnit()
337-
unitCtx.compilationUnit
338-
}
332+
if ctx.run.enterUnit() then
333+
try run
334+
catch case ex: Throwable if !ctx.run.enrichedErrorMessage =>
335+
println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit"))
336+
throw ex
337+
finally ctx.run.advanceUnit()
338+
buf += unitCtx.compilationUnit
339+
end if
340+
end for
341+
buf.result()
342+
end runOn
339343

340344
/** Convert a compilation unit's tree to a string; can be overridden */
341345
def show(tree: untpd.Tree)(using Context): String =
@@ -448,14 +452,28 @@ object Phases {
448452
Iterator.iterate(this)(_.next) takeWhile (_.hasNext)
449453

450454
/** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */
451-
final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Unit =
452-
ctx.run.beginUnit()
453-
try body
454-
catch
455-
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
456-
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
457-
throw ex
458-
finally ctx.run.advanceUnit()
455+
final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean =
456+
if ctx.run.enterUnit() then
457+
try {body; true}
458+
catch
459+
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
460+
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
461+
throw ex
462+
finally ctx.run.advanceUnit()
463+
else
464+
false
465+
466+
/** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */
467+
final def monitorOpt[T](doing: String)(body: Context ?=> Option[T])(using Context): Option[T] =
468+
if ctx.run.enterUnit() then
469+
try body
470+
catch
471+
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
472+
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
473+
throw ex
474+
finally ctx.run.advanceUnit()
475+
else
476+
None
459477

460478
override def toString: String = phaseName
461479
}

compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala

+8-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ class ReadTasty extends Phase {
2222
ctx.settings.fromTasty.value
2323

2424
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
25-
withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_)))
25+
withMode(Mode.ReadPositions) {
26+
val unitContexts = units.map(unit => ctx.fresh.setCompilationUnit(unit))
27+
unitContexts.flatMap(applyPhase()(using _))
28+
}
2629

27-
private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] =
28-
ctx.run.beginUnit()
29-
try readTASTY(unit)
30-
finally ctx.run.advanceUnit()
30+
private def applyPhase()(using Context): Option[CompilationUnit] = monitorOpt(phaseName):
31+
val unit = ctx.compilationUnit
32+
readTASTY(unit)
3133

3234
def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match {
3335
case unit: TASTYCompilationUnit =>
@@ -82,7 +84,7 @@ class ReadTasty extends Phase {
8284
}
8385
}
8486
case unit =>
85-
Some(unit)
87+
Some(unit)
8688
}
8789

8890
def run(using Context): Unit = unsupported("run")

compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Parser extends Phase {
2222
*/
2323
private[dotc] var firstXmlPos: SourcePosition = NoSourcePosition
2424

25-
def parse(using Context) = monitor("parser") {
25+
def parse(using Context): Boolean = monitor("parser") {
2626
val unit = ctx.compilationUnit
2727
unit.untpdTree =
2828
if (unit.isJava) new JavaParsers.JavaParser(unit.source).parse()
@@ -46,12 +46,15 @@ class Parser extends Phase {
4646
report.inform(s"parsing ${unit.source}")
4747
ctx.fresh.setCompilationUnit(unit).withRootImports
4848

49-
for given Context <- unitContexts do
50-
parse
49+
val unitContexts0 =
50+
for
51+
given Context <- unitContexts
52+
if parse
53+
yield ctx
5154

5255
record("parsedTrees", ast.Trees.ntrees)
5356

54-
unitContexts.map(_.compilationUnit)
57+
unitContexts0.map(_.compilationUnit)
5558
}
5659

5760
def run(using Context): Unit = unsupported("run")

compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ default void informUnitStarting(String phase, CompilationUnit unit) {}
1515
/** Record the current compilation progress.
1616
* @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate`
1717
* @param total `totalPhases * totalUnits + totalLate`
18-
* @return true if the compilation should continue (if false, then subsequent calls to `isCancelled()` will return true)
18+
* @return true if the compilation should continue (callers are expected to cancel if this returns false)
1919
*/
2020
default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; }
2121
}

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

+13-5
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,21 @@ class Checker extends Phase:
3131
override def isEnabled(using Context): Boolean =
3232
super.isEnabled && (ctx.settings.YcheckInit.value || ctx.settings.YcheckInitGlobal.value)
3333

34+
def traverse(traverser: InitTreeTraverser)(using Context): Boolean = monitor(phaseName):
35+
val unit = ctx.compilationUnit
36+
traverser.traverse(unit.tpdTree)
37+
3438
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
3539
val checkCtx = ctx.fresh.setPhase(this.start)
3640
val traverser = new InitTreeTraverser()
37-
for unit <- units do
38-
checkCtx.run.beginUnit()
39-
try traverser.traverse(unit.tpdTree)
40-
finally ctx.run.advanceUnit()
41+
val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit))
42+
43+
val unitContexts0 =
44+
for
45+
given Context <- unitContexts
46+
if traverse(traverser)
47+
yield ctx
48+
4149
val classes = traverser.getClasses()
4250

4351
if ctx.settings.YcheckInit.value then
@@ -46,7 +54,7 @@ class Checker extends Phase:
4654
if ctx.settings.YcheckInitGlobal.value then
4755
Objects.checkClasses(classes)(using checkCtx)
4856

49-
units
57+
unitContexts0.map(_.compilationUnit)
5058

5159
def run(using Context): Unit = unsupported("run")
5260

compiler/src/dotty/tools/dotc/typer/TyperPhase.scala

+25-16
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
3131
// Run regardless of parsing errors
3232
override def isRunnable(implicit ctx: Context): Boolean = true
3333

34-
def enterSyms(using Context): Unit = monitor("indexing") {
34+
def enterSyms(using Context): Boolean = monitor("indexing") {
3535
val unit = ctx.compilationUnit
3636
ctx.typer.index(unit.untpdTree)
3737
typr.println("entered: " + unit.source)
3838
}
3939

40-
def typeCheck(using Context): Unit = monitor("typechecking") {
40+
def typeCheck(using Context): Boolean = monitor("typechecking") {
4141
val unit = ctx.compilationUnit
4242
try
4343
if !unit.suspended then
@@ -49,7 +49,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
4949
catch case _: CompilationUnit.SuspendException => ()
5050
}
5151

52-
def javaCheck(using Context): Unit = monitor("checking java") {
52+
def javaCheck(using Context): Boolean = monitor("checking java") {
5353
val unit = ctx.compilationUnit
5454
if unit.isJava then
5555
JavaChecks.check(unit.tpdTree)
@@ -72,11 +72,14 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
7272
else
7373
newCtx
7474

75-
try
76-
for given Context <- unitContexts do
77-
enterSyms
78-
finally
79-
ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)"
75+
val unitContexts0 =
76+
try
77+
for
78+
given Context <- unitContexts
79+
if enterSyms
80+
yield ctx
81+
finally
82+
ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)"
8083

8184
ctx.base.parserPhase match {
8285
case p: ParserPhase =>
@@ -88,18 +91,24 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
8891
case _ =>
8992
}
9093

91-
try
92-
for given Context <- unitContexts do
93-
typeCheck
94-
finally
95-
ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)"
94+
val unitContexts1 =
95+
try
96+
for
97+
given Context <- unitContexts0
98+
if typeCheck
99+
yield ctx
100+
finally
101+
ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)"
96102

97103
record("total trees after typer", ast.Trees.ntrees)
98104

99-
for given Context <- unitContexts do
100-
javaCheck // after typechecking to avoid cycles
105+
val unitContexts2 =
106+
for
107+
given Context <- unitContexts1
108+
if javaCheck // after typechecking to avoid cycles
109+
yield ctx
101110

102-
val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper)
111+
val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper)
103112
ctx.run.nn.checkSuspendedUnits(newUnits)
104113
newUnits
105114

0 commit comments

Comments
 (0)