Skip to content

Commit 13174d8

Browse files
author
Shane Myrick
committed
Refactor subscription caching logic
Move the saving of subscriptions to a separate class so we can vefify the logic with unit tests and simplify the ApolloSubscriptionProtocolHandler. This also exposed a bug that we were not saving the operation subscriptions to be stopped properly. This is now covered by the unit tests
1 parent 9425a5f commit 13174d8

File tree

3 files changed

+333
-55
lines changed

3 files changed

+333
-55
lines changed

graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Client
2323
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_TERMINATE
2424
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_START
2525
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_STOP
26-
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE
2726
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ACK
2827
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ERROR
2928
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_KEEP_ALIVE
@@ -32,12 +31,10 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Server
3231
import com.fasterxml.jackson.databind.ObjectMapper
3332
import com.fasterxml.jackson.module.kotlin.convertValue
3433
import com.fasterxml.jackson.module.kotlin.readValue
35-
import org.reactivestreams.Subscription
3634
import org.slf4j.LoggerFactory
3735
import org.springframework.web.reactive.socket.WebSocketSession
3836
import reactor.core.publisher.Flux
3937
import java.time.Duration
40-
import java.util.concurrent.ConcurrentHashMap
4138

4239
/**
4340
* Implementation of the `graphql-ws` protocol defined by Apollo
@@ -48,11 +45,7 @@ class ApolloSubscriptionProtocolHandler(
4845
private val subscriptionHandler: SubscriptionHandler,
4946
private val objectMapper: ObjectMapper
5047
) {
51-
// Sessions are saved by web socket session id
52-
private val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()
53-
// Operations are saved by web socket session id, then operation id
54-
private val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()
55-
48+
private val sessionState = ApolloSubscriptionSessionState()
5649
private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java)
5750
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
5851
private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type)
@@ -63,36 +56,26 @@ class ApolloSubscriptionProtocolHandler(
6356
try {
6457
val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload)
6558

66-
return when (operationMessage.type) {
67-
GQL_CONNECTION_INIT.type -> {
68-
val flux = Flux.just(acknowledgeMessage)
69-
val keepAliveInterval = config.subscriptions.keepAliveInterval
70-
if (keepAliveInterval != null) {
71-
// Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated
72-
val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval))
73-
.map { keepAliveMessage }
74-
.doOnSubscribe {
75-
logger.debug("GraphQL subscription INIT, sessionId=${session.id} activeSessions=${activeKeepAliveSessions.count()}")
76-
activeKeepAliveSessions[session.id] = it
77-
}
78-
79-
return flux.concatWith(keepAliveFlux)
80-
}
59+
logger.debug("GraphQL subscription client message, sessionId=${session.id} operationMessage=$operationMessage")
8160

82-
return flux
61+
when (operationMessage.type) {
62+
GQL_CONNECTION_INIT.type -> {
63+
val ackowledgeMessageFlux = Flux.just(acknowledgeMessage)
64+
val keepAliveFlux = getKeepAliveFlux(session)
65+
return ackowledgeMessageFlux.concatWith(keepAliveFlux)
8366
}
84-
GQL_START.type -> startSubscription(operationMessage, session)
67+
GQL_START.type -> return startSubscription(operationMessage, session)
8568
GQL_STOP.type -> {
86-
stopSubscription(operationMessage, session)
69+
sessionState.stopOperation(session, operationMessage)
8770
return Flux.empty()
8871
}
8972
GQL_CONNECTION_TERMINATE.type -> {
90-
terminateSession(session)
73+
sessionState.terminateSession(session)
9174
return Flux.empty()
9275
}
9376
else -> {
9477
logger.error("Unknown subscription operation $operationMessage")
95-
stopSubscription(operationMessage, session)
78+
sessionState.stopOperation(session, operationMessage)
9679
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
9780
}
9881
}
@@ -102,6 +85,21 @@ class ApolloSubscriptionProtocolHandler(
10285
}
10386
}
10487

88+
/**
89+
* If the keep alive configuraation is set, send a message back to client at every interval until the session is terminated.
90+
* Otherwise just return empty flux to append to the acknowledge message.
91+
*/
92+
private fun getKeepAliveFlux(session: WebSocketSession): Flux<SubscriptionOperationMessage> {
93+
val keepAliveInterval: Long? = config.subscriptions.keepAliveInterval
94+
if (keepAliveInterval != null) {
95+
return Flux.interval(Duration.ofMillis(keepAliveInterval))
96+
.map { keepAliveMessage }
97+
.doOnSubscribe { sessionState.saveKeepAliveSubscription(session, it) }
98+
}
99+
100+
return Flux.empty()
101+
}
102+
105103
@Suppress("Detekt.TooGenericExceptionCaught")
106104
private fun startSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession): Flux<SubscriptionOperationMessage> {
107105
if (operationMessage.id == null) {
@@ -113,7 +111,7 @@ class ApolloSubscriptionProtocolHandler(
113111

114112
if (payload == null) {
115113
logger.error("GraphQL subscription payload was null instead of a GraphQLRequest object")
116-
stopSubscription(operationMessage, session)
114+
sessionState.stopOperation(session, operationMessage)
117115
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
118116
}
119117

@@ -127,35 +125,13 @@ class ApolloSubscriptionProtocolHandler(
127125
SubscriptionOperationMessage(type = GQL_DATA.type, id = operationMessage.id, payload = it)
128126
}
129127
}
130-
.concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id)))
131-
.doOnSubscribe {
132-
logger.debug("GraphQL subscription START, sessionId=${session.id} operationId=${operationMessage.id}")
133-
activeOperations[session.id]?.put(operationMessage.id, it)
134-
}
135-
.doOnCancel { logger.debug("GraphQL subscription CANCEL, sessionId=${session.id} operationId=${operationMessage.id}") }
136-
.doOnComplete { logger.debug("GraphQL subscription COMPELTE, sessionId=${session.id} operationId=${operationMessage.id}") }
128+
.concatWith(Flux.just(SubscriptionOperationMessage(type = SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE.type, id = operationMessage.id)))
129+
.doOnSubscribe { sessionState.saveOperation(session, operationMessage, it) }
137130
} catch (exception: Exception) {
138131
logger.error("Error running graphql subscription", exception)
139-
stopSubscription(operationMessage, session)
132+
// Do not terminate the session, just stop the operation messages
133+
sessionState.stopOperation(session, operationMessage)
140134
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
141135
}
142136
}
143-
144-
private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) {
145-
logger.debug("GraphQL subscription STOP, sessionId=${session.id} operationId=${operationMessage.id}")
146-
if (operationMessage.id != null) {
147-
val operationsForSession = activeOperations[session.id]
148-
operationsForSession?.get(operationMessage.id)?.cancel()
149-
operationsForSession?.remove(operationMessage.id)
150-
}
151-
}
152-
153-
private fun terminateSession(session: WebSocketSession) {
154-
logger.debug("GraphQL subscription TERMINATE, sessionId=${session.id}")
155-
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
156-
activeOperations.remove(session.id)
157-
activeKeepAliveSessions[session.id]?.cancel()
158-
activeKeepAliveSessions.remove(session.id)
159-
session.close()
160-
}
161137
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright 2019 Expedia, Inc
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.expediagroup.graphql.spring.execution
18+
19+
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage
20+
import org.reactivestreams.Subscription
21+
import org.springframework.web.reactive.socket.WebSocketSession
22+
import java.util.concurrent.ConcurrentHashMap
23+
24+
internal class ApolloSubscriptionSessionState {
25+
26+
// Sessions are saved by web socket session id
27+
internal val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()
28+
29+
// Operations are saved by web socket session id, then operation id
30+
internal val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()
31+
32+
/**
33+
* Save the session that is sending keep alive messages.
34+
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
35+
* These messages will be stopped on [terminateSession].
36+
*/
37+
fun saveKeepAliveSubscription(session: WebSocketSession, subscription: Subscription) {
38+
activeKeepAliveSessions[session.id] = subscription
39+
}
40+
41+
/**
42+
* Save the operation that is sending data to the client.
43+
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
44+
* These messages will be stopped on [stopOperation].
45+
*/
46+
fun saveOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage, subscription: Subscription) {
47+
if (operationMessage.id != null) {
48+
val operationsForSession: ConcurrentHashMap<String, Subscription> = activeOperations.getOrPut(session.id) { ConcurrentHashMap() }
49+
operationsForSession[operationMessage.id] = subscription
50+
}
51+
}
52+
53+
/**
54+
* Stop the subscription sending data. Does NOT terminate the session.
55+
*/
56+
fun stopOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage) {
57+
if (operationMessage.id != null) {
58+
val operationsForSession = activeOperations[session.id]
59+
operationsForSession?.get(operationMessage.id)?.cancel()
60+
operationsForSession?.remove(operationMessage.id)
61+
}
62+
}
63+
64+
/**
65+
* Terminate the session, cancelling the keep alive messages and all operations active for this session.
66+
*/
67+
fun terminateSession(session: WebSocketSession) {
68+
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
69+
activeOperations.remove(session.id)
70+
activeKeepAliveSessions[session.id]?.cancel()
71+
activeKeepAliveSessions.remove(session.id)
72+
session.close()
73+
}
74+
}

0 commit comments

Comments
 (0)