@@ -15,12 +15,15 @@ import Trees.*
15
15
import Types .*
16
16
import Symbols .*
17
17
import Names .*
18
+ import StdNames .str
18
19
import NameOps .*
19
20
import inlines .Inlines
20
21
import transform .ValueClasses
21
- import dotty .tools .io .{File , FileExtension }
22
+ import dotty .tools .io .{File , FileExtension , JarArchive }
23
+ import util .{Property , SourceFile }
22
24
import java .io .PrintWriter
23
25
26
+ import ExtractAPI .NonLocalClassSymbolsInCurrentUnits
24
27
25
28
import scala .collection .mutable
26
29
import scala .util .hashing .MurmurHash3
@@ -64,13 +67,59 @@ class ExtractAPI extends Phase {
64
67
// definitions, and `PostTyper` does not change definitions).
65
68
override def runsAfter : Set [String ] = Set (transform.PostTyper .name)
66
69
70
+ override def runOn (units : List [CompilationUnit ])(using Context ): List [CompilationUnit ] =
71
+ val nonLocalClassSymbols = new mutable.HashSet [Symbol ]
72
+ val ctx0 = ctx.withProperty(NonLocalClassSymbolsInCurrentUnits , Some (nonLocalClassSymbols))
73
+ val units0 = super .runOn(units)(using ctx0)
74
+ ctx.withIncCallback(recordNonLocalClasses(nonLocalClassSymbols, _))
75
+ units0
76
+ end runOn
77
+
78
+ private def recordNonLocalClasses (nonLocalClassSymbols : mutable.HashSet [Symbol ], cb : interfaces.IncrementalCallback )(using Context ): Unit =
79
+ for cls <- nonLocalClassSymbols do
80
+ val sourceFile = cls.source
81
+ if sourceFile.exists && cls.isDefinedInCurrentRun then
82
+ recordNonLocalClass(cls, sourceFile, cb)
83
+ cb.apiPhaseCompleted()
84
+ cb.dependencyPhaseCompleted()
85
+
86
+ private def recordNonLocalClass (cls : Symbol , sourceFile : SourceFile , cb : interfaces.IncrementalCallback )(using Context ): Unit =
87
+ def registerProductNames (fullClassName : String , binaryClassName : String ) =
88
+ val pathToClassFile = s " ${binaryClassName.replace('.' , java.io.File .separatorChar)}.class "
89
+
90
+ val classFile = {
91
+ ctx.settings.outputDir.value match {
92
+ case jar : JarArchive =>
93
+ new java.io.File (s " $jar! $pathToClassFile" )
94
+ case outputDir =>
95
+ new java.io.File (outputDir.file, pathToClassFile)
96
+ }
97
+ }
98
+
99
+ cb.generatedNonLocalClass(sourceFile, classFile.toPath(), binaryClassName, fullClassName)
100
+ end registerProductNames
101
+
102
+ val fullClassName = atPhase(sbtExtractDependenciesPhase) {
103
+ ExtractDependencies .classNameAsString(cls)
104
+ }
105
+ val binaryClassName = cls.binaryClassName
106
+ registerProductNames(fullClassName, binaryClassName)
107
+
108
+ // Register the names of top-level module symbols that emit two class files
109
+ val isTopLevelUniqueModule =
110
+ cls.owner.is(PackageClass ) && cls.is(ModuleClass ) && cls.companionClass == NoSymbol
111
+ if isTopLevelUniqueModule then
112
+ registerProductNames(fullClassName, binaryClassName.stripSuffix(str.MODULE_SUFFIX ))
113
+ end recordNonLocalClass
114
+
67
115
override def run (using Context ): Unit = {
68
116
val unit = ctx.compilationUnit
69
117
val sourceFile = unit.source
70
118
ctx.withIncCallback: cb =>
71
119
cb.startSource(sourceFile)
72
120
73
- val apiTraverser = new ExtractAPICollector
121
+ val nonLocalClassSymbols = ctx.property(NonLocalClassSymbolsInCurrentUnits ).get
122
+ val apiTraverser = ExtractAPICollector (nonLocalClassSymbols)
74
123
val classes = apiTraverser.apiSource(unit.tpdTree)
75
124
val mainClasses = apiTraverser.mainClasses
76
125
@@ -94,6 +143,8 @@ object ExtractAPI:
94
143
val name : String = " sbt-api"
95
144
val description : String = " sends a representation of the API of classes to sbt"
96
145
146
+ private val NonLocalClassSymbolsInCurrentUnits : Property .Key [mutable.HashSet [Symbol ]] = Property .Key ()
147
+
97
148
/** Extracts full (including private members) API representation out of Symbols and Types.
98
149
*
99
150
* The exact representation used for each type is not important: the only thing
@@ -136,7 +187,7 @@ object ExtractAPI:
136
187
* without going through an intermediate representation, see
137
188
* http://www.scala-sbt.org/0.13/docs/Understanding-Recompilation.html#Hashing+an+API+representation
138
189
*/
139
- private class ExtractAPICollector (using Context ) extends ThunkHolder {
190
+ private class ExtractAPICollector (nonLocalClassSymbols : mutable. HashSet [ Symbol ])( using Context ) extends ThunkHolder {
140
191
import tpd .*
141
192
import xsbti .api
142
193
@@ -254,6 +305,8 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
254
305
childrenOfSealedClass, topLevel, tparams)
255
306
256
307
allNonLocalClassesInSrc += cl
308
+ if ! sym.isLocal then
309
+ nonLocalClassSymbols += sym
257
310
258
311
if (sym.isStatic && ! sym.is(Trait ) && ctx.platform.hasMainMethod(sym)) {
259
312
// If sym is an object, all main methods count, otherwise only @static ones count.
0 commit comments