10
10
* additional information regarding copyright ownership.
11
11
*/
12
12
13
- package scala .tools .testkit
13
+ package scala .tools .testkit
14
14
15
15
import org .junit .Assert , Assert ._
16
16
import scala .reflect .ClassTag
17
17
import scala .runtime .ScalaRunTime .stringOf
18
- import scala .collection .GenIterable
19
18
import scala .collection .JavaConverters ._
20
19
import scala .collection .mutable
21
- import scala .concurrent .{Await , Awaitable , SyncVar , TimeoutException }
22
- import scala .util .Try
20
+ import scala .concurrent .{Await , Awaitable }
21
+ import scala .util .{ Failure , Success , Try }
23
22
import scala .util .Properties .isJavaAtLeast
24
23
import scala .util .control .NonFatal
24
+ import java .util .concurrent .{CountDownLatch , TimeUnit }
25
+ import java .util .concurrent .atomic .AtomicReference
25
26
import java .lang .ref ._
26
27
import java .lang .reflect .{Array => _ , _ }
27
28
import java .util .IdentityHashMap
@@ -129,14 +130,18 @@ object AssertUtil {
129
130
130
131
/** Assert no new threads, with some margin for arbitrary threads to exit. */
131
132
def assertZeroNetThreads (body : => Unit ): Unit = {
132
- val result = new SyncVar [Option [Throwable ]]
133
133
val group = new ThreadGroup (" junit" )
134
- def check () = {
134
+ try assertZeroNetThreads(group)(body)
135
+ finally group.destroy()
136
+ }
137
+ def assertZeroNetThreads [A ](group : ThreadGroup )(body : => A ): Try [A ] = {
138
+ val testDone = new CountDownLatch (1 )
139
+ def check (): Try [A ] = {
135
140
val beforeCount = group.activeCount
136
141
val beforeThreads = new Array [Thread ](beforeCount)
137
142
assertEquals(" Spurious early thread creation." , beforeCount, group.enumerate(beforeThreads))
138
143
139
- body
144
+ val outcome = Try ( body)
140
145
141
146
val afterCount = {
142
147
waitForIt(group.activeCount <= beforeCount, label = " after count" )
@@ -146,32 +151,47 @@ object AssertUtil {
146
151
assertEquals(" Spurious late thread creation." , afterCount, group.enumerate(afterThreads))
147
152
val staleThreads = afterThreads.toList.diff(beforeThreads)
148
153
// staleThreads.headOption.foreach(_.getStackTrace.foreach(println))
149
- assertEquals(staleThreads.mkString(" There are stale threads: " ," ," ," " ), beforeCount, afterCount)
150
- assertTrue(staleThreads.mkString(" There are stale threads: " ," ," ," " ), staleThreads.isEmpty)
154
+ val staleMessage = staleThreads.mkString(" There are stale threads: " ," ," ," " )
155
+ assertEquals(staleMessage, beforeCount, afterCount)
156
+ assertTrue(staleMessage, staleThreads.isEmpty)
157
+
158
+ outcome
151
159
}
152
- def test () = {
160
+ val result = new AtomicReference [Try [A ]]()
161
+ def test (): Try [A ] =
153
162
try {
154
- check()
155
- result.put(None )
156
- } catch {
157
- case t : Throwable => result.put(Some (t))
163
+ val checked = check()
164
+ result.set(checked)
165
+ checked
166
+ } finally {
167
+ testDone.countDown()
158
168
}
159
- }
160
- val timeout = 10 * 1000L // last chance timeout
169
+
170
+ val timeout = 10 * 1000L
161
171
val thread = new Thread (group, () => test())
162
- def resulted : Boolean = result.get(timeout).isDefined
172
+ def abort (): Try [A ] = {
173
+ group.interrupt()
174
+ new Failure (new AssertionError (" Test did not complete" ))
175
+ }
163
176
try {
164
177
thread.start()
165
- waitForIt(resulted, Slow , label = " test result" )
166
- val err = result.take(timeout)
167
- err.foreach(e => throw e)
178
+ waitForIt(testDone.getCount == 0 , Fast , label = " test result" )
179
+ if (testDone.await(timeout, TimeUnit .MILLISECONDS ))
180
+ result.get
181
+ else
182
+ abort()
168
183
} finally {
169
184
thread.join(timeout)
170
- group.destroy()
171
185
}
172
186
}
173
187
174
188
/** Wait for a condition, with a simple back-off strategy.
189
+ *
190
+ * This makes it easier to see hanging threads in development
191
+ * without tweaking a timeout parameter. Conversely, when a thread
192
+ * fails to make progress in a test environment, we allow the wait
193
+ * period to grow larger than usual, since a long wait for failure
194
+ * is acceptable.
175
195
*
176
196
* It would be nicer if what we're waiting for gave us
177
197
* a progress indicator: we don't care if something
@@ -213,9 +233,51 @@ object AssertUtil {
213
233
214
234
/** Like Await.ready but return false on timeout, true on completion, throw InterruptedException. */
215
235
def readyOrNot (awaitable : Awaitable [_]): Boolean = Try (Await .ready(awaitable, TestDuration .Standard )).isSuccess
236
+
237
+ def withoutATrace [A ](body : => A ) = NoTrace (body)
216
238
}
217
239
218
240
object TestDuration {
219
241
import scala .concurrent .duration .{Duration , SECONDS }
220
242
val Standard = Duration (4 , SECONDS )
221
243
}
244
+
245
+ /** Run a thunk, collecting uncaught exceptions from any spawned threads. */
246
+ class NoTrace [A ](body : => A ) extends Runnable {
247
+
248
+ private val uncaught = new mutable.ListBuffer [(Thread , Throwable )]()
249
+
250
+ @ volatile private [testkit] var result : Option [A ] = None
251
+
252
+ def run (): Unit = {
253
+ import AssertUtil .assertZeroNetThreads
254
+ val group = new ThreadGroup (" notrace" ) {
255
+ override def uncaughtException (t : Thread , e : Throwable ): Unit = synchronized {
256
+ uncaught += ((t, e))
257
+ }
258
+ }
259
+ try assertZeroNetThreads(group)(body) match {
260
+ case Success (a) => result = Some (a)
261
+ case Failure (e) => synchronized { uncaught += ((Thread .currentThread, e)) }
262
+ }
263
+ finally group.destroy()
264
+ }
265
+
266
+ private [testkit] lazy val errors : List [(Thread , Throwable )] = synchronized (uncaught.toList)
267
+
268
+ private def suppress (t : Throwable , other : Throwable ): t.type = { t.addSuppressed(other) ; t }
269
+
270
+ private final val noError = None : Option [Throwable ]
271
+
272
+ def asserted : Option [Throwable ] =
273
+ errors.collect { case (_, e : AssertionError ) => e }
274
+ .foldLeft(noError)((res, e) => res.map(suppress(_, e)).orElse(Some (e)))
275
+
276
+ def apply (test : (Option [A ], List [(Thread , Throwable )]) => Option [Throwable ]) = {
277
+ run()
278
+ test(result, errors).orElse(asserted).foreach(e => throw e)
279
+ }
280
+ }
281
+ object NoTrace {
282
+ def apply [A ](body : => A ): NoTrace [A ] = new NoTrace (body)
283
+ }
0 commit comments