@@ -8,7 +8,11 @@ import kotlinx.coroutines.internal.*
8
8
import kotlinx.coroutines.selects.*
9
9
import kotlinx.coroutines.sync.*
10
10
import org.junit.*
11
+ import org.junit.Test
11
12
import java.util.concurrent.*
13
+ import java.util.concurrent.atomic.AtomicBoolean
14
+ import java.util.concurrent.atomic.AtomicInteger
15
+ import kotlin.test.*
12
16
13
17
class MutexCancellationStressTest : TestBase () {
14
18
@Test
@@ -18,13 +22,16 @@ class MutexCancellationStressTest : TestBase() {
18
22
val mutexOwners = Array (mutexJobNumber) { " $it " }
19
23
val dispatcher = Executors .newFixedThreadPool(mutexJobNumber + 2 ).asCoroutineDispatcher()
20
24
var counter = 0
21
- val counterLocal = Array (mutexJobNumber) { LocalAtomicInt (0 ) }
22
- val completed = LocalAtomicInt ( 0 )
25
+ val counterLocal = Array (mutexJobNumber) { AtomicInteger (0 ) }
26
+ val completed = AtomicBoolean ( false )
23
27
val mutexJobLauncher: (jobNumber: Int ) -> Job = { jobId ->
24
28
val coroutineName = " MutexJob-$jobId "
25
- launch(dispatcher + CoroutineName (coroutineName)) {
26
- while (completed.value == 0 ) {
29
+ // ATOMIC to always have a chance to proceed
30
+ launch(dispatcher + CoroutineName (coroutineName), CoroutineStart .ATOMIC ) {
31
+ while (! completed.get()) {
32
+ // Stress out holdsLock
27
33
mutex.holdsLock(mutexOwners[(jobId + 1 ) % mutexJobNumber])
34
+ // Stress out lock-like primitives
28
35
if (mutex.tryLock(mutexOwners[jobId])) {
29
36
counterLocal[jobId].incrementAndGet()
30
37
counter++
@@ -47,30 +54,32 @@ class MutexCancellationStressTest : TestBase() {
47
54
val mutexJobs = (0 until mutexJobNumber).map { mutexJobLauncher(it) }.toMutableList()
48
55
val checkProgressJob = launch(dispatcher + CoroutineName (" checkProgressJob" )) {
49
56
var lastCounterLocalSnapshot = (0 until mutexJobNumber).map { 0 }
50
- while (completed.value == 0 ) {
51
- delay(1000 )
57
+ while (! completed.get()) {
58
+ delay(500 )
59
+ // If we've caught the completion after delay, then there is a chance no progress were made whatsoever, bail out
60
+ if (completed.get()) return @launch
52
61
val c = counterLocal.map { it.value }
53
62
for (i in 0 until mutexJobNumber) {
54
- assert (c[i] > lastCounterLocalSnapshot[i]) { " No progress in MutexJob-$i " }
63
+ assert (c[i] > lastCounterLocalSnapshot[i]) { " No progress in MutexJob-$i , last observed state: ${c[i]} " }
55
64
}
56
65
lastCounterLocalSnapshot = c
57
66
}
58
67
}
59
68
val cancellationJob = launch(dispatcher + CoroutineName (" cancellationJob" )) {
60
69
var cancellingJobId = 0
61
- while (completed.value == 0 ) {
70
+ while (! completed.get() ) {
62
71
val jobToCancel = mutexJobs.removeFirst()
63
72
jobToCancel.cancelAndJoin()
64
73
mutexJobs + = mutexJobLauncher(cancellingJobId)
65
74
cancellingJobId = (cancellingJobId + 1 ) % mutexJobNumber
66
75
}
67
76
}
68
77
delay(2000L * stressTestMultiplier)
69
- completed.value = 1
78
+ completed.set( true )
70
79
cancellationJob.join()
71
80
mutexJobs.forEach { it.join() }
72
81
checkProgressJob.join()
73
- check (counter == counterLocal.sumOf { it.value })
82
+ assertEquals (counter, counterLocal.sumOf { it.value })
74
83
dispatcher.close()
75
84
}
76
85
}
0 commit comments