Skip to content

Fix context support in Publisher.asFlow.flowOn #1774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.reactive
Expand Down Expand Up @@ -27,7 +27,7 @@ import kotlin.coroutines.*
* see its documentation for additional details.
*/
public fun <T : Any> Publisher<T>.asFlow(): Flow<T> =
PublisherAsFlow(this, 1)
PublisherAsFlow(this)

/**
* Transforms the given flow to a reactive specification compliant [Publisher].
Expand All @@ -39,30 +39,11 @@ public fun <T : Any> Flow<T>.asPublisher(): Publisher<T> = FlowAsPublisher(this)

private class PublisherAsFlow<T : Any>(
private val publisher: Publisher<T>,
capacity: Int
) : ChannelFlow<T>(EmptyCoroutineContext, capacity) {
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = 1
) : ChannelFlow<T>(context, capacity) {
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
PublisherAsFlow(publisher, capacity)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
// use another channel for conflation (cannot do openSubscription)
if (capacity < 0) return super.produceImpl(scope)
// Open subscription channel directly
val channel = publisher
.injectCoroutineContext(scope.coroutineContext)
.openSubscription(capacity)
val handle = scope.coroutineContext[Job]?.invokeOnCompletion(onCancelling = true) { cause ->
channel.cancel(cause?.let {
it as? CancellationException ?: CancellationException("Job was cancelled", it)
})
}
if (handle != null && handle !== NonDisposableHandle) {
(channel as SendChannel<*>).invokeOnClose {
handle.dispose()
}
}
return channel
}
PublisherAsFlow(publisher, context, capacity)

private val requestSize: Long
get() = when (capacity) {
Expand All @@ -73,8 +54,26 @@ private class PublisherAsFlow<T : Any>(
}

override suspend fun collect(collector: FlowCollector<T>) {
val collectContext = coroutineContext
val newDispatcher = context[ContinuationInterceptor]
if (newDispatcher == null || newDispatcher == collectContext[ContinuationInterceptor]) {
// fast path -- subscribe directly in this dispatcher
return collectImpl(collectContext + context, collector)
}
// slow path -- produce in a separate dispatcher
collectSlowPath(collector)
}

private suspend fun collectSlowPath(collector: FlowCollector<T>) {
coroutineScope {
collector.emitAll(produceImpl(this + context))
}
}

private suspend fun collectImpl(injectContext: CoroutineContext, collector: FlowCollector<T>) {
val subscriber = ReactiveSubscriber<T>(capacity, requestSize)
publisher.injectCoroutineContext(coroutineContext).subscribe(subscriber)
// inject subscribe context into publisher
publisher.injectCoroutineContext(injectContext).subscribe(subscriber)
try {
var consumed = 0L
while (true) {
Expand All @@ -90,9 +89,9 @@ private class PublisherAsFlow<T : Any>(
}
}

// The second channel here is used only for broadcast
// The second channel here is used for produceIn/broadcastIn and slow-path (dispatcher change)
override suspend fun collectTo(scope: ProducerScope<T>) =
collect(SendingCollector(scope.channel))
collectImpl(scope.coroutineContext, SendingCollector(scope.channel))
}

@Suppress("SubscriberImplementation")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class PublisherAsFlowTest : TestBase() {
7 -> try {
send(value)
} catch (e: CancellationException) {
finish(6)
expect(5)
throw e
}
else -> expectUnreached()
Expand All @@ -143,6 +143,6 @@ class PublisherAsFlowTest : TestBase() {
}
}
}
expect(5)
finish(6)
}
}
52 changes: 48 additions & 4 deletions reactive/kotlinx-coroutines-reactor/test/FlowAsFluxTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.reactive.*
import org.junit.Test
import reactor.core.publisher.Mono
import reactor.core.publisher.*
import reactor.util.context.Context
import kotlin.test.assertEquals
import kotlin.test.*

class FlowAsFluxTest : TestBase() {
@Test
fun testFlowToFluxContextPropagation() {
fun testFlowAsFluxContextPropagation() {
val flux = flow<String> {
(1..4).forEach { i -> emit(createMono(i).awaitFirst()) }
} .asFlux()
}
.asFlux()
.subscriberContext(Context.of(1, "1"))
.subscriberContext(Context.of(2, "2", 3, "3", 4, "4"))
val list = flux.collectList().block()!!
Expand All @@ -24,4 +25,47 @@ class FlowAsFluxTest : TestBase() {
val ctx = coroutineContext[ReactorContext]!!.context
ctx.getOrDefault(i, "noValue")
}

@Test
fun testFluxAsFlowContextPropagationWithFlowOn() = runTest {
expect(1)
Flux.create<String> {
it.next("OK")
it.complete()
}
.subscriberContext { ctx ->
expect(2)
assertEquals("CTX", ctx.get(1))
ctx
}
.asFlow()
.flowOn(ReactorContext(Context.of(1, "CTX")))
.collect {
expect(3)
assertEquals("OK", it)
}
finish(4)
}

@Test
fun testFluxAsFlowContextPropagationFromScope() = runTest {
expect(1)
withContext(ReactorContext(Context.of(1, "CTX"))) {
Flux.create<String> {
it.next("OK")
it.complete()
}
.subscriberContext { ctx ->
expect(2)
assertEquals("CTX", ctx.get(1))
ctx
}
.asFlow()
.collect {
expect(3)
assertEquals("OK", it)
}
}
finish(4)
}
}
43 changes: 43 additions & 0 deletions reactive/kotlinx-coroutines-reactor/test/FluxContextTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.reactor

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.reactive.*
import org.junit.*
import org.junit.Test
import reactor.core.publisher.*
import kotlin.test.*

class FluxContextTest : TestBase() {
private val dispatcher = newSingleThreadContext("FluxContextTest")

@After
fun tearDown() {
dispatcher.close()
}

@Test
fun testFluxCreateAsFlowThread() = runTest {
expect(1)
val mainThread = Thread.currentThread()
val dispatcherThread = withContext(dispatcher) { Thread.currentThread() }
assertTrue(dispatcherThread != mainThread)
Flux.create<String> {
assertEquals(dispatcherThread, Thread.currentThread())
it.next("OK")
it.complete()
}
.asFlow()
.flowOn(dispatcher)
.collect {
expect(2)
assertEquals("OK", it)
assertEquals(mainThread, Thread.currentThread())
}
finish(3)
}
}
43 changes: 43 additions & 0 deletions reactive/kotlinx-coroutines-rx2/test/FlowableContextTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.rx2

import io.reactivex.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.reactive.*
import org.junit.*
import org.junit.Test
import kotlin.test.*

class FlowableContextTest : TestBase() {
private val dispatcher = newSingleThreadContext("FlowableContextTest")

@After
fun tearDown() {
dispatcher.close()
}

@Test
fun testFlowableCreateAsFlowThread() = runTest {
expect(1)
val mainThread = Thread.currentThread()
val dispatcherThread = withContext(dispatcher) { Thread.currentThread() }
assertTrue(dispatcherThread != mainThread)
Flowable.create<String>({
assertEquals(dispatcherThread, Thread.currentThread())
it.onNext("OK")
it.onComplete()
}, BackpressureStrategy.BUFFER)
.asFlow()
.flowOn(dispatcher)
.collect {
expect(2)
assertEquals("OK", it)
assertEquals(mainThread, Thread.currentThread())
}
finish(3)
}
}