Skip to content

Commit 2f61ee2

Browse files
committed
Ensure to escape characters before constructing JSON profile trace
1 parent 91ef921 commit 2f61ee2

File tree

4 files changed

+191
-6
lines changed

4 files changed

+191
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package dotty.tools.dotc.profile
2+
3+
import scala.annotation.internal.sharable
4+
5+
// Based on NameTransformer but dedicated for JSON encoding rules
6+
object JsonNameTransformer {
7+
private val nops = 128
8+
private val ncodes = 26 * 26
9+
10+
@sharable private val op2code = new Array[String](nops)
11+
private def enterOp(op: Char, code: String) = op2code(op.toInt) = code
12+
13+
enterOp('\"', "\\\"")
14+
enterOp('\\', "\\\\")
15+
// enterOp('/', "\\/") // optional, no need for escaping outside of html context
16+
enterOp('\b', "\\b")
17+
enterOp('\f', "\\f")
18+
enterOp('\n', "\\n")
19+
enterOp('\r', "\\r")
20+
enterOp('\t', "\\t")
21+
22+
def encode(name: String): String = {
23+
var buf: StringBuilder = null.asInstanceOf
24+
val len = name.length
25+
var i = 0
26+
while (i < len) {
27+
val c = name(i)
28+
if (c < nops && (op2code(c.toInt) ne null)) {
29+
if (buf eq null) {
30+
buf = new StringBuilder()
31+
buf.append(name.subSequence(0, i))
32+
}
33+
buf.append(op2code(c.toInt))
34+
} else if (c <= 0x1F || c > 0x7F) {
35+
if (buf eq null) {
36+
buf = new StringBuilder()
37+
buf.append(name.subSequence(0, i))
38+
}
39+
buf.append("\\u%04X".format(c.toInt))
40+
} else if (buf ne null) {
41+
buf.append(c)
42+
}
43+
i += 1
44+
}
45+
if (buf eq null) name else buf.toString
46+
}
47+
}

Diff for: compiler/src/dotty/tools/dotc/profile/Profiler.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
273273
override def beforePhase(phase: Phase): (TracedEventId, ProfileSnap) = {
274274
assert(mainThread eq Thread.currentThread())
275275
traceThreadSnapshotCounters()
276-
val eventId = traceDurationStart(Category.Phase, phase.phaseName)
276+
val eventId = traceDurationStart(Category.Phase, escapeSpecialChars(phase.phaseName))
277277
if (ctx.settings.YprofileRunGcBetweenPhases.value.contains(phase.toString))
278278
doGC()
279279
if (ctx.settings.YprofileExternalTool.value.contains(phase.toString)) {
@@ -287,7 +287,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
287287
assert(mainThread eq Thread.currentThread())
288288
if chromeTrace != null then
289289
traceThreadSnapshotCounters()
290-
traceDurationStart(Category.File, unit.source.name)
290+
traceDurationStart(Category.File, escapeSpecialChars(unit.source.name))
291291
else TracedEventId.Empty
292292
}
293293

@@ -325,7 +325,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
325325
then EmptyCompletionEvent
326326
else
327327
val completionName = this.completionName(root, associatedFile)
328-
val event = TracedEventId(associatedFile.name)
328+
val event = TracedEventId(escapeSpecialChars(associatedFile.name))
329329
chromeTrace.traceDurationEventStart(Category.Completion.name, "", colour = "thread_state_sleeping")
330330
chromeTrace.traceDurationEventStart(Category.File.name, event)
331331
chromeTrace.traceDurationEventStart(Category.Completion.name, completionName)
@@ -350,8 +350,13 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
350350
if chromeTrace != null then
351351
chromeTrace.traceDurationEventEnd(category.name, event, colour)
352352

353-
private def symbolName(sym: Symbol): String = s"${sym.showKind} ${sym.showName}"
354-
private def completionName(root: Symbol, associatedFile: AbstractFile): String =
353+
private inline def escapeSpecialChars(value: String): String =
354+
JsonNameTransformer.encode(value)
355+
356+
private def symbolName(sym: Symbol): String = escapeSpecialChars:
357+
s"${sym.showKind} ${sym.showName}"
358+
359+
private def completionName(root: Symbol, associatedFile: AbstractFile): String = escapeSpecialChars:
355360
def isTopLevel = root.owner != NoSymbol && root.owner.is(Flags.Package)
356361
if root.is(Flags.Package) || isTopLevel
357362
then root.javaBinaryName

Diff for: compiler/test/dotty/tools/DottyTest.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ trait DottyTest extends ContextEscapeDetection {
4646

4747
protected def defaultCompiler: Compiler = new Compiler()
4848

49-
private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler {
49+
protected def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler {
5050

5151
private val baseCompiler = defaultCompiler
5252

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package dotty.tools.dotc.profile
2+
3+
import org.junit.Assert.*
4+
import org.junit.*
5+
6+
import scala.annotation.tailrec
7+
import dotty.tools.DottyTest
8+
import dotty.tools.dotc.util.SourceFile
9+
import dotty.tools.dotc.core.Contexts.FreshContext
10+
import java.nio.file.Files
11+
import java.util.Locale
12+
13+
class TraceNameManglingTest extends DottyTest {
14+
15+
override protected def initializeCtx(fc: FreshContext): Unit = {
16+
super.initializeCtx(fc)
17+
val tmpDir = Files.createTempDirectory("trace_name_mangling_test").nn
18+
fc.setSetting(fc.settings.YprofileEnabled, true)
19+
fc.setSetting(
20+
fc.settings.YprofileTrace,
21+
tmpDir.resolve("trace.json").nn.toAbsolutePath().toString()
22+
)
23+
fc.setSetting(
24+
fc.settings.YprofileDestination,
25+
tmpDir.resolve("profiler.out").nn.toAbsolutePath().toString()
26+
)
27+
}
28+
29+
@Test def escapeBackslashes(): Unit = {
30+
val isWindows = sys.props("os.name").toLowerCase(Locale.ROOT) == "windows"
31+
val filename = if isWindows then "/.scala" else "\\.scala"
32+
checkTraceEvents(
33+
"""
34+
|class /\ :
35+
| var /\ = ???
36+
|object /\{
37+
| def /\ = ???
38+
|}""".stripMargin,
39+
filename = filename
40+
)(
41+
Set(
42+
raw"class /\\",
43+
raw"object /\\",
44+
raw"method /\\",
45+
raw"variable /\\",
46+
raw"setter /\\_="
47+
).map(TraceEvent("typecheck", _))
48+
++ Set(
49+
TraceEvent("file", if isWindows then "/.scala" else "\\\\.scala")
50+
)
51+
)
52+
}
53+
54+
@Test def escapeDoubleQuotes(): Unit = {
55+
val filename = "\"quoted\".scala"
56+
checkTraceEvents(
57+
"""
58+
|class `"QuotedClass"`:
59+
| var `"quotedVar"` = ???
60+
|object `"QuotedObject"` {
61+
| def `"quotedMethod"` = ???
62+
|}""".stripMargin,
63+
filename = filename
64+
):
65+
Set(
66+
raw"class \"QuotedClass\"",
67+
raw"object \"QuotedObject\"",
68+
raw"method \"quotedMethod\"",
69+
raw"variable \"quotedVar\""
70+
).map(TraceEvent("typecheck", _))
71+
++ Set(TraceEvent("file", "\\\"quoted\\\".scala"))
72+
}
73+
@Test def escapeNonAscii(): Unit = {
74+
val filename = "unic😀de.scala"
75+
checkTraceEvents(
76+
"""
77+
|class ΩUnicodeClass:
78+
| var `中文Var` = ???
79+
|object ΩUnicodeObject {
80+
| def 中文Method = ???
81+
|}""".stripMargin,
82+
filename = filename
83+
):
84+
Set(
85+
"class \\u03A9UnicodeClass",
86+
"object \\u03A9UnicodeObject",
87+
"method \\u4E2D\\u6587Method",
88+
"variable \\u4E2D\\u6587Var"
89+
).map(TraceEvent("typecheck", _))
90+
++ Set(TraceEvent("file", "unic\\uD83D\\uDE00de.scala"))
91+
}
92+
93+
case class TraceEvent(category: String, name: String)
94+
private def compileWithTracer(
95+
code: String,
96+
filename: String,
97+
afterPhase: String = "typer"
98+
)(checkEvents: Seq[TraceEvent] => Unit) = {
99+
val runCtx = locally:
100+
val source = SourceFile.virtual(filename, code)
101+
val c = compilerWithChecker(afterPhase) { (_, _) => () }
102+
val run = c.newRun
103+
run.compileSources(List(source))
104+
run.runContext
105+
assert(!runCtx.reporter.hasErrors, "compilation failed")
106+
val outfile = ctx.settings.YprofileTrace.value
107+
checkEvents:
108+
scala.io.Source
109+
.fromFile(outfile)
110+
.getLines()
111+
.collect:
112+
case s"""${_}"cat":"${category}","name":${name},"ph":${_}""" =>
113+
TraceEvent(category, name.stripPrefix("\"").stripSuffix("\""))
114+
.distinct.toSeq
115+
}
116+
117+
private def checkTraceEvents(code: String, filename: String = "test")(expected: Set[TraceEvent]): Unit = {
118+
compileWithTracer(code, filename = filename, afterPhase = "typer"){ events =>
119+
val missing = expected.diff(events.toSet)
120+
def showFound = events
121+
.groupBy(_.category)
122+
.collect:
123+
case (category, events)
124+
if expected.exists(_.category == category) =>
125+
s"- $category: [${events.map(_.name).mkString(", ")}]"
126+
.mkString("\n")
127+
assert(
128+
missing.isEmpty,
129+
s"""Missing ${missing.size} names [${missing.mkString(", ")}] in events, got:\n${showFound}"""
130+
)
131+
}
132+
}
133+
}

0 commit comments

Comments
 (0)