diff --git a/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt b/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt index dde78c1ee1..dba57e82cd 100644 --- a/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt +++ b/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt @@ -23,7 +23,6 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Client import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_TERMINATE import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_START import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_STOP -import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ACK import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ERROR import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_KEEP_ALIVE @@ -32,12 +31,10 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Server import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.kotlin.convertValue import com.fasterxml.jackson.module.kotlin.readValue -import org.reactivestreams.Subscription import org.slf4j.LoggerFactory import org.springframework.web.reactive.socket.WebSocketSession import reactor.core.publisher.Flux import java.time.Duration -import java.util.concurrent.ConcurrentHashMap /** * Implementation of the `graphql-ws` protocol defined by Apollo @@ -48,11 +45,7 @@ class ApolloSubscriptionProtocolHandler( private val subscriptionHandler: SubscriptionHandler, private val objectMapper: ObjectMapper ) { - // Sessions are saved by web socket session id - private val activeKeepAliveSessions = ConcurrentHashMap() - // Operations are saved by web socket session id, then operation id - private val activeOperations = ConcurrentHashMap>() - + private val sessionState = ApolloSubscriptionSessionState() private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java) private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type) private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type) @@ -63,36 +56,26 @@ class ApolloSubscriptionProtocolHandler( try { val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload) - return when (operationMessage.type) { - GQL_CONNECTION_INIT.type -> { - val flux = Flux.just(acknowledgeMessage) - val keepAliveInterval = config.subscriptions.keepAliveInterval - if (keepAliveInterval != null) { - // Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated - val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval)) - .map { keepAliveMessage } - .doOnSubscribe { - logger.debug("GraphQL subscription INIT, sessionId=${session.id} activeSessions=${activeKeepAliveSessions.count()}") - activeKeepAliveSessions[session.id] = it - } - - return flux.concatWith(keepAliveFlux) - } + logger.debug("GraphQL subscription client message, sessionId=${session.id} operationMessage=$operationMessage") - return flux + when (operationMessage.type) { + GQL_CONNECTION_INIT.type -> { + val ackowledgeMessageFlux = Flux.just(acknowledgeMessage) + val keepAliveFlux = getKeepAliveFlux(session) + return ackowledgeMessageFlux.concatWith(keepAliveFlux) } - GQL_START.type -> startSubscription(operationMessage, session) + GQL_START.type -> return startSubscription(operationMessage, session) GQL_STOP.type -> { - stopSubscription(operationMessage, session) + sessionState.stopOperation(session, operationMessage) return Flux.empty() } GQL_CONNECTION_TERMINATE.type -> { - terminateSession(session) + sessionState.terminateSession(session) return Flux.empty() } else -> { logger.error("Unknown subscription operation $operationMessage") - stopSubscription(operationMessage, session) + sessionState.stopOperation(session, operationMessage) return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id)) } } @@ -102,6 +85,21 @@ class ApolloSubscriptionProtocolHandler( } } + /** + * If the keep alive configuraation is set, send a message back to client at every interval until the session is terminated. + * Otherwise just return empty flux to append to the acknowledge message. + */ + private fun getKeepAliveFlux(session: WebSocketSession): Flux { + val keepAliveInterval: Long? = config.subscriptions.keepAliveInterval + if (keepAliveInterval != null) { + return Flux.interval(Duration.ofMillis(keepAliveInterval)) + .map { keepAliveMessage } + .doOnSubscribe { sessionState.saveKeepAliveSubscription(session, it) } + } + + return Flux.empty() + } + @Suppress("Detekt.TooGenericExceptionCaught") private fun startSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession): Flux { if (operationMessage.id == null) { @@ -113,7 +111,7 @@ class ApolloSubscriptionProtocolHandler( if (payload == null) { logger.error("GraphQL subscription payload was null instead of a GraphQLRequest object") - stopSubscription(operationMessage, session) + sessionState.stopOperation(session, operationMessage) return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id)) } @@ -127,35 +125,13 @@ class ApolloSubscriptionProtocolHandler( SubscriptionOperationMessage(type = GQL_DATA.type, id = operationMessage.id, payload = it) } } - .concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id))) - .doOnSubscribe { - logger.debug("GraphQL subscription START, sessionId=${session.id} operationId=${operationMessage.id}") - activeOperations[session.id]?.put(operationMessage.id, it) - } - .doOnCancel { logger.debug("GraphQL subscription CANCEL, sessionId=${session.id} operationId=${operationMessage.id}") } - .doOnComplete { logger.debug("GraphQL subscription COMPELTE, sessionId=${session.id} operationId=${operationMessage.id}") } + .concatWith(Flux.just(SubscriptionOperationMessage(type = SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE.type, id = operationMessage.id))) + .doOnSubscribe { sessionState.saveOperation(session, operationMessage, it) } } catch (exception: Exception) { logger.error("Error running graphql subscription", exception) - stopSubscription(operationMessage, session) + // Do not terminate the session, just stop the operation messages + sessionState.stopOperation(session, operationMessage) return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id)) } } - - private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) { - logger.debug("GraphQL subscription STOP, sessionId=${session.id} operationId=${operationMessage.id}") - if (operationMessage.id != null) { - val operationsForSession = activeOperations[session.id] - operationsForSession?.get(operationMessage.id)?.cancel() - operationsForSession?.remove(operationMessage.id) - } - } - - private fun terminateSession(session: WebSocketSession) { - logger.debug("GraphQL subscription TERMINATE, sessionId=${session.id}") - activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() } - activeOperations.remove(session.id) - activeKeepAliveSessions[session.id]?.cancel() - activeKeepAliveSessions.remove(session.id) - session.close() - } } diff --git a/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionState.kt b/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionState.kt new file mode 100644 index 0000000000..ff7e0b2ecb --- /dev/null +++ b/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionState.kt @@ -0,0 +1,78 @@ +/* + * Copyright 2019 Expedia, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.expediagroup.graphql.spring.execution + +import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage +import org.reactivestreams.Subscription +import org.springframework.web.reactive.socket.WebSocketSession +import java.util.concurrent.ConcurrentHashMap + +internal class ApolloSubscriptionSessionState { + + // Sessions are saved by web socket session id + internal val activeKeepAliveSessions = ConcurrentHashMap() + + // Operations are saved by web socket session id, then operation id + internal val activeOperations = ConcurrentHashMap>() + + /** + * Save the session that is sending keep alive messages. + * This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel. + * These messages will be stopped on [terminateSession]. + */ + fun saveKeepAliveSubscription(session: WebSocketSession, subscription: Subscription) { + activeKeepAliveSessions[session.id] = subscription + } + + /** + * Save the operation that is sending data to the client. + * This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel. + * These messages will be stopped on [stopOperation]. + */ + fun saveOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage, subscription: Subscription) { + if (operationMessage.id != null) { + val operationsForSession: ConcurrentHashMap = activeOperations.getOrPut(session.id) { ConcurrentHashMap() } + operationsForSession[operationMessage.id] = subscription + } + } + + /** + * Stop the subscription sending data. Does NOT terminate the session. + */ + fun stopOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage) { + if (operationMessage.id != null) { + val operationsForSession = activeOperations[session.id] + operationsForSession?.get(operationMessage.id)?.cancel() + operationsForSession?.remove(operationMessage.id) + + if (operationsForSession?.isEmpty() == true) { + activeOperations.remove(session.id) + } + } + } + + /** + * Terminate the session, cancelling the keep alive messages and all operations active for this session. + */ + fun terminateSession(session: WebSocketSession) { + activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() } + activeOperations.remove(session.id) + activeKeepAliveSessions[session.id]?.cancel() + activeKeepAliveSessions.remove(session.id) + session.close() + } +} diff --git a/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionStateTest.kt b/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionStateTest.kt new file mode 100644 index 0000000000..fb3a407c54 --- /dev/null +++ b/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionSessionStateTest.kt @@ -0,0 +1,251 @@ +/* + * Copyright 2019 Expedia, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.expediagroup.graphql.spring.execution + +import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import org.junit.jupiter.api.Test +import org.reactivestreams.Subscription +import org.springframework.web.reactive.socket.WebSocketSession +import reactor.core.publisher.Mono +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class ApolloSubscriptionSessionStateTest { + + @Test + fun `saveKeepAliveSubscription saves the subscription by session id`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + + assertEquals(expected = 0, actual = state.activeKeepAliveSessions.size) + + state.saveKeepAliveSubscription(mockSession, mockSubscription) + + assertEquals(expected = 1, actual = state.activeKeepAliveSessions.size) + assertEquals(expected = mockSubscription, actual = state.activeKeepAliveSessions["123"]) + } + + @Test + fun `saveOperation does not save the subscription if operation id is null`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val mockOperationMessage: SubscriptionOperationMessage = mockk { every { id } returns null } + + assertEquals(expected = 0, actual = state.activeOperations.size) + + state.saveOperation(mockSession, mockOperationMessage, mockSubscription) + + assertEquals(expected = 0, actual = state.activeOperations.size) + } + + @Test + fun `saveOperation saves the subscription if operation id is valid`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val mockOperationMessage: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + + assertEquals(expected = 0, actual = state.activeOperations.size) + + state.saveOperation(mockSession, mockOperationMessage, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = mockSubscription, actual = state.activeOperations["123"]?.get("abc")) + } + + @Test + fun `saveOperation saves the subscription and does not duplicate session ids if operation id is valid`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val mockOperationMessage1: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + val mockOperationMessage2: SubscriptionOperationMessage = mockk { every { id } returns "def" } + + assertEquals(expected = 0, actual = state.activeOperations.size) + + state.saveOperation(mockSession, mockOperationMessage1, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = 1, actual = state.activeOperations["123"]?.size) + + state.saveOperation(mockSession, mockOperationMessage2, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = 2, actual = state.activeOperations["123"]?.size) + } + + @Test + fun `stopOperation does not cancel the subscription if operation id is null`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val inputOperation: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + + state.saveOperation(mockSession, inputOperation, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + + val cancelOperation: SubscriptionOperationMessage = mockk { every { id } returns null } + state.stopOperation(mockSession, cancelOperation) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = mockSubscription, actual = state.activeOperations["123"]?.get("abc")) + verify(exactly = 0) { mockSubscription.cancel() } + } + + @Test + fun `stopOperation does not cancel the subscription if operation id not match`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val inputOperation: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + + state.saveOperation(mockSession, inputOperation, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + + val cancelOperation: SubscriptionOperationMessage = mockk { every { id } returns "xyz" } + state.stopOperation(mockSession, cancelOperation) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = mockSubscription, actual = state.activeOperations["123"]?.get("abc")) + verify(exactly = 0) { mockSubscription.cancel() } + } + + @Test + fun `stopOperation clears entire operation cache if it is empty after removal`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk { every { cancel() } returns Unit } + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val inputOperation: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + + state.saveOperation(mockSession, inputOperation, mockSubscription) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = 1, actual = state.activeOperations["123"]?.size) + + state.stopOperation(mockSession, inputOperation) + + assertEquals(expected = 0, actual = state.activeOperations.size) + assertNull(state.activeOperations["123"]) + verify(exactly = 1) { mockSubscription.cancel() } + } + + @Test + fun `stopOperation cancels the subscription if operation id is valid`() { + val state = ApolloSubscriptionSessionState() + val mockSession: WebSocketSession = mockk { every { id } returns "123" } + val mockSubscription1: Subscription = mockk { every { cancel() } returns Unit } + val mockSubscription2: Subscription = mockk { every { cancel() } returns Unit } + val inputOperation1: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + val inputOperation2: SubscriptionOperationMessage = mockk { every { id } returns "def" } + + state.saveOperation(mockSession, inputOperation1, mockSubscription1) + state.saveOperation(mockSession, inputOperation2, mockSubscription2) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = 2, actual = state.activeOperations["123"]?.size) + + state.stopOperation(mockSession, inputOperation1) + + assertEquals(expected = 1, actual = state.activeOperations.size) + assertEquals(expected = 1, actual = state.activeOperations["123"]?.size) + verify(exactly = 1) { mockSubscription1.cancel() } + } + + @Test + fun `terminateSession cancels the keep alive subscription`() { + val state = ApolloSubscriptionSessionState() + val mockSubscription: Subscription = mockk { every { cancel() } returns Unit } + val mockSession: WebSocketSession = mockk { + every { id } returns "123" + every { close() } returns Mono.empty() + } + + state.saveKeepAliveSubscription(mockSession, mockSubscription) + + assertEquals(expected = 1, actual = state.activeKeepAliveSessions.size) + + state.terminateSession(mockSession) + + assertEquals(expected = 0, actual = state.activeKeepAliveSessions.size) + verify(exactly = 1) { mockSubscription.cancel() } + verify(exactly = 1) { mockSession.close() } + } + + @Test + fun `terminateSession cancels all subscriptions for the session and operations`() { + val state = ApolloSubscriptionSessionState() + val mockSessionSubscription: Subscription = mockk { every { cancel() } returns Unit } + val mockOperationSubscription: Subscription = mockk { every { cancel() } returns Unit } + val inputOperation: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + val mockSession: WebSocketSession = mockk { + every { id } returns "123" + every { close() } returns Mono.empty() + } + + state.saveKeepAliveSubscription(mockSession, mockSessionSubscription) + state.saveOperation(mockSession, inputOperation, mockOperationSubscription) + + assertEquals(expected = 1, actual = state.activeKeepAliveSessions.size) + assertEquals(expected = 1, actual = state.activeOperations.size) + + state.terminateSession(mockSession) + + assertEquals(expected = 0, actual = state.activeKeepAliveSessions.size) + assertEquals(expected = 0, actual = state.activeOperations.size) + verify(exactly = 1) { mockSessionSubscription.cancel() } + verify(exactly = 1) { mockOperationSubscription.cancel() } + verify(exactly = 1) { mockSession.close() } + } + + @Test + fun `terminateSession does not cancel any subscriptions if the session id does not match`() { + val state = ApolloSubscriptionSessionState() + val mockSessionSubscription: Subscription = mockk { every { cancel() } returns Unit } + val mockOperationSubscription: Subscription = mockk { every { cancel() } returns Unit } + val inputOperation: SubscriptionOperationMessage = mockk { every { id } returns "abc" } + val mockSession: WebSocketSession = mockk { + every { id } returns "123" + every { close() } returns Mono.empty() + } + + state.saveKeepAliveSubscription(mockSession, mockSessionSubscription) + state.saveOperation(mockSession, inputOperation, mockOperationSubscription) + + assertEquals(expected = 1, actual = state.activeKeepAliveSessions.size) + assertEquals(expected = 1, actual = state.activeOperations.size) + + val nonMatchingSession: WebSocketSession = mockk { + every { id } returns "xyz" + every { close() } returns Mono.empty() + } + state.terminateSession(nonMatchingSession) + + assertEquals(expected = 1, actual = state.activeKeepAliveSessions.size) + assertEquals(expected = 1, actual = state.activeOperations.size) + verify(exactly = 0) { mockSessionSubscription.cancel() } + verify(exactly = 0) { mockOperationSubscription.cancel() } + verify(exactly = 0) { mockSession.close() } + verify(exactly = 1) { nonMatchingSession.close() } + } +}