Skip to content

Commit 1632cdf

Browse files
authored
KTOR-8409 & KTOR-7947 Fix closing of OkHttp SSE connection (#4811)
1 parent d491930 commit 1632cdf

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

ktor-client/ktor-client-okhttp/jvm/src/io/ktor/client/engine/okhttp/OkHttpSSESession.kt

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ package io.ktor.client.engine.okhttp
77
import io.ktor.client.plugins.sse.*
88
import io.ktor.http.*
99
import io.ktor.sse.*
10-
import kotlinx.coroutines.CancellationException
11-
import kotlinx.coroutines.CompletableDeferred
10+
import kotlinx.coroutines.*
1211
import kotlinx.coroutines.channels.Channel
1312
import kotlinx.coroutines.channels.onFailure
1413
import kotlinx.coroutines.channels.trySendBlocking
1514
import kotlinx.coroutines.flow.Flow
1615
import kotlinx.coroutines.flow.consumeAsFlow
16+
import kotlinx.coroutines.flow.onCompletion
1717
import okhttp3.OkHttpClient
1818
import okhttp3.Request
1919
import okhttp3.Response
@@ -22,18 +22,39 @@ import okhttp3.sse.EventSourceListener
2222
import okhttp3.sse.EventSources
2323
import kotlin.coroutines.CoroutineContext
2424

25-
internal class OkHttpSSESession(
26-
engine: OkHttpClient,
25+
internal class OkHttpSSESession private constructor(
26+
factory: EventSource.Factory,
2727
engineRequest: Request,
2828
override val coroutineContext: CoroutineContext,
2929
) : SSESession, EventSourceListener() {
30-
private val serverSentEventsSource = EventSources.createFactory(engine).newEventSource(engineRequest, this)
30+
31+
constructor(
32+
engine: OkHttpClient,
33+
engineRequest: Request,
34+
callContext: CoroutineContext,
35+
) : this(
36+
factory = EventSources.createFactory(engine),
37+
engineRequest = engineRequest,
38+
coroutineContext = callContext + Job() + CoroutineName("OkHttpSSESession"),
39+
)
40+
41+
private val serverSentEventsSource = factory.newEventSource(engineRequest, this)
3142

3243
internal val originResponse: CompletableDeferred<Response> = CompletableDeferred()
3344

3445
private val _incoming = Channel<ServerSentEvent>(8)
3546

3647
override val incoming: Flow<ServerSentEvent> = _incoming.consumeAsFlow()
48+
.onCompletion { cause ->
49+
// Use onCompletion operator to handle CancellationExceptions which occur in downstream flow.
50+
if (cause is CancellationException) close(cause = null)
51+
}
52+
53+
init {
54+
coroutineContext.job.invokeOnCompletion {
55+
close(cause = null)
56+
}
57+
}
3758

3859
override fun onOpen(eventSource: EventSource, response: Response) {
3960
originResponse.complete(response)
@@ -52,6 +73,7 @@ internal class OkHttpSSESession(
5273
(statusCode != HttpStatusCode.OK.value || contentType != ContentType.Text.EventStream.toString())
5374
) {
5475
originResponse.complete(response)
76+
close(cause = null)
5577
} else {
5678
val error = t?.let {
5779
SSEClientException(
@@ -60,15 +82,19 @@ internal class OkHttpSSESession(
6082
)
6183
} ?: mapException(response)
6284
originResponse.completeExceptionally(error)
85+
close(cause = error)
6386
}
64-
65-
_incoming.close()
66-
serverSentEventsSource.cancel()
6787
}
6888

6989
override fun onClosed(eventSource: EventSource) {
70-
_incoming.close()
90+
close(cause = null)
91+
}
92+
93+
private fun close(cause: Throwable?) {
94+
_incoming.close(cause)
7195
serverSentEventsSource.cancel()
96+
// Cancel context last so 'invokeOnCompletion' doesn't override 'cause' for closing '_incoming'.
97+
coroutineContext.cancel()
7298
}
7399

74100
private fun mapException(response: Response?): SSEClientException {

ktor-client/ktor-client-okhttp/jvm/test/io/ktor/client/engine/okhttp/OkHttpHttpClientTest.kt

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,94 @@
44

55
package io.ktor.client.engine.okhttp
66

7+
import io.ktor.client.*
8+
import io.ktor.client.plugins.sse.*
9+
import io.ktor.client.test.base.*
710
import io.ktor.client.tests.*
11+
import io.ktor.network.sockets.*
12+
import kotlinx.coroutines.Job
13+
import kotlinx.coroutines.delay
14+
import kotlinx.coroutines.flow.collect
15+
import kotlinx.coroutines.launch
16+
import kotlinx.coroutines.test.runTest
17+
import okhttp3.OkHttpClient
18+
import org.junit.jupiter.api.assertInstanceOf
19+
import java.util.concurrent.TimeUnit
20+
import kotlin.test.Test
21+
import kotlin.test.assertEquals
22+
import kotlin.test.fail
823

9-
class OkHttpHttpClientTest : HttpClientTest(OkHttp)
24+
class OkHttpHttpClientTest : HttpClientTest(OkHttp) {
25+
@Test
26+
fun testCancelSseRequestIncomingCollect() {
27+
val okHttpClient = OkHttpClient()
28+
29+
HttpClient(OkHttp) {
30+
engine { preconfigured = okHttpClient }
31+
install(SSE)
32+
}.use { client ->
33+
runTest {
34+
var request: Job? = null
35+
request = launch {
36+
client.sse("${TEST_SERVER}/sse/hello?times=20&interval=100") {
37+
request?.cancel() // Cancel the request once the connection is open.
38+
incoming.collect() // Collect all messages.
39+
}
40+
fail("Request should be cancelled.")
41+
}
42+
request.join()
43+
}
44+
}
45+
46+
okHttpClient.connectionPool.evictAll() // Make sure idle connections are removed.
47+
assertEquals(0, okHttpClient.connectionPool.connectionCount())
48+
}
49+
50+
@Test
51+
fun testCancelSseRequestWithDelay() {
52+
val okHttpClient = OkHttpClient()
53+
54+
HttpClient(OkHttp) {
55+
engine { preconfigured = okHttpClient }
56+
install(SSE)
57+
}.use { client ->
58+
runTest {
59+
var request: Job? = null
60+
request = launch {
61+
client.sse("${TEST_SERVER}/sse/hello?times=20&interval=100") {
62+
request?.cancel() // Cancel the request once the connection is open.
63+
delay(1) // Never read from incoming.
64+
}
65+
fail("Request should be cancelled.")
66+
}
67+
request.join()
68+
}
69+
}
70+
71+
okHttpClient.connectionPool.evictAll() // Make sure idle connections are removed.
72+
assertEquals(0, okHttpClient.connectionPool.connectionCount())
73+
}
74+
75+
@Test
76+
fun testSSESessionTimeout() {
77+
val okHttpClient = OkHttpClient.Builder().apply {
78+
readTimeout(1L, TimeUnit.SECONDS)
79+
}.build()
80+
81+
HttpClient(OkHttp) {
82+
engine { preconfigured = okHttpClient }
83+
install(SSE)
84+
}.use { client ->
85+
runTest {
86+
client.sse("$TEST_SERVER/sse/hello?delay=10000") {
87+
try {
88+
incoming.collect()
89+
fail("Request should error.")
90+
} catch (e: SSEClientException) {
91+
assertInstanceOf<SocketTimeoutException>(e.cause)
92+
}
93+
}
94+
}
95+
}
96+
}
97+
}

0 commit comments

Comments
 (0)