Skip to content

Commit ad954b0

Browse files
committed
Fix thread context elements not being cleaned up after collect
Wait for the upstream coroutine to finish before restoring the thread context element state to the old one. Fixes #4403
1 parent e5bb191 commit ad954b0

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,11 @@ internal suspend fun <T, V> withContextUndispatched(
217217
value: V,
218218
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
219219
block: suspend (V) -> T
220-
): T =
220+
): T = withCoroutineContext(newContext, countOrElement) {
221221
suspendCoroutineUninterceptedOrReturn { uCont ->
222-
withCoroutineContext(newContext, countOrElement) {
223-
block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
224-
}
222+
block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
225223
}
224+
}
226225

227226
// Continuation that links the caller with uCont with walkable CoroutineStackFrame
228227
private class StackFrameContinuation<T>(

kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,53 @@ class ThreadContextElementTest : TestBase() {
237237

238238
@Test
239239
fun testThreadLocalFlowOn() = runTest {
240-
val myData = MyData()
241-
myThreadLocal.set(myData)
242-
expect(1)
243-
flow {
244-
assertEquals(myData, myThreadLocal.get())
245-
emit(1)
240+
val parameters: List<Triple<CoroutineContext, Boolean, Boolean>> =
241+
listOf(EmptyCoroutineContext, Dispatchers.Default, Dispatchers.Unconfined).flatMap { dispatcher ->
242+
listOf(true, false).flatMap { doYield ->
243+
listOf(true, false).map { useThreadLocalInOuterContext ->
244+
Triple(dispatcher, doYield, useThreadLocalInOuterContext)
245+
}
246+
}
247+
}
248+
for ((dispatcher, doYield, useThreadLocalInOuterContext) in parameters) {
249+
try {
250+
testThreadLocalFlowOn(dispatcher, doYield, useThreadLocalInOuterContext)
251+
} catch (e: Throwable) {
252+
throw AssertionError("Failed with parameters: dispatcher=$dispatcher, " +
253+
"doYield=$doYield, " +
254+
"useThreadLocalInOuterContext=$useThreadLocalInOuterContext", e)
255+
}
256+
}
257+
}
258+
259+
private fun testThreadLocalFlowOn(
260+
extraFlowOnContext: CoroutineContext, doYield: Boolean, useThreadLocalInOuterContext: Boolean
261+
) = runTest {
262+
try {
263+
val myData1 = MyData()
264+
val myData2 = MyData()
265+
myThreadLocal.set(myData1)
266+
withContext(if (useThreadLocalInOuterContext) myThreadLocal.asContextElement() else EmptyCoroutineContext) {
267+
assertEquals(myData1, myThreadLocal.get())
268+
flow {
269+
repeat(5) {
270+
assertEquals(myData2, myThreadLocal.get())
271+
emit(1)
272+
if (doYield) yield()
273+
}
274+
}
275+
.flowOn(myThreadLocal.asContextElement(myData2) + extraFlowOnContext)
276+
.collect {
277+
if (useThreadLocalInOuterContext) {
278+
assertEquals(myData1, myThreadLocal.get())
279+
}
280+
}
281+
assertEquals(myData1, myThreadLocal.get())
282+
}
283+
assertEquals(myData1, myThreadLocal.get())
284+
} finally {
285+
myThreadLocal.set(null)
246286
}
247-
.flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
248-
.single()
249-
myThreadLocal.set(null)
250-
finish(2)
251287
}
252288
}
253289

0 commit comments

Comments
 (0)