Skip to content

Commit 7e0087a

Browse files
committed
Prevent cancelling subscription for someone else of operation name is the same, by including operation name in the id. Keep hasmaps clean by removing canceled subscriptions
1 parent 5ea81a8 commit 7e0087a

File tree

2 files changed

+45
-28
lines changed

2 files changed

+45
-28
lines changed

graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ class ApolloSubscriptionProtocolHandler(
4444
private val subscriptionHandler: SubscriptionHandler,
4545
private val objectMapper: ObjectMapper
4646
) {
47-
// Keep Alive subscriptions are saved by web socket session id since they are sent on connection init
48-
private val keepAliveSubscriptions = ConcurrentHashMap<String, Subscription>()
49-
// Data subscriptions are saved by SubscriptionOperationMessage.id
47+
// Data subscriptions are saved by web socket session id + SubscriptionOperationMessage.id
5048
private val subscriptions = ConcurrentHashMap<String, Subscription>()
49+
// Mapping from client id to active subscriptions
50+
private val subscriptionsForClient = ConcurrentHashMap<String, MutableList<String>>()
5151

5252
private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java)
5353
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
@@ -57,35 +57,36 @@ class ApolloSubscriptionProtocolHandler(
5757
try {
5858
val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload)
5959

60-
return when {
61-
operationMessage.type == GQL_CONNECTION_INIT.type -> {
60+
return when (operationMessage.type) {
61+
GQL_CONNECTION_INIT.type -> {
6262
val flux = Flux.just(SubscriptionOperationMessage(GQL_CONNECTION_ACK.type))
6363
val keepAliveInterval = config.subscriptions.keepAliveInterval
64+
subscriptionsForClient[session.id] = mutableListOf()
6465
if (keepAliveInterval != null) {
66+
subscriptionsForClient[session.id]?.add(session.id)
6567
// Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated
6668
val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval))
67-
.map { keepAliveMessage }
68-
.doOnSubscribe {
69-
keepAliveSubscriptions[session.id] = it
70-
}
69+
.map { keepAliveMessage }
70+
.doOnSubscribe {
71+
subscriptions[session.id] = it
72+
}
7173
return flux.concatWith(keepAliveFlux)
7274
}
73-
7475
return flux
7576
}
76-
operationMessage.type == GQL_START.type -> startSubscription(operationMessage, session)
77-
operationMessage.type == GQL_STOP.type -> {
78-
stopSubscription(operationMessage, session, false)
79-
Flux.empty()
77+
GQL_START.type -> startSubscription(operationMessage, session)
78+
GQL_STOP.type -> {
79+
stopSubscription(operationMessage, session)
80+
return Flux.empty()
8081
}
81-
operationMessage.type == GQL_CONNECTION_TERMINATE.type -> {
82-
stopSubscription(operationMessage, session, true)
82+
GQL_CONNECTION_TERMINATE.type -> {
83+
terminateSubscription(session)
8384
session.close()
84-
Flux.empty()
85+
return Flux.empty()
8586
}
8687
else -> {
8788
logger.error("Unknown subscription operation $operationMessage")
88-
Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
89+
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
8990
}
9091
}
9192
} catch (exception: Exception) {
@@ -108,9 +109,9 @@ class ApolloSubscriptionProtocolHandler(
108109
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
109110
}
110111

111-
return try {
112+
try {
112113
val request = objectMapper.convertValue<GraphQLRequest>(payload)
113-
subscriptionHandler.executeSubscription(request)
114+
return subscriptionHandler.executeSubscription(request)
114115
.map {
115116
if (it.errors?.isNotEmpty() == true) {
116117
SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it)
@@ -121,22 +122,33 @@ class ApolloSubscriptionProtocolHandler(
121122
.concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id)))
122123
.doOnSubscribe {
123124
logger.trace("WebSocket GraphQL subscription subscribe, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}")
124-
subscriptions[operationMessage.id] = it
125+
subscriptions[session.id + operationMessage.id] = it
126+
subscriptionsForClient[session.id]?.add(session.id + operationMessage.id)
125127
}
126128
.doOnCancel { logger.trace("WebSocket GraphQL subscription cancel, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}") }
127129
.doOnComplete { logger.trace("WebSocket GraphQL subscription complete, WebSocketSessionID=${session.id} OperationMessageID=${operationMessage.id}") }
128130
} catch (exception: Exception) {
129131
logger.error("Error running graphql subscription", exception)
130-
Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
132+
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
131133
}
132134
}
133135

134-
private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession, terminate: Boolean) {
136+
private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) {
135137
if (operationMessage.id != null) {
136-
if (terminate) {
137-
keepAliveSubscriptions[session.id]?.cancel()
138+
val key = session.id + operationMessage.id
139+
subscriptions[key]?.let {
140+
it.cancel()
141+
subscriptions.remove(key)
142+
subscriptionsForClient[session.id]?.remove(key)
138143
}
139-
subscriptions[operationMessage.id]?.cancel()
140144
}
141145
}
146+
147+
private fun terminateSubscription(session: WebSocketSession) {
148+
subscriptionsForClient[session.id]?.let {
149+
it.forEach { subscriptions[it]?.cancel() }
150+
it.forEach { subscriptions.remove(it) }
151+
}
152+
subscriptionsForClient.remove(session.id)
153+
}
142154
}

graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandlerTest.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ class ApolloSubscriptionProtocolHandlerTest {
8585
}
8686
}
8787
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_INIT.type)
88-
val session: WebSocketSession = mockk()
88+
val session: WebSocketSession = mockk {
89+
every { id } returns "123"
90+
}
8991
val subscriptionHandler: SubscriptionHandler = mockk()
9092

9193
val handler = ApolloSubscriptionProtocolHandler(config, subscriptionHandler, objectMapper)
@@ -104,7 +106,9 @@ class ApolloSubscriptionProtocolHandlerTest {
104106
}
105107
}
106108
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_INIT.type)
107-
val session: WebSocketSession = mockk()
109+
val session: WebSocketSession = mockk {
110+
every { id } returns "123"
111+
}
108112
val subscriptionHandler: SubscriptionHandler = mockk()
109113

110114
val handler = ApolloSubscriptionProtocolHandler(config, subscriptionHandler, objectMapper)
@@ -143,6 +147,7 @@ class ApolloSubscriptionProtocolHandlerTest {
143147
val config: GraphQLConfigurationProperties = mockk()
144148
val operationMessage = SubscriptionOperationMessage(GQL_CONNECTION_TERMINATE.type)
145149
val session: WebSocketSession = mockk {
150+
every { id } returns "123"
146151
every { close() } returns mockk()
147152
}
148153
val subscriptionHandler: SubscriptionHandler = mockk()

0 commit comments

Comments
 (0)