Skip to content

Performance improvement: use provisional state for better cache reuse; refactor TermRef widening logic #21278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 129 additions & 36 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,39 +118,101 @@ object Types extends TypeUtils {
* a call to `resetInst`. This means all caches that rely on `isProvisional`
* can likely end up returning stale results.
*/
def isProvisional(using Context): Boolean = mightBeProvisional && testProvisional

private def testProvisional(using Context): Boolean =
def isProvisional(using Context): Boolean = mightBeProvisional && currentProvisionalState != null

// The provisonal state of a type stores the parts which might be changed and their
// info at a given point.
// For example, a `TypeVar` is provisional until it is permently instantiated,
// and its info is the current instantiation.
type ProvisionalState = util.HashMap[Type, Type] | Null

def currentProvisionalState(using Context): ProvisionalState =
Copy link
Contributor

@EugeneFlesselle EugeneFlesselle Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, it is important in the current implementation that we recompute the state from scratch each time, and not try to reuse the result of the previous currentProvisionalState computation itself. It might be worth documenting this too

var state: ProvisionalState = null
inline def record(tp: Type, info: Type): Unit =
if state == null then state = util.HashMap()
state.uncheckedNN(tp) = info
// Compared to previous `testProvisional`, we don't use short-circuiting or (`||`),
// because we want to collect all provisional types.
class ProAcc extends TypeAccumulator[Boolean]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to keep the Boolean accumulator in the new scheme with a state? Why not use a TypeTraverser?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the recursive result to set mightBeProvisional field

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but then maybe TypeAccumulator[ProvisionalState] and mightBeProvisional = x.nonEmpty?

override def apply(x: Boolean, t: Type) = x || test(t, this)
override def apply(x: Boolean, t: Type) = x | test(t, this)
def test(t: Type, theAcc: TypeAccumulator[Boolean] | Null): Boolean =
if t.mightBeProvisional then
t.mightBeProvisional = t match
case t: TypeRef =>
t.currentSymbol.isProvisional || !t.currentSymbol.isStatic && {
if t.currentSymbol.isProvisional then
// When t is a TypeRef and its symbol is provisional,
// t will be considered provisional and its state is always updating.
// We store itself as info.
record(t, t)
true
else if !t.currentSymbol.isStatic then
(t: Type).mightBeProvisional = false // break cycles
test(t.prefix, theAcc)
|| t.denot.infoOrCompleter.match
case info: LazyType => true
case info: AliasingBounds => test(info.alias, theAcc)
case TypeBounds(lo, hi) => test(lo, theAcc) || test(hi, theAcc)
case _ => false
}
if test(t.prefix, theAcc) then
// If the prefix is provisional, some provisional type from it
// must have been added to state, so we don't need to add t.
true
else t.denot.infoOrCompleter.match
case info: LazyType =>
record(t, info)
true
case info: AliasingBounds =>
test(info.alias, theAcc)
case TypeBounds(lo, hi) =>
test(lo, theAcc) | test(hi, theAcc)
case _ =>
// If a TypeRef has been fully completed, it is no longer provisional,
// so we don't need to traverse its info.
false
else false
case t: TermRef =>
!t.currentSymbol.isStatic && test(t.prefix, theAcc)
case t: AppliedType =>
t.fold(false, (x, tp) => x || test(tp, theAcc))
t.fold(false, (x, tp) => x | test(tp, theAcc))
case t: TypeVar =>
!t.isPermanentlyInstantiated || test(t.permanentInst, theAcc)
if t.isPermanentlyInstantiated then
test(t.permanentInst, theAcc)
else
val inst = t.instanceOpt
if inst.exists then
// We want to store the temporary instance to the state
// in order to reuse the cache when possible.
record(t, inst)
test(inst, theAcc)
else
// When t is a TypeVar and does not have an instantiation,
// we store itself as info.
record(t, t)
true
case t: LazyRef =>
!t.completed || test(t.ref, theAcc)
if !t.completed then
// When t is a LazyRef and is not completed,
// we store itself as info.
record(t, t)
true
else
test(t.ref, theAcc)
case _ =>
(if theAcc != null then theAcc else ProAcc()).foldOver(false, t)
end if
t.mightBeProvisional
end test
test(this, null)
end testProvisional
state
end currentProvisionalState

def isStateUpToDate(
currentState: ProvisionalState,
lastState: ProvisionalState)
(using Context): Boolean =
(currentState eq lastState)
|| currentState != null && lastState != null
&& currentState.size == lastState.size
&& currentState.iterator.forall: (tp, info) =>
lastState.contains(tp) && {
tp match
case tp: TypeRef => (info ne tp) && (info eq lastState(tp))
case _ => info eq lastState(tp)
}

/** Is this type different from NoType? */
final def exists: Boolean = this.ne(NoType)
Expand Down Expand Up @@ -1311,7 +1373,8 @@ object Types extends TypeUtils {
final def widen(using Context): Type = this match
case _: TypeRef | _: MethodOrPoly => this // fast path for most frequent cases
case tp: TermRef => // fast path for next most frequent case
if tp.isOverloaded then tp else tp.underlying.widen
val denot = tp.denot
if denot.isOverloaded then tp else denot.info.widen
case tp: SingletonType => tp.underlying.widen
case tp: ExprType => tp.resultType.widen
case tp =>
Expand All @@ -1324,10 +1387,12 @@ object Types extends TypeUtils {
/** Widen from singleton type to its underlying non-singleton
* base type by applying one or more `underlying` dereferences.
*/
final def widenSingleton(using Context): Type = stripped match {
case tp: SingletonType if !tp.isOverloaded => tp.underlying.widenSingleton
final def widenSingleton(using Context): Type = stripped match
case tp: TermRef =>
val denot = tp.denot
if denot.isOverloaded then this else denot.info.widenSingleton
case tp: SingletonType => tp.underlying.widenSingleton
case _ => this
}

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
Expand Down Expand Up @@ -2305,10 +2370,12 @@ object Types extends TypeUtils {

private var myName: Name | Null = null
private var lastDenotation: Denotation | Null = null
private var lastDenotationProvState: ProvisionalState = null
private var lastSymbol: Symbol | Null = null
private var checkedPeriod: Period = Nowhere
private var myStableHash: Byte = 0
private var mySignature: Signature = uninitialized
private var mySignatureProvState: ProvisionalState = null
private var mySignatureRunId: Int = NoRunId

// Invariants:
Expand Down Expand Up @@ -2343,9 +2410,12 @@ object Types extends TypeUtils {
else if ctx.erasedTypes then atPhase(erasurePhase)(computeSignature)
else symbol.asSeenFrom(prefix).signature

if ctx.runId != mySignatureRunId then
val currentState = currentProvisionalState
if ctx.runId != mySignatureRunId
|| !isStateUpToDate(currentState, mySignatureProvState) then
mySignature = computeSignature
if !mySignature.isUnderDefined && !isProvisional then mySignatureRunId = ctx.runId
mySignatureProvState = currentState
if !mySignature.isUnderDefined then mySignatureRunId = ctx.runId
mySignature
end signature

Expand All @@ -2356,7 +2426,9 @@ object Types extends TypeUtils {
* some symbols change their signature at erasure.
*/
private def currentSignature(using Context): Signature =
if ctx.runId == mySignatureRunId then mySignature
if ctx.runId == mySignatureRunId
&& isStateUpToDate(currentProvisionalState, mySignatureProvState)
then mySignature
else
val lastd = lastDenotation
if lastd != null then sigFromDenot(lastd)
Expand All @@ -2376,7 +2448,9 @@ object Types extends TypeUtils {
final def symbol(using Context): Symbol =
// We can rely on checkedPeriod (unlike in the definition of `denot` below)
// because SymDenotation#installAfter never changes the symbol
if (checkedPeriod.code == ctx.period.code) lastSymbol.asInstanceOf[Symbol]
if checkedPeriod.code == ctx.period.code
&& isStateUpToDate(prefix.currentProvisionalState, lastDenotationProvState) then
lastSymbol.asInstanceOf[Symbol]
else computeSymbol

private def computeSymbol(using Context): Symbol =
Expand Down Expand Up @@ -2434,7 +2508,10 @@ object Types extends TypeUtils {
val lastd = lastDenotation.asInstanceOf[Denotation]
// Even if checkedPeriod == now we still need to recheck lastDenotation.validFor
// as it may have been mutated by SymDenotation#installAfter
if checkedPeriod.code != NowhereCode && lastd.validFor.contains(ctx.period) then lastd
if checkedPeriod.code != NowhereCode
&& lastd.validFor.contains(ctx.period)
&& isStateUpToDate(prefix.currentProvisionalState, lastDenotationProvState)
then lastd
else computeDenot

private def computeDenot(using Context): Denotation = {
Expand Down Expand Up @@ -2468,14 +2545,18 @@ object Types extends TypeUtils {
finish(symd.current)
}

def isLastDenotValid =
checkedPeriod.code != NowhereCode
&& isStateUpToDate(prefix.currentProvisionalState, lastDenotationProvState)

lastDenotation match {
case lastd0: SingleDenotation =>
val lastd = lastd0.skipRemoved
if lastd.validFor.runId == ctx.runId && checkedPeriod.code != NowhereCode then
if lastd.validFor.runId == ctx.runId && isLastDenotValid then
finish(lastd.current)
else lastd match {
case lastd: SymDenotation =>
if stillValid(lastd) && checkedPeriod.code != NowhereCode then finish(lastd.current)
if stillValid(lastd) && isLastDenotValid then finish(lastd.current)
else finish(memberDenot(lastd.initial.name, allowPrivate = false))
case _ =>
fromDesignator
Expand Down Expand Up @@ -2566,7 +2647,8 @@ object Types extends TypeUtils {

lastDenotation = denot
lastSymbol = denot.symbol
checkedPeriod = if (prefix.isProvisional) Nowhere else ctx.period
lastDenotationProvState = prefix.currentProvisionalState
checkedPeriod = ctx.period
designator match {
case sym: Symbol if designator ne lastSymbol.nn =>
designator = lastSymbol.asInstanceOf[Designator{ type ThisName = self.ThisName }]
Expand Down Expand Up @@ -3849,14 +3931,18 @@ object Types extends TypeUtils {
sealed abstract class MethodOrPoly extends UncachedGroundType with LambdaType with MethodicType {

// Invariants:
// (1) mySignatureRunId != NoRunId => mySignature != null
// (2) myJavaSignatureRunId != NoRunId => myJavaSignature != null
// (1) mySignatureRunId != NoRunId => mySignature != null
// (2) myJavaSignatureRunId != NoRunId => myJavaSignature != null
// (3) myScala2SignatureRunId != NoRunId => myScala2Signature != null

private var mySignature: Signature = uninitialized
private var mySignatureProvState: ProvisionalState = null
private var mySignatureRunId: Int = NoRunId
private var myJavaSignature: Signature = uninitialized
private var myJavaSignatureProvState: ProvisionalState = null
private var myJavaSignatureRunId: Int = NoRunId
private var myScala2Signature: Signature = uninitialized
private var myScala2SignatureProvState: ProvisionalState = null
private var myScala2SignatureRunId: Int = NoRunId

/** If `isJava` is false, the Scala signature of this method. Otherwise, its Java signature.
Expand Down Expand Up @@ -3892,21 +3978,28 @@ object Types extends TypeUtils {
case tp: PolyType =>
resultSignature.prependTypeParams(tp.paramNames.length)

val currentState = currentProvisionalState
sourceLanguage match
case SourceLanguage.Java =>
if ctx.runId != myJavaSignatureRunId then
if ctx.runId != myJavaSignatureRunId
|| !isStateUpToDate(currentState, myJavaSignatureProvState) then
myJavaSignature = computeSignature
if !myJavaSignature.isUnderDefined && !isProvisional then myJavaSignatureRunId = ctx.runId
myJavaSignatureProvState = currentState
if !myJavaSignature.isUnderDefined then myJavaSignatureRunId = ctx.runId
myJavaSignature
case SourceLanguage.Scala2 =>
if ctx.runId != myScala2SignatureRunId then
if ctx.runId != myScala2SignatureRunId
|| !isStateUpToDate(currentState, myScala2SignatureProvState) then
myScala2Signature = computeSignature
if !myScala2Signature.isUnderDefined && !isProvisional then myScala2SignatureRunId = ctx.runId
myScala2SignatureProvState = currentState
if !myScala2Signature.isUnderDefined then myScala2SignatureRunId = ctx.runId
myScala2Signature
case SourceLanguage.Scala3 =>
if ctx.runId != mySignatureRunId then
if ctx.runId != mySignatureRunId
|| !isStateUpToDate(currentState, mySignatureProvState) then
mySignature = computeSignature
if !mySignature.isUnderDefined && !isProvisional then mySignatureRunId = ctx.runId
mySignatureProvState = currentState
if !mySignature.isUnderDefined then mySignatureRunId = ctx.runId
mySignature
end signature

Expand Down
Loading
Loading