Skip to content

Commit 58810fd

Browse files
authored
Add support for xsbti.compile.CompileProgress (#18739)
Fixes #13082
2 parents 7f803ec + b510772 commit 58810fd

File tree

23 files changed

+928
-109
lines changed

23 files changed

+928
-109
lines changed

Diff for: compiler/src/dotty/tools/dotc/Run.scala

+211-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import typer.Typer
1212
import typer.ImportInfo.withRootImports
1313
import Decorators._
1414
import io.AbstractFile
15-
import Phases.unfusedPhases
15+
import Phases.{unfusedPhases, Phase}
16+
17+
import sbt.interfaces.ProgressCallback
1618

1719
import util._
1820
import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile}
@@ -32,6 +34,10 @@ import scala.collection.mutable
3234
import scala.util.control.NonFatal
3335
import scala.io.Codec
3436

37+
import Run.Progress
38+
import scala.compiletime.uninitialized
39+
import dotty.tools.dotc.transform.MegaPhase
40+
3541
/** A compiler run. Exports various methods to compile source files */
3642
class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo {
3743

@@ -155,14 +161,75 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
155161
}
156162

157163
/** The source files of all late entered symbols, as a set */
158-
private var lateFiles = mutable.Set[AbstractFile]()
164+
private val lateFiles = mutable.Set[AbstractFile]()
159165

160166
/** A cache for static references to packages and classes */
161167
val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024)
162168

163169
/** Actions that need to be performed at the end of the current compilation run */
164170
private var finalizeActions = mutable.ListBuffer[() => Unit]()
165171

172+
private var _progress: Progress | Null = null // Set if progress reporting is enabled
173+
174+
private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit =
175+
foldProgress(())(op)
176+
177+
private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T =
178+
val local = _progress
179+
if local != null then
180+
op(using ctx)(local)
181+
else
182+
default
183+
184+
def didEnterUnit(unit: CompilationUnit)(using Context): Boolean =
185+
foldProgress(true /* should progress by default */)(_.tryEnterUnit(unit))
186+
187+
def canProgress()(using Context): Boolean =
188+
foldProgress(true /* not cancelled by default */)(p => !p.checkCancellation())
189+
190+
def doAdvanceUnit()(using Context): Unit =
191+
trackProgress: progress =>
192+
progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase
193+
progress.refreshProgress()
194+
195+
def doAdvanceLate()(using Context): Unit =
196+
trackProgress: progress =>
197+
progress.currentLateUnitCount += 1 // trace that we completed a late compilation
198+
progress.refreshProgress()
199+
200+
private def doEnterPhase(currentPhase: Phase)(using Context): Unit =
201+
trackProgress: progress =>
202+
progress.enterPhase(currentPhase)
203+
204+
/** interrupt the thread and set cancellation state */
205+
private def cancelInterrupted(): Unit =
206+
try
207+
trackProgress(_.cancel())
208+
finally
209+
Thread.currentThread().nn.interrupt()
210+
211+
private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
212+
trackProgress: progress =>
213+
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
214+
progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial
215+
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
216+
if wasRan then
217+
// add an extra traversal now that we completed a (sub)phase
218+
progress.completedTraversalCount += 1
219+
else
220+
// no subphases were ran, remove traversals from expected total
221+
progress.totalTraversals -= currentPhase.traversals
222+
223+
private def tryAdvanceSubPhase()(using Context): Unit =
224+
trackProgress: progress =>
225+
if progress.canAdvanceSubPhase then
226+
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
227+
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
228+
progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
229+
progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
230+
if !progress.isCancelled() then
231+
progress.tickSubphase()
232+
166233
/** Will be set to true if any of the compiled compilation units contains
167234
* a pureFunctions language import.
168235
*/
@@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
233300
if ctx.settings.YnoDoubleBindings.value then
234301
ctx.base.checkNoDoubleBindings = true
235302

236-
def runPhases(using Context) = {
303+
def runPhases(allPhases: Array[Phase])(using Context) = {
237304
var lastPrintedTree: PrintedTree = NoPrintedTree
238305
val profiler = ctx.profiler
239306
var phasesWereAdjusted = false
240307

241-
for (phase <- ctx.base.allPhases)
242-
if (phase.isRunnable)
308+
for phase <- allPhases do
309+
doEnterPhase(phase)
310+
val phaseWillRun = phase.isRunnable
311+
if phaseWillRun then
243312
Stats.trackTime(s"phase time ms/$phase") {
244313
val start = System.currentTimeMillis
245314
val profileBefore = profiler.beforePhase(phase)
246-
units = phase.runOn(units)
315+
try units = phase.runOn(units)
316+
catch case _: InterruptedException => cancelInterrupted()
247317
profiler.afterPhase(phase, profileBefore)
248318
if (ctx.settings.Xprint.value.containsPhase(phase))
249319
for (unit <- units)
@@ -261,18 +331,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
261331
if !Feature.ccEnabledSomewhere then
262332
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev)
263333
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase)
264-
334+
end if
335+
end if
336+
end if
337+
doAdvancePhase(phase, wasRan = phaseWillRun)
338+
end for
265339
profiler.finished()
266340
}
267341

268342
val runCtx = ctx.fresh
269343
runCtx.setProfiler(Profiler())
270344
unfusedPhases.foreach(_.initContext(runCtx))
271-
runPhases(using runCtx)
345+
val fusedPhases = runCtx.base.allPhases
346+
runCtx.withProgressCallback: cb =>
347+
_progress = Progress(cb, this, fusedPhases.map(_.traversals).sum)
348+
runPhases(allPhases = fusedPhases)(using runCtx)
272349
if (!ctx.reporter.hasErrors)
273350
Rewrites.writeBack()
274351
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
275-
while (finalizeActions.nonEmpty) {
352+
while (finalizeActions.nonEmpty && canProgress()) {
276353
val action = finalizeActions.remove(0)
277354
action()
278355
}
@@ -294,10 +371,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
294371
.withRootImports
295372

296373
def process()(using Context) =
297-
ctx.typer.lateEnterUnit(doTypeCheck =>
298-
if typeCheck then
299-
if compiling then finalizeActions += doTypeCheck
300-
else doTypeCheck()
374+
ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
375+
if compiling then finalizeActions += doTypeCheck
376+
else doTypeCheck()
301377
)
302378

303379
process()(using unitCtx)
@@ -400,7 +476,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
400476
}
401477

402478
object Run {
479+
480+
case class SubPhase(val name: String):
481+
override def toString: String = name
482+
483+
class SubPhases(val phase: Phase):
484+
require(phase.exists)
485+
486+
private def baseName: String = phase match
487+
case phase: MegaPhase => phase.shortPhaseName
488+
case phase => phase.phaseName
489+
490+
val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]"))
491+
492+
def next(using Context): Option[SubPhases] =
493+
val next0 = phase.megaPhase.next.megaPhase
494+
if next0.exists then Some(SubPhases(next0))
495+
else None
496+
497+
def size: Int = all.size
498+
499+
def subPhase(index: Int) =
500+
if index < all.size then all(index)
501+
else baseName
502+
503+
504+
private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
505+
export cb.{cancel, isCancelled}
506+
507+
var totalTraversals: Int = initialTraversals // track how many phases we expect to run
508+
var currentUnitCount: Int = 0 // current unit count in the current (sub)phase
509+
var currentLateUnitCount: Int = 0 // current late unit count
510+
var completedTraversalCount: Int = 0 // completed traversals over all files
511+
var currentCompletedSubtraversalCount: Int = 0 // completed subphases in the current phase
512+
var seenPhaseCount: Int = 0 // how many phases we've seen so far
513+
514+
private var currPhase: Phase = uninitialized // initialized by enterPhase
515+
private var subPhases: SubPhases = uninitialized // initialized by enterPhase
516+
private var currPhaseName: String = uninitialized // initialized by enterPhase
517+
private var nextPhaseName: String = uninitialized // initialized by enterPhase
518+
519+
/** Enter into a new real phase, setting the current and next (sub)phases */
520+
def enterPhase(newPhase: Phase)(using Context): Unit =
521+
if newPhase ne currPhase then
522+
currPhase = newPhase
523+
subPhases = SubPhases(newPhase)
524+
tickSubphase()
525+
526+
def canAdvanceSubPhase: Boolean =
527+
currentCompletedSubtraversalCount + 1 < subPhases.size
528+
529+
/** Compute the current (sub)phase name and next (sub)phase name */
530+
def tickSubphase()(using Context): Unit =
531+
val index = currentCompletedSubtraversalCount
532+
val s = subPhases
533+
currPhaseName = s.subPhase(index)
534+
nextPhaseName =
535+
if index + 1 < s.all.size then s.subPhase(index + 1)
536+
else s.next match
537+
case None => "<end>"
538+
case Some(next0) => next0.subPhase(0)
539+
if seenPhaseCount > 0 then
540+
refreshProgress()
541+
542+
543+
/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
544+
private def currentProgress(): Int =
545+
completedTraversalCount * work() + currentUnitCount + currentLateUnitCount
546+
547+
/**Total progress is computed as the sum of
548+
* - the number of traversals we expect to make over all files
549+
* - the number of late compilations
550+
*/
551+
private def totalProgress(): Int =
552+
totalTraversals * work() + run.lateFiles.size
553+
554+
private def work(): Int = run.files.size
555+
556+
private def requireInitialized(): Unit =
557+
require((currPhase: Phase | Null) != null, "enterPhase was not called")
558+
559+
def checkCancellation(): Boolean =
560+
if Thread.interrupted() then cancel()
561+
isCancelled()
562+
563+
/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
564+
def tryEnterUnit(unit: CompilationUnit): Boolean =
565+
if checkCancellation() then false
566+
else
567+
requireInitialized()
568+
cb.informUnitStarting(currPhaseName, unit)
569+
true
570+
571+
/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
572+
def refreshProgress()(using Context): Unit =
573+
requireInitialized()
574+
val total = totalProgress()
575+
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
576+
cancel()
577+
403578
extension (run: Run | Null)
579+
580+
/** record that the current phase has begun for the compilation unit of the current Context */
581+
def enterUnit(unit: CompilationUnit)(using Context): Boolean =
582+
if run != null then run.didEnterUnit(unit)
583+
else true // don't check cancellation if we're not tracking progress
584+
585+
/** check progress cancellation, true if not cancelled */
586+
def enterRegion()(using Context): Boolean =
587+
if run != null then run.canProgress()
588+
else true // don't check cancellation if we're not tracking progress
589+
590+
/** advance the unit count and record progress in the current phase */
591+
def advanceUnit()(using Context): Unit =
592+
if run != null then run.doAdvanceUnit()
593+
594+
/** if there exists another subphase, switch to it and record progress */
595+
def enterNextSubphase()(using Context): Unit =
596+
if run != null then run.tryAdvanceSubPhase()
597+
598+
/** advance the late count and record progress in the current phase */
599+
def advanceLate()(using Context): Unit =
600+
if run != null then run.doAdvanceLate()
601+
404602
def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage
405603
def enrichErrorMessage(errorMessage: String)(using Context): String =
406604
if run == null then

Diff for: compiler/src/dotty/tools/dotc/core/Contexts.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import scala.annotation.internal.sharable
3434

3535
import DenotTransformers.DenotTransformer
3636
import dotty.tools.dotc.profile.Profiler
37-
import dotty.tools.dotc.sbt.interfaces.IncrementalCallback
37+
import dotty.tools.dotc.sbt.interfaces.{IncrementalCallback, ProgressCallback}
3838
import util.Property.Key
3939
import util.Store
4040
import plugins._
@@ -53,8 +53,9 @@ object Contexts {
5353
private val (notNullInfosLoc, store8) = store7.newLocation[List[NotNullInfo]]()
5454
private val (importInfoLoc, store9) = store8.newLocation[ImportInfo | Null]()
5555
private val (typeAssignerLoc, store10) = store9.newLocation[TypeAssigner](TypeAssigner)
56+
private val (progressCallbackLoc, store11) = store10.newLocation[ProgressCallback | Null]()
5657

57-
private val initialStore = store10
58+
private val initialStore = store11
5859

5960
/** The current context */
6061
inline def ctx(using ctx: Context): Context = ctx
@@ -177,6 +178,14 @@ object Contexts {
177178
val local = incCallback
178179
local != null && local.enabled || forceRun
179180

181+
/** The Zinc compile progress callback implementation if we are run from Zinc, null otherwise */
182+
def progressCallback: ProgressCallback | Null = store(progressCallbackLoc)
183+
184+
/** Run `op` if there exists a Zinc progress callback */
185+
inline def withProgressCallback(inline op: ProgressCallback => Unit): Unit =
186+
val local = progressCallback
187+
if local != null then op(local)
188+
180189
/** The current plain printer */
181190
def printerFn: Context => Printer = store(printerFnLoc)
182191

@@ -675,6 +684,7 @@ object Contexts {
675684

676685
def setCompilerCallback(callback: CompilerCallback): this.type = updateStore(compilerCallbackLoc, callback)
677686
def setIncCallback(callback: IncrementalCallback): this.type = updateStore(incCallbackLoc, callback)
687+
def setProgressCallback(callback: ProgressCallback): this.type = updateStore(progressCallbackLoc, callback)
678688
def setPrinterFn(printer: Context => Printer): this.type = updateStore(printerFnLoc, printer)
679689
def setSettings(settingsState: SettingsState): this.type = updateStore(settingsStateLoc, settingsState)
680690
def setRun(run: Run | Null): this.type = updateStore(runLoc, run)

0 commit comments

Comments
 (0)