Skip to content

Commit 26f34d0

Browse files
authored
Merge pull request scala#8087 from som-snytt/topic/test-cleanup
Suppress stack trace in test
2 parents b2a4255 + 966827b commit 26f34d0

File tree

3 files changed

+162
-30
lines changed

3 files changed

+162
-30
lines changed

src/testkit/scala/tools/testkit/AssertUtil.scala

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@
1010
* additional information regarding copyright ownership.
1111
*/
1212

13-
package scala.tools.testkit
13+
package scala.tools.testkit
1414

1515
import org.junit.Assert, Assert._
1616
import scala.reflect.ClassTag
1717
import scala.runtime.ScalaRunTime.stringOf
18-
import scala.collection.GenIterable
1918
import scala.collection.JavaConverters._
2019
import scala.collection.mutable
21-
import scala.concurrent.{Await, Awaitable, SyncVar, TimeoutException}
22-
import scala.util.Try
20+
import scala.concurrent.{Await, Awaitable}
21+
import scala.util.{Failure, Success, Try}
2322
import scala.util.Properties.isJavaAtLeast
2423
import scala.util.control.NonFatal
24+
import java.util.concurrent.{CountDownLatch, TimeUnit}
25+
import java.util.concurrent.atomic.AtomicReference
2526
import java.lang.ref._
2627
import java.lang.reflect.{Array => _, _}
2728
import java.util.IdentityHashMap
@@ -129,14 +130,18 @@ object AssertUtil {
129130

130131
/** Assert no new threads, with some margin for arbitrary threads to exit. */
131132
def assertZeroNetThreads(body: => Unit): Unit = {
132-
val result = new SyncVar[Option[Throwable]]
133133
val group = new ThreadGroup("junit")
134-
def check() = {
134+
try assertZeroNetThreads(group)(body)
135+
finally group.destroy()
136+
}
137+
def assertZeroNetThreads[A](group: ThreadGroup)(body: => A): Try[A] = {
138+
val testDone = new CountDownLatch(1)
139+
def check(): Try[A] = {
135140
val beforeCount = group.activeCount
136141
val beforeThreads = new Array[Thread](beforeCount)
137142
assertEquals("Spurious early thread creation.", beforeCount, group.enumerate(beforeThreads))
138143

139-
body
144+
val outcome = Try(body)
140145

141146
val afterCount = {
142147
waitForIt(group.activeCount <= beforeCount, label = "after count")
@@ -146,32 +151,47 @@ object AssertUtil {
146151
assertEquals("Spurious late thread creation.", afterCount, group.enumerate(afterThreads))
147152
val staleThreads = afterThreads.toList.diff(beforeThreads)
148153
//staleThreads.headOption.foreach(_.getStackTrace.foreach(println))
149-
assertEquals(staleThreads.mkString("There are stale threads: ",",",""), beforeCount, afterCount)
150-
assertTrue(staleThreads.mkString("There are stale threads: ",",",""), staleThreads.isEmpty)
154+
val staleMessage = staleThreads.mkString("There are stale threads: ",",","")
155+
assertEquals(staleMessage, beforeCount, afterCount)
156+
assertTrue(staleMessage, staleThreads.isEmpty)
157+
158+
outcome
151159
}
152-
def test() = {
160+
val result = new AtomicReference[Try[A]]()
161+
def test(): Try[A] =
153162
try {
154-
check()
155-
result.put(None)
156-
} catch {
157-
case t: Throwable => result.put(Some(t))
163+
val checked = check()
164+
result.set(checked)
165+
checked
166+
} finally {
167+
testDone.countDown()
158168
}
159-
}
160-
val timeout = 10 * 1000L // last chance timeout
169+
170+
val timeout = 10 * 1000L
161171
val thread = new Thread(group, () => test())
162-
def resulted: Boolean = result.get(timeout).isDefined
172+
def abort(): Try[A] = {
173+
group.interrupt()
174+
new Failure(new AssertionError("Test did not complete"))
175+
}
163176
try {
164177
thread.start()
165-
waitForIt(resulted, Slow, label = "test result")
166-
val err = result.take(timeout)
167-
err.foreach(e => throw e)
178+
waitForIt(testDone.getCount == 0, Fast, label = "test result")
179+
if (testDone.await(timeout, TimeUnit.MILLISECONDS))
180+
result.get
181+
else
182+
abort()
168183
} finally {
169184
thread.join(timeout)
170-
group.destroy()
171185
}
172186
}
173187

174188
/** Wait for a condition, with a simple back-off strategy.
189+
*
190+
* This makes it easier to see hanging threads in development
191+
* without tweaking a timeout parameter. Conversely, when a thread
192+
* fails to make progress in a test environment, we allow the wait
193+
* period to grow larger than usual, since a long wait for failure
194+
* is acceptable.
175195
*
176196
* It would be nicer if what we're waiting for gave us
177197
* a progress indicator: we don't care if something
@@ -213,9 +233,51 @@ object AssertUtil {
213233

214234
/** Like Await.ready but return false on timeout, true on completion, throw InterruptedException. */
215235
def readyOrNot(awaitable: Awaitable[_]): Boolean = Try(Await.ready(awaitable, TestDuration.Standard)).isSuccess
236+
237+
def withoutATrace[A](body: => A) = NoTrace(body)
216238
}
217239

218240
object TestDuration {
219241
import scala.concurrent.duration.{Duration, SECONDS}
220242
val Standard = Duration(4, SECONDS)
221243
}
244+
245+
/** Run a thunk, collecting uncaught exceptions from any spawned threads. */
246+
class NoTrace[A](body: => A) extends Runnable {
247+
248+
private val uncaught = new mutable.ListBuffer[(Thread, Throwable)]()
249+
250+
@volatile private[testkit] var result: Option[A] = None
251+
252+
def run(): Unit = {
253+
import AssertUtil.assertZeroNetThreads
254+
val group = new ThreadGroup("notrace") {
255+
override def uncaughtException(t: Thread, e: Throwable): Unit = synchronized {
256+
uncaught += ((t, e))
257+
}
258+
}
259+
try assertZeroNetThreads(group)(body) match {
260+
case Success(a) => result = Some(a)
261+
case Failure(e) => synchronized { uncaught += ((Thread.currentThread, e)) }
262+
}
263+
finally group.destroy()
264+
}
265+
266+
private[testkit] lazy val errors: List[(Thread, Throwable)] = synchronized(uncaught.toList)
267+
268+
private def suppress(t: Throwable, other: Throwable): t.type = { t.addSuppressed(other) ; t }
269+
270+
private final val noError = None: Option[Throwable]
271+
272+
def asserted: Option[Throwable] =
273+
errors.collect { case (_, e: AssertionError) => e }
274+
.foldLeft(noError)((res, e) => res.map(suppress(_, e)).orElse(Some(e)))
275+
276+
def apply(test: (Option[A], List[(Thread, Throwable)]) => Option[Throwable]) = {
277+
run()
278+
test(result, errors).orElse(asserted).foreach(e => throw e)
279+
}
280+
}
281+
object NoTrace {
282+
def apply[A](body: => A): NoTrace[A] = new NoTrace(body)
283+
}

test/junit/scala/sys/process/ProcessTest.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala.sys.process
22

3-
import java.io.{ByteArrayInputStream, File}
3+
import java.io.{ByteArrayInputStream, File, IOException}
44
import java.nio.file.{Files, Paths}, Files.createTempFile
55
import java.nio.charset.StandardCharsets.UTF_8
66

@@ -10,6 +10,7 @@ import scala.util.Try
1010
//import scala.sys.process._
1111
import scala.util.Properties._
1212
import scala.collection.JavaConverters._
13+
import scala.tools.testkit.AssertUtil._
1314

1415
import org.junit.runner.RunWith
1516
import org.junit.runners.JUnit4
@@ -89,7 +90,7 @@ class ProcessTest {
8990
Files.write(file, List(prefix).asJava, UTF_8)
9091
file
9192
}
92-
val file1 = Paths.get("total", "junk")
93+
val noFile = Paths.get("total", "junk")
9394
val p2 = new ProcessMock(false)
9495
val failed = new java.util.concurrent.atomic.AtomicBoolean
9596
val pb2 = new ProcessBuilderMock(p2, error = true) {
@@ -102,14 +103,23 @@ class ProcessTest {
102103
val out = createTempFile("out", "tmp")
103104
val outf = out.toFile
104105

105-
try {
106-
val p0 = (file1.toFile : ProcessBuilder.Source).cat #&& pb2
107-
val p = p0 #> outf
106+
def process =
107+
try {
108+
val p0 = (noFile.toFile : ProcessBuilder.Source).cat #&& pb2
109+
val p = p0 #> outf
108110

109-
assertEquals(1, p.!)
110-
assertFalse(failed.get)
111-
} finally {
112-
Files.delete(out)
111+
assertEquals(1, p.!)
112+
assertFalse(failed.get)
113+
} finally {
114+
Files.delete(out)
115+
}
116+
117+
def fail(why: String): Option[Throwable] = Some(new AssertionError(why))
118+
119+
withoutATrace(process) {
120+
case (None, _) => fail("No main result")
121+
case (_, (_, (_: IOException)) :: Nil) => None
122+
case (_, other) => fail(s"Expected one IOException, got $other")
113123
}
114124
}
115125
}

test/junit/scala/tools/testing/AssertUtilTest.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,64 @@ class AssertUtilTest {
1717
val r = new SoftReference(o)
1818
assertNotReachable(o, new Holder(r)) { }
1919
}
20+
21+
@Test def `asserts on child threads are suppressed`(): Unit = {
22+
def kickoff(body: => Unit): Unit = {
23+
val t = new Thread(() => body)
24+
t.start()
25+
t.join()
26+
}
27+
val sut = withoutATrace {
28+
kickoff(assertEquals(42, 17))
29+
kickoff(???)
30+
kickoff(assertEquals("hi", "bi"))
31+
kickoff(assertEquals("low", "brow"))
32+
}
33+
sut.run()
34+
sut.asserted match {
35+
case None => fail("Expected assertion errors")
36+
case Some(e) => assertEquals(2, e.getSuppressed.length)
37+
}
38+
}
39+
40+
@Test def `waits for child threads to complete`(): Unit = {
41+
import java.util.concurrent.CountDownLatch
42+
val latch = new CountDownLatch(1)
43+
def kickoff(body: => Unit): Unit = {
44+
val t = new Thread(() => body)
45+
t.start()
46+
}
47+
val sut = withoutATrace {
48+
kickoff {
49+
latch.await()
50+
assertEquals(42, 17) // must wait to see this
51+
}
52+
kickoff {
53+
Thread.sleep(100L) // make them wait for it
54+
latch.countDown()
55+
}
56+
kickoff(assertEquals("hi", "bi")) // ordinary background thread assertion
57+
assertEquals("low", "brow") // "foreground" thread assertion must be handled
58+
}
59+
sut.run()
60+
sut.asserted match {
61+
case None => fail("Expected assertion errors")
62+
case Some(e) => assertEquals(2, e.getSuppressed.length)
63+
}
64+
}
65+
66+
@Test def `result is returned`(): Unit = {
67+
def kickoff(body: => Unit): Unit = new Thread(() => body).start()
68+
def f() = {
69+
kickoff {
70+
assertEquals(42, 17)
71+
}
72+
27
73+
}
74+
val sut = withoutATrace(f())
75+
sut.run()
76+
assertEquals(Some(27), sut.result)
77+
assertEquals(1, sut.errors.size)
78+
assertEquals(0, sut.errors.head._2.getSuppressed.length)
79+
}
2080
}

0 commit comments

Comments
 (0)