Skip to content

Fix subscription caching logic #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Client
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_TERMINATE
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_START
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_STOP
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ACK
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ERROR
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_KEEP_ALIVE
Expand All @@ -32,12 +31,10 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Server
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.convertValue
import com.fasterxml.jackson.module.kotlin.readValue
import org.reactivestreams.Subscription
import org.slf4j.LoggerFactory
import org.springframework.web.reactive.socket.WebSocketSession
import reactor.core.publisher.Flux
import java.time.Duration
import java.util.concurrent.ConcurrentHashMap

/**
* Implementation of the `graphql-ws` protocol defined by Apollo
Expand All @@ -48,11 +45,7 @@ class ApolloSubscriptionProtocolHandler(
private val subscriptionHandler: SubscriptionHandler,
private val objectMapper: ObjectMapper
) {
// Sessions are saved by web socket session id
private val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()
// Operations are saved by web socket session id, then operation id
private val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()

private val sessionState = ApolloSubscriptionSessionState()
private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java)
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type)
Expand All @@ -63,36 +56,26 @@ class ApolloSubscriptionProtocolHandler(
try {
val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload)

return when (operationMessage.type) {
GQL_CONNECTION_INIT.type -> {
val flux = Flux.just(acknowledgeMessage)
val keepAliveInterval = config.subscriptions.keepAliveInterval
if (keepAliveInterval != null) {
// Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated
val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval))
.map { keepAliveMessage }
.doOnSubscribe {
logger.debug("GraphQL subscription INIT, sessionId=${session.id} activeSessions=${activeKeepAliveSessions.count()}")
activeKeepAliveSessions[session.id] = it
}

return flux.concatWith(keepAliveFlux)
}
logger.debug("GraphQL subscription client message, sessionId=${session.id} operationMessage=$operationMessage")

return flux
when (operationMessage.type) {
GQL_CONNECTION_INIT.type -> {
val ackowledgeMessageFlux = Flux.just(acknowledgeMessage)
val keepAliveFlux = getKeepAliveFlux(session)
return ackowledgeMessageFlux.concatWith(keepAliveFlux)
}
GQL_START.type -> startSubscription(operationMessage, session)
GQL_START.type -> return startSubscription(operationMessage, session)
GQL_STOP.type -> {
stopSubscription(operationMessage, session)
sessionState.stopOperation(session, operationMessage)
return Flux.empty()
}
GQL_CONNECTION_TERMINATE.type -> {
terminateSession(session)
sessionState.terminateSession(session)
return Flux.empty()
}
else -> {
logger.error("Unknown subscription operation $operationMessage")
stopSubscription(operationMessage, session)
sessionState.stopOperation(session, operationMessage)
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
}
}
Expand All @@ -102,6 +85,21 @@ class ApolloSubscriptionProtocolHandler(
}
}

/**
* If the keep alive configuraation is set, send a message back to client at every interval until the session is terminated.
* Otherwise just return empty flux to append to the acknowledge message.
*/
private fun getKeepAliveFlux(session: WebSocketSession): Flux<SubscriptionOperationMessage> {
val keepAliveInterval: Long? = config.subscriptions.keepAliveInterval
if (keepAliveInterval != null) {
return Flux.interval(Duration.ofMillis(keepAliveInterval))
.map { keepAliveMessage }
.doOnSubscribe { sessionState.saveKeepAliveSubscription(session, it) }
}

return Flux.empty()
}

@Suppress("Detekt.TooGenericExceptionCaught")
private fun startSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession): Flux<SubscriptionOperationMessage> {
if (operationMessage.id == null) {
Expand All @@ -113,7 +111,7 @@ class ApolloSubscriptionProtocolHandler(

if (payload == null) {
logger.error("GraphQL subscription payload was null instead of a GraphQLRequest object")
stopSubscription(operationMessage, session)
sessionState.stopOperation(session, operationMessage)
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
}

Expand All @@ -127,35 +125,13 @@ class ApolloSubscriptionProtocolHandler(
SubscriptionOperationMessage(type = GQL_DATA.type, id = operationMessage.id, payload = it)
}
}
.concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id)))
.doOnSubscribe {
logger.debug("GraphQL subscription START, sessionId=${session.id} operationId=${operationMessage.id}")
activeOperations[session.id]?.put(operationMessage.id, it)
}
.doOnCancel { logger.debug("GraphQL subscription CANCEL, sessionId=${session.id} operationId=${operationMessage.id}") }
.doOnComplete { logger.debug("GraphQL subscription COMPELTE, sessionId=${session.id} operationId=${operationMessage.id}") }
.concatWith(Flux.just(SubscriptionOperationMessage(type = SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE.type, id = operationMessage.id)))
.doOnSubscribe { sessionState.saveOperation(session, operationMessage, it) }
} catch (exception: Exception) {
logger.error("Error running graphql subscription", exception)
stopSubscription(operationMessage, session)
// Do not terminate the session, just stop the operation messages
sessionState.stopOperation(session, operationMessage)
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
}
}

private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) {
logger.debug("GraphQL subscription STOP, sessionId=${session.id} operationId=${operationMessage.id}")
if (operationMessage.id != null) {
val operationsForSession = activeOperations[session.id]
operationsForSession?.get(operationMessage.id)?.cancel()
operationsForSession?.remove(operationMessage.id)
}
}

private fun terminateSession(session: WebSocketSession) {
logger.debug("GraphQL subscription TERMINATE, sessionId=${session.id}")
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
activeOperations.remove(session.id)
activeKeepAliveSessions[session.id]?.cancel()
activeKeepAliveSessions.remove(session.id)
session.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2019 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.expediagroup.graphql.spring.execution

import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage
import org.reactivestreams.Subscription
import org.springframework.web.reactive.socket.WebSocketSession
import java.util.concurrent.ConcurrentHashMap

internal class ApolloSubscriptionSessionState {

// Sessions are saved by web socket session id
internal val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()

// Operations are saved by web socket session id, then operation id
internal val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()

/**
* Save the session that is sending keep alive messages.
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
* These messages will be stopped on [terminateSession].
*/
fun saveKeepAliveSubscription(session: WebSocketSession, subscription: Subscription) {
activeKeepAliveSessions[session.id] = subscription
}

/**
* Save the operation that is sending data to the client.
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
* These messages will be stopped on [stopOperation].
*/
fun saveOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage, subscription: Subscription) {
if (operationMessage.id != null) {
val operationsForSession: ConcurrentHashMap<String, Subscription> = activeOperations.getOrPut(session.id) { ConcurrentHashMap() }
operationsForSession[operationMessage.id] = subscription
}
}

/**
* Stop the subscription sending data. Does NOT terminate the session.
*/
fun stopOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage) {
if (operationMessage.id != null) {
val operationsForSession = activeOperations[session.id]
operationsForSession?.get(operationMessage.id)?.cancel()
operationsForSession?.remove(operationMessage.id)

if (operationsForSession?.isEmpty() == true) {
activeOperations.remove(session.id)
}
}
}

/**
* Terminate the session, cancelling the keep alive messages and all operations active for this session.
*/
fun terminateSession(session: WebSocketSession) {
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
activeOperations.remove(session.id)
activeKeepAliveSessions[session.id]?.cancel()
activeKeepAliveSessions.remove(session.id)
session.close()
}
}
Loading