Skip to content

Commit b310670

Browse files
gklijssmyrick
authored andcommitted
fix subscriptions by not removing the ak subscription when one of the… (ExpediaGroup#510)
* fix subscriptions by not removing the ak subscription when one of the subscriptions is stopped. * Prevent cancelling subscription for someone else of operation name is the same, by including operation name in the id. Keep hasmaps clean by removing canceled subscriptions
1 parent 22f0d38 commit b310670

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ class ApolloSubscriptionProtocolHandler(
4444
private val subscriptionHandler: SubscriptionHandler,
4545
private val objectMapper: ObjectMapper
4646
) {
47-
// Keep Alive subscriptions are saved by web socket session id since they are sent on connection init
48-
private val keepAliveSubscriptions = ConcurrentHashMap<String, Subscription>()
49-
// Data subscriptions are saved by SubscriptionOperationMessage.id
47+
// Data subscriptions are saved by web socket session id + SubscriptionOperationMessage.id
5048
private val subscriptions = ConcurrentHashMap<String, Subscription>()
49+
// Mapping from client id to active subscriptions
50+
private val subscriptionsForClient = ConcurrentHashMap<String, MutableList<String>>()
5151

5252
private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java)
5353
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
@@ -57,35 +57,36 @@ class ApolloSubscriptionProtocolHandler(
5757
try {
5858
val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload)
5959

60-
return when {
61-
operationMessage.type == GQL_CONNECTION_INIT.type -> {
60+
return when (operationMessage.type) {
61+
GQL_CONNECTION_INIT.type -> {
6262
val flux = Flux.just(SubscriptionOperationMessage(GQL_CONNECTION_ACK.type))
6363
val keepAliveInterval = config.subscriptions.keepAliveInterval
64+
subscriptionsForClient[session.id] = mutableListOf()
6465
if (keepAliveInterval != null) {
66+
subscriptionsForClient[session.id]?.add(session.id)
6567
// Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated
6668
val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval))
67-
.map { keepAliveMessage }
68-
.doOnSubscribe {
69-
keepAliveSubscriptions[session.id] = it
70-
}
69+
.map { keepAliveMessage }
70+
.doOnSubscribe {
71+
subscriptions[session.id] = it
72+
}
7173
return flux.concatWith(keepAliveFlux)
7274
}
73-
7475
return flux
7576
}
76-
operationMessage.type == GQL_START.type -> startSubscription(operationMessage, session)
77-
operationMessage.type == GQL_STOP.type -> {
77+
GQL_START.type -> startSubscription(operationMessage, session)
78+
GQL_STOP.type -> {
7879
stopSubscription(operationMessage, session)
79-
Flux.empty()
80+
return Flux.empty()
8081
}
81-
operationMessage.type == GQL_CONNECTION_TERMINATE.type -> {
82-
stopSubscription(operationMessage, session)
82+
GQL_CONNECTION_TERMINATE.type -> {
83+
terminateSubscription(session)
8384
session.close()
84-
Flux.empty()
85+
return Flux.empty()
8586
}
8687
else -> {
8788
logger.error("Unknown subscription operation $operationMessage")
88-
Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
89+
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
8990
}
9091
}
9192
} catch (exception: Exception) {
@@ -108,9 +109,9 @@ class ApolloSubscriptionProtocolHandler(
108109
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
109110
}
110111

111-
return try {
112+
try {
112113
val request = objectMapper.convertValue<GraphQLRequest>(payload)
113-
subscriptionHandler.executeSubscription(request)
114+
return subscriptionHandler.executeSubscription(request)
114115
.map {
115116
if (it.errors?.isNotEmpty() == true) {
116117
SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it)
@@ -121,20 +122,33 @@ class ApolloSubscriptionProtocolHandler(
121122
.concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id)))
122123
.doOnSubscribe {
123124
logger.trace("WebSocket GraphQL subscription subscribe, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}")
124-
subscriptions[operationMessage.id] = it
125+
subscriptions[session.id + operationMessage.id] = it
126+
subscriptionsForClient[session.id]?.add(session.id + operationMessage.id)
125127
}
126128
.doOnCancel { logger.trace("WebSocket GraphQL subscription cancel, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}") }
127129
.doOnComplete { logger.trace("WebSocket GraphQL subscription complete, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}") }
128130
} catch (exception: Exception) {
129131
logger.error("Error running graphql subscription", exception)
130-
Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
132+
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
131133
}
132134
}
133135

134136
private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) {
135137
if (operationMessage.id != null) {
136-
keepAliveSubscriptions[session.id]?.cancel()
137-
subscriptions[operationMessage.id]?.cancel()
138+
val key = session.id + operationMessage.id
139+
subscriptions[key]?.let {
140+
it.cancel()
141+
subscriptions.remove(key)
142+
subscriptionsForClient[session.id]?.remove(key)
143+
}
144+
}
145+
}
146+
147+
private fun terminateSubscription(session: WebSocketSession) {
148+
subscriptionsForClient[session.id]?.let {
149+
it.forEach { subscriptions[it]?.cancel() }
150+
it.forEach { subscriptions.remove(it) }
138151
}
152+
subscriptionsForClient.remove(session.id)
139153
}
140154
}

graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandlerTest.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ class ApolloSubscriptionProtocolHandlerTest {
8585
}
8686
}
8787
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_INIT.type)
88-
val session: WebSocketSession = mockk()
88+
val session: WebSocketSession = mockk {
89+
every { id } returns "123"
90+
}
8991
val subscriptionHandler: SubscriptionHandler = mockk()
9092

9193
val handler = ApolloSubscriptionProtocolHandler(config, subscriptionHandler, objectMapper)
@@ -104,7 +106,9 @@ class ApolloSubscriptionProtocolHandlerTest {
104106
}
105107
}
106108
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_INIT.type)
107-
val session: WebSocketSession = mockk()
109+
val session: WebSocketSession = mockk {
110+
every { id } returns "123"
111+
}
108112
val subscriptionHandler: SubscriptionHandler = mockk()
109113

110114
val handler = ApolloSubscriptionProtocolHandler(config, subscriptionHandler, objectMapper)
@@ -143,6 +147,7 @@ class ApolloSubscriptionProtocolHandlerTest {
143147
val config: GraphQLConfigurationProperties = mockk()
144148
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_TERMINATE.type)
145149
val session: WebSocketSession = mockk {
150+
every { id } returns "123"
146151
every { close() } returns mockk()
147152
}
148153
val subscriptionHandler: SubscriptionHandler = mockk()

0 commit comments

Comments
 (0)