Skip to content

update FlowSubscriptionExecutionStrategy to support flow natively #1120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 Expedia, Inc
* Copyright 2021 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,15 +28,14 @@ import graphql.execution.SubscriptionExecutionStrategy
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters
import graphql.execution.reactive.CompletionStageMappingPublisher
import graphql.schema.GraphQLObjectType
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.reactive.asPublisher
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.future.await
import kotlinx.coroutines.reactive.asFlow
import org.reactivestreams.Publisher
import java.util.Collections
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.function.Function

/**
* [SubscriptionExecutionStrategy] replacement that and allows schema subscription functions
Expand All @@ -62,20 +61,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec

//
// when the upstream source event stream completes, subscribe to it and wire in our adapter
val overallResult: CompletableFuture<ExecutionResult> = sourceEventStream.thenApply { publisher ->
if (publisher == null) {
val overallResult: CompletableFuture<ExecutionResult> = sourceEventStream.thenApply { flow ->
if (flow == null) {
ExecutionResultImpl(null, executionContext.errors)
} else {
val mapperFunction = Function<Any, CompletionStage<ExecutionResult>> { eventPayload: Any? ->
val returnFlow = flow.map { eventPayload: Any? ->
executeSubscriptionEvent(
executionContext,
parameters,
eventPayload
)
).await()
}
// we need explicit cast as Kotlin Flow is covariant (Flow<out T> vs Publisher<T>)
val mapSourceToResponse = CompletionStageMappingPublisher<ExecutionResult, Any>(publisher as Publisher<Any>, mapperFunction)
ExecutionResultImpl(mapSourceToResponse, executionContext.errors)
ExecutionResultImpl(returnFlow, executionContext.errors)
}
}

Expand All @@ -102,18 +99,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec
private fun createSourceEventStream(
executionContext: ExecutionContext,
parameters: ExecutionStrategyParameters
): CompletableFuture<Publisher<out Any>?> {
): CompletableFuture<Flow<*>?> {
val newParameters = firstFieldOfSubscriptionSelection(parameters)

val fieldFetched = fetchField(executionContext, newParameters)
return fieldFetched.thenApply { fetchedValue ->
val publisher = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) {
is Publisher<*> -> publisherOrFlow
val flow = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) {
is Publisher<*> -> publisherOrFlow.asFlow()
// below explicit cast is required due to the type erasure and Kotlin declaration-site variance vs Java use-site variance
is Flow<*> -> (publisherOrFlow as? Flow<Any>)?.asPublisher()
is Flow<*> -> publisherOrFlow
else -> null
}
publisher
flow
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 Expedia, Inc
* Copyright 2021 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,9 +34,9 @@ import graphql.schema.GraphQLSchema
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.reactive.asPublisher
import kotlinx.coroutines.reactive.collect
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import org.reactivestreams.Publisher
Expand Down Expand Up @@ -66,9 +66,9 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription to flow`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { ticker }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val list = mutableListOf<Int>()
publisher.collect {
flow.collect {
list.add(it.getData<Map<String, Int>>().getValue("ticker"))
assertEquals(it.extensions["testKey"], "testValue")
}
Expand All @@ -82,9 +82,9 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription to datafetcher flow`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { datafetcher }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val list = mutableListOf<Int>()
publisher.collect {
flow.collect {
val intVal = it.getData<Map<String, Int>>().getValue("datafetcher")
list.add(intVal)
assertEquals(it.extensions["testKey"], "testValue")
Expand All @@ -99,9 +99,9 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription to publisher`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { publisherTicker }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val list = mutableListOf<Int>()
publisher.collect {
flow.collect {
list.add(it.getData<Map<String, Int>>().getValue("publisherTicker"))
}
assertEquals(5, list.size)
Expand All @@ -117,9 +117,9 @@ class FlowSubscriptionExecutionStrategyTest {
.context(SubscriptionContext("junitHandler"))
.build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val list = mutableListOf<Int>()
publisher.collect {
flow.collect {
val contextValue = it.getData<Map<String, String>>().getValue("contextualTicker")
assertTrue(contextValue.startsWith("junitHandler:"))
list.add(contextValue.substringAfter("junitHandler:").toInt())
Expand All @@ -134,11 +134,11 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription to failing flow`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { alwaysThrows }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val errors = mutableListOf<GraphQLError>()
val results = mutableListOf<Int>()
try {
publisher.collect {
flow.collect {
val dataMap = it.getData<Map<String, Int>>()
if (dataMap != null) {
results.add(dataMap.getValue("alwaysThrows"))
Expand All @@ -161,9 +161,9 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription to exploding flow`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { throwsFast }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val errors = response.errors
assertNull(publisher)
assertNull(flow)
assertEquals(1, errors.size)
assertEquals("JUNIT flow failure", errors[0].message.substringAfter(" : "))
}
Expand All @@ -172,10 +172,10 @@ class FlowSubscriptionExecutionStrategyTest {
fun `verify subscription alias`() = runBlocking {
val request = ExecutionInput.newExecutionInput().query("subscription { t: ticker }").build()
val response = testGraphQL.execute(request)
val publisher = response.getData<Publisher<ExecutionResult>>()
val flow = response.getData<Flow<ExecutionResult>>()
val list = mutableListOf<Int>()
publisher.collect {
list.add(it.getData<Map<String, Int>>().getValue("t"))
flow.collect { executionResult ->
list.add(executionResult.getData<Map<String, Int>>().getValue("t"))
}
assertEquals(5, list.size)
for (i in list.indices) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.convertValue
import com.fasterxml.jackson.module.kotlin.readValue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.reactor.asFlux
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import org.springframework.web.reactive.socket.WebSocketSession
Expand Down Expand Up @@ -130,7 +131,7 @@ class ApolloSubscriptionProtocolHandler(
try {
val request = objectMapper.convertValue<GraphQLRequest>(payload)
return subscriptionHandler.executeSubscription(request, context)
.toFlux()
.asFlux()
.map {
if (it.errors?.isNotEmpty() == true) {
SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019 Expedia, Inc
* Copyright 2021 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,9 +26,9 @@ import com.expediagroup.graphql.server.types.GraphQLRequest
import com.expediagroup.graphql.server.types.GraphQLResponse
import graphql.ExecutionResult
import graphql.GraphQL
import org.reactivestreams.Publisher
import reactor.core.publisher.Flux
import reactor.kotlin.core.publisher.toFlux
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map

/**
* Default Spring implementation of GraphQL subscription handler.
Expand All @@ -38,17 +38,16 @@ open class SpringGraphQLSubscriptionHandler(
private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null
) {

fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flux<GraphQLResponse<*>> {
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flow<GraphQLResponse<*>> {
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry)

return graphQL.execute(input)
.getData<Publisher<ExecutionResult>>()
.toFlux()
.getData<Flow<ExecutionResult>>()
.map { result -> result.toGraphQLResponse() }
.onErrorResume { throwable ->
.catch { throwable ->
val error = throwable.toGraphQLError()
Flux.just(GraphQLResponse<Any?>(errors = listOf(error.toGraphQLKotlinType())))
emit(GraphQLResponse<Any?>(errors = listOf(error.toGraphQLKotlinType())))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import graphql.GraphQL
import graphql.schema.GraphQLSchema
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.flow.flowOf
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.boot.autoconfigure.AutoConfigurations
Expand Down Expand Up @@ -125,7 +126,7 @@ class SubscriptionConfigurationTest {

@Bean
fun subscriptionHandler(): SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(any(), any()) } returns Flux.empty()
every { executeSubscription(any(), any()) } returns flowOf()
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 Expedia, Inc
* Copyright 2021 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ package com.expediagroup.graphql.server.spring.execution
import com.expediagroup.graphql.generator.SchemaGeneratorConfig
import com.expediagroup.graphql.generator.TopLevelObject
import com.expediagroup.graphql.generator.exceptions.GraphQLKotlinException
import com.expediagroup.graphql.generator.execution.FlowSubscriptionExecutionStrategy
import com.expediagroup.graphql.generator.execution.GraphQLContext
import com.expediagroup.graphql.generator.toSchema
import com.expediagroup.graphql.server.execution.DefaultDataLoaderRegistryFactory
Expand All @@ -30,6 +31,7 @@ import graphql.GraphQL
import graphql.schema.DataFetchingEnvironment
import graphql.schema.GraphQLSchema
import io.mockk.mockk
import kotlinx.coroutines.reactor.asFlux
import org.dataloader.DataLoader
import org.junit.jupiter.api.Test
import reactor.core.publisher.Flux
Expand All @@ -51,7 +53,9 @@ class SpringGraphQLSubscriptionHandlerTest {
queries = listOf(TopLevelObject(BasicQuery())),
subscriptions = listOf(TopLevelObject(BasicSubscription()))
)
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema).build()
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema)
.subscriptionExecutionStrategy(FlowSubscriptionExecutionStrategy())
.build()
private val mockLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
override val dataLoaderName: String = "MockDataLoader"
override fun getDataLoader(): DataLoader<String, String> = DataLoader<String, String> { ids ->
Expand All @@ -66,7 +70,7 @@ class SpringGraphQLSubscriptionHandlerTest {
@Test
fun `verify subscription`() {
val request = GraphQLRequest(query = "subscription { ticker }")
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()

StepVerifier.create(responseFlux)
.thenConsumeWhile { response ->
Expand All @@ -84,7 +88,7 @@ class SpringGraphQLSubscriptionHandlerTest {
@Test
fun `verify subscription with data loader`() {
val request = GraphQLRequest(query = "subscription { dataLoaderValue }")
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()

StepVerifier.create(responseFlux)
.thenConsumeWhile { response ->
Expand All @@ -105,7 +109,7 @@ class SpringGraphQLSubscriptionHandlerTest {
fun `verify subscription with context`() {
val request = GraphQLRequest(query = "subscription { contextualTicker }")
val context = SubscriptionContext("junitHandler")
val responseFlux = subscriptionHandler.executeSubscription(request, context)
val responseFlux = subscriptionHandler.executeSubscription(request, context).asFlux()

StepVerifier.create(responseFlux)
.thenConsumeWhile { response ->
Expand All @@ -126,7 +130,7 @@ class SpringGraphQLSubscriptionHandlerTest {
@Test
fun `verify subscription to failing publisher`() {
val request = GraphQLRequest(query = "subscription { alwaysThrows }")
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()

StepVerifier.create(responseFlux)
.assertNext { response ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ import io.mockk.mockk
import io.mockk.verify
import io.mockk.verifyOrder
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import org.junit.jupiter.api.Test
import org.springframework.web.reactive.socket.WebSocketSession
import reactor.core.publisher.Flux
import reactor.test.StepVerifier
import java.time.Duration
import kotlin.test.assertEquals
Expand Down Expand Up @@ -297,7 +298,7 @@ class ApolloSubscriptionProtocolHandlerTest {
every { id } returns "123"
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
}

val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
Expand Down Expand Up @@ -329,7 +330,7 @@ class ApolloSubscriptionProtocolHandlerTest {
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
// Never closes
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.interval(Duration.ofSeconds(1)).map { GraphQLResponse("myData") }
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(Duration.ofSeconds(1)).map { GraphQLResponse("myData") }
}

val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
Expand Down Expand Up @@ -360,7 +361,7 @@ class ApolloSubscriptionProtocolHandlerTest {
every { id } returns "123"
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
}

val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
Expand Down Expand Up @@ -394,7 +395,7 @@ class ApolloSubscriptionProtocolHandlerTest {
every { id } returns "123"
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
}

val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
Expand Down Expand Up @@ -427,7 +428,7 @@ class ApolloSubscriptionProtocolHandlerTest {
}
val errors = listOf(GraphQLServerError("My GraphQL Error"))
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse<Any>(errors = errors))
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse<Any>(errors = errors))
}

val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
Expand Down Expand Up @@ -503,7 +504,7 @@ class ApolloSubscriptionProtocolHandlerTest {
}
val expectedResponse = GraphQLResponse("myData")
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(expectedResponse)
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(expectedResponse)
}
val subscriptionHooks: ApolloSubscriptionHooks = mockk {
every { onConnect(any(), any(), any()) } returns null
Expand Down