@@ -7,13 +7,13 @@ package io.ktor.client.engine.okhttp
7
7
import io.ktor.client.plugins.sse.*
8
8
import io.ktor.http.*
9
9
import io.ktor.sse.*
10
- import kotlinx.coroutines.CancellationException
11
- import kotlinx.coroutines.CompletableDeferred
10
+ import kotlinx.coroutines.*
12
11
import kotlinx.coroutines.channels.Channel
13
12
import kotlinx.coroutines.channels.onFailure
14
13
import kotlinx.coroutines.channels.trySendBlocking
15
14
import kotlinx.coroutines.flow.Flow
16
15
import kotlinx.coroutines.flow.consumeAsFlow
16
+ import kotlinx.coroutines.flow.onCompletion
17
17
import okhttp3.OkHttpClient
18
18
import okhttp3.Request
19
19
import okhttp3.Response
@@ -22,18 +22,39 @@ import okhttp3.sse.EventSourceListener
22
22
import okhttp3.sse.EventSources
23
23
import kotlin.coroutines.CoroutineContext
24
24
25
- internal class OkHttpSSESession (
26
- engine : OkHttpClient ,
25
+ internal class OkHttpSSESession private constructor (
26
+ factory : EventSource . Factory ,
27
27
engineRequest : Request ,
28
28
override val coroutineContext : CoroutineContext ,
29
29
) : SSESession, EventSourceListener() {
30
- private val serverSentEventsSource = EventSources .createFactory(engine).newEventSource(engineRequest, this )
30
+
31
+ constructor (
32
+ engine: OkHttpClient ,
33
+ engineRequest: Request ,
34
+ callContext: CoroutineContext ,
35
+ ) : this (
36
+ factory = EventSources .createFactory(engine),
37
+ engineRequest = engineRequest,
38
+ coroutineContext = callContext + Job () + CoroutineName (" OkHttpSSESession" ),
39
+ )
40
+
41
+ private val serverSentEventsSource = factory.newEventSource(engineRequest, this )
31
42
32
43
internal val originResponse: CompletableDeferred <Response > = CompletableDeferred ()
33
44
34
45
private val _incoming = Channel <ServerSentEvent >(8 )
35
46
36
47
override val incoming: Flow <ServerSentEvent > = _incoming .consumeAsFlow()
48
+ .onCompletion { cause ->
49
+ // Use onCompletion operator to handle CancellationExceptions which occur in downstream flow.
50
+ if (cause is CancellationException ) close(cause = null )
51
+ }
52
+
53
+ init {
54
+ coroutineContext.job.invokeOnCompletion {
55
+ close(cause = null )
56
+ }
57
+ }
37
58
38
59
override fun onOpen (eventSource : EventSource , response : Response ) {
39
60
originResponse.complete(response)
@@ -52,6 +73,7 @@ internal class OkHttpSSESession(
52
73
(statusCode != HttpStatusCode .OK .value || contentType != ContentType .Text .EventStream .toString())
53
74
) {
54
75
originResponse.complete(response)
76
+ close(cause = null )
55
77
} else {
56
78
val error = t?.let {
57
79
SSEClientException (
@@ -60,15 +82,19 @@ internal class OkHttpSSESession(
60
82
)
61
83
} ? : mapException(response)
62
84
originResponse.completeExceptionally(error)
85
+ close(cause = error)
63
86
}
64
-
65
- _incoming .close()
66
- serverSentEventsSource.cancel()
67
87
}
68
88
69
89
override fun onClosed (eventSource : EventSource ) {
70
- _incoming .close()
90
+ close(cause = null )
91
+ }
92
+
93
+ private fun close (cause : Throwable ? ) {
94
+ _incoming .close(cause)
71
95
serverSentEventsSource.cancel()
96
+ // Cancel context last so 'invokeOnCompletion' doesn't override 'cause' for closing '_incoming'.
97
+ coroutineContext.cancel()
72
98
}
73
99
74
100
private fun mapException (response : Response ? ): SSEClientException {
0 commit comments