Skip to content

Commit 307d65b

Browse files
authored
Fix flowOn handling of thread context elements (#4431)
Fixes #4403 Fixes #4422 Fixes some other, similar bugs that weren't reported.
1 parent f0feb8e commit 307d65b

File tree

8 files changed

+339
-40
lines changed

8 files changed

+339
-40
lines changed

kotlinx-coroutines-core/common/src/channels/Produce.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ public fun <E> CoroutineScope.produce(
266266
produce(context, capacity, BufferOverflow.SUSPEND, start, onCompletion, block)
267267

268268
// Internal version of produce that is maximally flexible, but is not exposed through public API (too many params)
269+
// (scope + context1).produce(context2) == scope.produce(context1 + context2)
269270
internal fun <E> CoroutineScope.produce(
270271
context: CoroutineContext = EmptyCoroutineContext,
271272
capacity: Int = 0,

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,11 @@ internal suspend fun <T, V> withContextUndispatched(
217217
value: V,
218218
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
219219
block: suspend (V) -> T
220-
): T =
220+
): T = withCoroutineContext(newContext, countOrElement) {
221221
suspendCoroutineUninterceptedOrReturn { uCont ->
222-
withCoroutineContext(newContext, countOrElement) {
223-
block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
224-
}
222+
block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
225223
}
224+
}
226225

227226
// Continuation that links the caller with uCont with walkable CoroutineStackFrame
228227
private class StackFrameContinuation<T>(

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ import kotlin.coroutines.jvm.internal.CoroutineStackFrame
99
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
1010
* and copyable-thread-local facilities on JVM.
1111
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
12+
*
13+
* When [CopyableThreadContextElement] values are used, the logic for processing them is as follows:
14+
* - If [this] or [context] has a copyable thread-local value whose key is absent in [context] or [this],
15+
* it is [copied][CopyableThreadContextElement.copyForChild] to the new context.
16+
* - If [this] has a copyable thread-local value whose key is present in [context],
17+
* it is [merged][CopyableThreadContextElement.mergeForChild] with the one from [context].
18+
* - The other values are added to the new context as is, with [context] values taking precedence.
1219
*/
1320
@ExperimentalCoroutinesApi
1421
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
@@ -37,13 +44,14 @@ private fun CoroutineContext.hasCopyableElements(): Boolean =
3744

3845
/**
3946
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
40-
* The rules are the following:
41-
* - If neither context has CTCE, the sum of two contexts is returned
42-
* - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
43-
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
44-
* - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
45-
* - Every CTCE from the right-hand side context that hasn't been merged is copied
46-
* - Everything else is added to the resulting context as is.
47+
48+
* The rules are as follows:
49+
* - If both contexts have the same (by key) CTCE, they are [merged][CopyableThreadContextElement.mergeForChild].
50+
* - If [isNewCoroutine] is `true`, the CTCEs that one context has and the other does not are
51+
* [copied][CopyableThreadContextElement.copyForChild].
52+
* - If [isNewCoroutine] is `false`, then the CTCEs that the right context has and the left does not are copied,
53+
* but those that only the left context has are not copied but added to the resulting context as is.
54+
* - Every non-CTCE is added to the resulting context as is.
4755
*/
4856
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
4957
// Do we have something to copy left-hand side?
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package kotlinx.coroutines
2+
3+
import kotlinx.coroutines.testing.*
4+
import kotlin.coroutines.*
5+
import kotlin.test.*
6+
7+
class CoroutineScopeTestJvm: TestBase() {
8+
/**
9+
* Test the documented behavior of [CoroutineScope.newCoroutineContext] regarding the copyable context elements.
10+
*/
11+
@Test
12+
fun testNewCoroutineContextCopyableContextElements() {
13+
val ce1L = MyMutableContextElement("key1", "value1_l")
14+
val ce2L = MyMutableContextElement("key2", "value2_l")
15+
val ce2R = MyMutableContextElement("key2", "value2_r")
16+
val ce3R = MyMutableContextElement("key3", "value3_r")
17+
val nonce1L = CoroutineExceptionHandler { _, _ -> }
18+
val nonce2L = Dispatchers.Default
19+
val nonce2R = Dispatchers.IO
20+
val nonce3R = CoroutineName("name3_r")
21+
val leftContext = randomlyShuffledContext(ce1L, ce2L, nonce1L, nonce2L)
22+
val rightContext = randomlyShuffledContext(ce2R, ce3R, nonce2R, nonce3R)
23+
CoroutineScope(leftContext).newCoroutineContext(rightContext).let { ctx ->
24+
assertEquals("Copy of 'value1_l'", ctx[MyMutableContextElementKey("key1")]?.value)
25+
assertEquals("Merged 'value2_l' and 'value2_r'", ctx[MyMutableContextElementKey("key2")]?.value)
26+
assertEquals("Copy of 'value3_r'", ctx[MyMutableContextElementKey("key3")]?.value)
27+
assertSame(nonce1L, ctx[CoroutineExceptionHandler])
28+
assertSame(nonce2R, ctx[ContinuationInterceptor])
29+
assertSame(nonce3R, ctx[CoroutineName])
30+
}
31+
}
32+
33+
private fun randomlyShuffledContext(
34+
vararg elements: CoroutineContext.Element
35+
): CoroutineContext = elements.toList().shuffled().fold(EmptyCoroutineContext, CoroutineContext::plus)
36+
}
37+
38+
class MyMutableContextElementKey(val key: String): CoroutineContext.Key<MyMutableContextElement> {
39+
override fun equals(other: Any?): Boolean =
40+
this === other || other is MyMutableContextElementKey && key == other.key
41+
42+
override fun hashCode(): Int = key.hashCode()
43+
}
44+
45+
class MyMutableContextElement(
46+
val keyId: String,
47+
var value: String
48+
) : AbstractCoroutineContextElement(MyMutableContextElementKey(keyId)), CopyableThreadContextElement<String> {
49+
override fun updateThreadContext(context: CoroutineContext): String {
50+
return value
51+
}
52+
53+
override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
54+
value = oldState
55+
}
56+
57+
override fun toString(): String {
58+
return "MyMutableContextElement(keyId='$keyId', value='$value')"
59+
}
60+
61+
override fun equals(other: Any?): Boolean =
62+
this === other || other is MyMutableContextElement && keyId == other.keyId && value == other.value
63+
64+
override fun hashCode(): Int = 31 * key.hashCode() + value.hashCode()
65+
66+
override fun copyForChild(): CopyableThreadContextElement<String> =
67+
MyMutableContextElement(keyId, "Copy of '$value'")
68+
69+
override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext =
70+
MyMutableContextElement(
71+
keyId,
72+
"Merged '$value' and '${(overwritingElement as MyMutableContextElement).value}'"
73+
)
74+
}

kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,53 @@ class ThreadContextElementTest : TestBase() {
237237

238238
@Test
239239
fun testThreadLocalFlowOn() = runTest {
240-
val myData = MyData()
241-
myThreadLocal.set(myData)
242-
expect(1)
243-
flow {
244-
assertEquals(myData, myThreadLocal.get())
245-
emit(1)
240+
val parameters: List<Triple<CoroutineContext, Boolean, Boolean>> =
241+
listOf(EmptyCoroutineContext, Dispatchers.Default, Dispatchers.Unconfined).flatMap { dispatcher ->
242+
listOf(true, false).flatMap { doYield ->
243+
listOf(true, false).map { useThreadLocalInOuterContext ->
244+
Triple(dispatcher, doYield, useThreadLocalInOuterContext)
245+
}
246+
}
247+
}
248+
for ((dispatcher, doYield, useThreadLocalInOuterContext) in parameters) {
249+
try {
250+
testThreadLocalFlowOn(dispatcher, doYield, useThreadLocalInOuterContext)
251+
} catch (e: Throwable) {
252+
throw AssertionError("Failed with parameters: dispatcher=$dispatcher, " +
253+
"doYield=$doYield, " +
254+
"useThreadLocalInOuterContext=$useThreadLocalInOuterContext", e)
255+
}
256+
}
257+
}
258+
259+
private fun testThreadLocalFlowOn(
260+
extraFlowOnContext: CoroutineContext, doYield: Boolean, useThreadLocalInOuterContext: Boolean
261+
) = runTest {
262+
try {
263+
val myData1 = MyData()
264+
val myData2 = MyData()
265+
myThreadLocal.set(myData1)
266+
withContext(if (useThreadLocalInOuterContext) myThreadLocal.asContextElement() else EmptyCoroutineContext) {
267+
assertEquals(myData1, myThreadLocal.get())
268+
flow {
269+
repeat(5) {
270+
assertEquals(myData2, myThreadLocal.get())
271+
emit(1)
272+
if (doYield) yield()
273+
}
274+
}
275+
.flowOn(myThreadLocal.asContextElement(myData2) + extraFlowOnContext)
276+
.collect {
277+
if (useThreadLocalInOuterContext) {
278+
assertEquals(myData1, myThreadLocal.get())
279+
}
280+
}
281+
assertEquals(myData1, myThreadLocal.get())
282+
}
283+
assertEquals(myData1, myThreadLocal.get())
284+
} finally {
285+
myThreadLocal.set(null)
246286
}
247-
.flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
248-
.single()
249-
myThreadLocal.set(null)
250-
finish(2)
251287
}
252288
}
253289

kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,29 +132,27 @@ class ThreadContextMutableCopiesTest : TestBase() {
132132

133133
@Test
134134
fun testDataIsCopiedThroughFlowOnUndispatched() = runTest {
135-
expect(1)
136-
val root = MyMutableElement(ArrayList())
137-
val originalData = root.mutableData
135+
val originalData = mutableListOf("X")
136+
val root = MyMutableElement(originalData)
138137
flow {
139138
assertNotSame(originalData, threadLocalData.get())
139+
assertEquals(originalData, threadLocalData.get())
140140
emit(1)
141141
}
142142
.flowOn(root)
143143
.single()
144-
finish(2)
145144
}
146145

147146
@Test
148147
fun testDataIsCopiedThroughFlowOnDispatched() = runTest {
149-
expect(1)
150-
val root = MyMutableElement(ArrayList())
151-
val originalData = root.mutableData
148+
val originalData = mutableListOf("X")
149+
val root = MyMutableElement(originalData)
152150
flow {
153151
assertNotSame(originalData, threadLocalData.get())
152+
assertEquals(originalData, threadLocalData.get())
154153
emit(1)
155154
}
156155
.flowOn(root + Dispatchers.Default)
157156
.single()
158-
finish(2)
159157
}
160158
}

reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,14 @@ private class PublisherAsFlow<T : Any>(
6969

7070
override suspend fun collect(collector: FlowCollector<T>) {
7171
val collectContext = coroutineContext
72-
val newDispatcher = context[ContinuationInterceptor]
73-
if (newDispatcher == null || newDispatcher == collectContext[ContinuationInterceptor]) {
74-
// fast path -- subscribe directly in this dispatcher
75-
return collectImpl(collectContext + context, collector)
72+
val newContext = collectContext.newCoroutineContext(context)
73+
// quickest path: if the context has not changed, just subscribe inline
74+
if (newContext == collectContext) {
75+
return collectImpl(collectContext, collector)
7676
}
77+
// TODO: copy-paste/share the ChannelFlowOperatorImpl implementation for the same-dispatcher quick path
7778
// slow path -- produce in a separate dispatcher
78-
collectSlowPath(collector)
79-
}
80-
81-
private suspend fun collectSlowPath(collector: FlowCollector<T>) {
82-
coroutineScope {
83-
collector.emitAll(produceImpl(this + context))
84-
}
79+
super.collect(collector)
8580
}
8681

8782
private suspend fun collectImpl(injectContext: CoroutineContext, collector: FlowCollector<T>) {

0 commit comments

Comments
 (0)