Skip to content

Commit 8b7b47e

Browse files
authored
update FlowSubscriptionExecutionStrategy to support flow natively (ExpediaGroup#1120)
Update `FlowSubscriptionExecutionStrategy` to support Kotlin `Flow` natively (i.e. convert publishers to flows vs old logic of converting flows to publishers). While the Reactor `Flux` has a concept of subscriber context (and has full interop with Kotlin `Flow`), generic `Publisher` does not. This means that whenever we convert to a publisher we loose any contextual data that was available. This PR makes minimal changes to get our Spring server working with the updated subscription execution strategy. related: ExpediaGroup#1116
1 parent 3a950ed commit 8b7b47e

File tree

7 files changed

+61
-58
lines changed

7 files changed

+61
-58
lines changed

generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategy.kt

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 Expedia, Inc
2+
* Copyright 2021 Expedia, Inc
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,15 +28,14 @@ import graphql.execution.SubscriptionExecutionStrategy
2828
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
2929
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
3030
import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters
31-
import graphql.execution.reactive.CompletionStageMappingPublisher
3231
import graphql.schema.GraphQLObjectType
3332
import kotlinx.coroutines.flow.Flow
34-
import kotlinx.coroutines.reactive.asPublisher
33+
import kotlinx.coroutines.flow.map
34+
import kotlinx.coroutines.future.await
35+
import kotlinx.coroutines.reactive.asFlow
3536
import org.reactivestreams.Publisher
3637
import java.util.Collections
3738
import java.util.concurrent.CompletableFuture
38-
import java.util.concurrent.CompletionStage
39-
import java.util.function.Function
4039

4140
/**
4241
* [SubscriptionExecutionStrategy] replacement that and allows schema subscription functions
@@ -62,20 +61,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec
6261

6362
//
6463
// when the upstream source event stream completes, subscribe to it and wire in our adapter
65-
val overallResult: CompletableFuture<ExecutionResult> = sourceEventStream.thenApply { publisher ->
66-
if (publisher == null) {
64+
val overallResult: CompletableFuture<ExecutionResult> = sourceEventStream.thenApply { flow ->
65+
if (flow == null) {
6766
ExecutionResultImpl(null, executionContext.errors)
6867
} else {
69-
val mapperFunction = Function<Any, CompletionStage<ExecutionResult>> { eventPayload: Any? ->
68+
val returnFlow = flow.map { eventPayload: Any? ->
7069
executeSubscriptionEvent(
7170
executionContext,
7271
parameters,
7372
eventPayload
74-
)
73+
).await()
7574
}
76-
// we need explicit cast as Kotlin Flow is covariant (Flow<out T> vs Publisher<T>)
77-
val mapSourceToResponse = CompletionStageMappingPublisher<ExecutionResult, Any>(publisher as Publisher<Any>, mapperFunction)
78-
ExecutionResultImpl(mapSourceToResponse, executionContext.errors)
75+
ExecutionResultImpl(returnFlow, executionContext.errors)
7976
}
8077
}
8178

@@ -102,18 +99,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec
10299
private fun createSourceEventStream(
103100
executionContext: ExecutionContext,
104101
parameters: ExecutionStrategyParameters
105-
): CompletableFuture<Publisher<out Any>?> {
102+
): CompletableFuture<Flow<*>?> {
106103
val newParameters = firstFieldOfSubscriptionSelection(parameters)
107104

108105
val fieldFetched = fetchField(executionContext, newParameters)
109106
return fieldFetched.thenApply { fetchedValue ->
110-
val publisher = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) {
111-
is Publisher<*> -> publisherOrFlow
107+
val flow = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) {
108+
is Publisher<*> -> publisherOrFlow.asFlow()
112109
// below explicit cast is required due to the type erasure and Kotlin declaration-site variance vs Java use-site variance
113-
is Flow<*> -> (publisherOrFlow as? Flow<Any>)?.asPublisher()
110+
is Flow<*> -> publisherOrFlow
114111
else -> null
115112
}
116-
publisher
113+
flow
117114
}
118115
}
119116

generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategyTest.kt

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 Expedia, Inc
2+
* Copyright 2021 Expedia, Inc
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -34,9 +34,9 @@ import graphql.schema.GraphQLSchema
3434
import kotlinx.coroutines.InternalCoroutinesApi
3535
import kotlinx.coroutines.delay
3636
import kotlinx.coroutines.flow.Flow
37+
import kotlinx.coroutines.flow.collect
3738
import kotlinx.coroutines.flow.flow
3839
import kotlinx.coroutines.reactive.asPublisher
39-
import kotlinx.coroutines.reactive.collect
4040
import kotlinx.coroutines.runBlocking
4141
import org.junit.jupiter.api.Test
4242
import org.reactivestreams.Publisher
@@ -66,9 +66,9 @@ class FlowSubscriptionExecutionStrategyTest {
6666
fun `verify subscription to flow`() = runBlocking {
6767
val request = ExecutionInput.newExecutionInput().query("subscription { ticker }").build()
6868
val response = testGraphQL.execute(request)
69-
val publisher = response.getData<Publisher<ExecutionResult>>()
69+
val flow = response.getData<Flow<ExecutionResult>>()
7070
val list = mutableListOf<Int>()
71-
publisher.collect {
71+
flow.collect {
7272
list.add(it.getData<Map<String, Int>>().getValue("ticker"))
7373
assertEquals(it.extensions["testKey"], "testValue")
7474
}
@@ -82,9 +82,9 @@ class FlowSubscriptionExecutionStrategyTest {
8282
fun `verify subscription to datafetcher flow`() = runBlocking {
8383
val request = ExecutionInput.newExecutionInput().query("subscription { datafetcher }").build()
8484
val response = testGraphQL.execute(request)
85-
val publisher = response.getData<Publisher<ExecutionResult>>()
85+
val flow = response.getData<Flow<ExecutionResult>>()
8686
val list = mutableListOf<Int>()
87-
publisher.collect {
87+
flow.collect {
8888
val intVal = it.getData<Map<String, Int>>().getValue("datafetcher")
8989
list.add(intVal)
9090
assertEquals(it.extensions["testKey"], "testValue")
@@ -99,9 +99,9 @@ class FlowSubscriptionExecutionStrategyTest {
9999
fun `verify subscription to publisher`() = runBlocking {
100100
val request = ExecutionInput.newExecutionInput().query("subscription { publisherTicker }").build()
101101
val response = testGraphQL.execute(request)
102-
val publisher = response.getData<Publisher<ExecutionResult>>()
102+
val flow = response.getData<Flow<ExecutionResult>>()
103103
val list = mutableListOf<Int>()
104-
publisher.collect {
104+
flow.collect {
105105
list.add(it.getData<Map<String, Int>>().getValue("publisherTicker"))
106106
}
107107
assertEquals(5, list.size)
@@ -117,9 +117,9 @@ class FlowSubscriptionExecutionStrategyTest {
117117
.context(SubscriptionContext("junitHandler"))
118118
.build()
119119
val response = testGraphQL.execute(request)
120-
val publisher = response.getData<Publisher<ExecutionResult>>()
120+
val flow = response.getData<Flow<ExecutionResult>>()
121121
val list = mutableListOf<Int>()
122-
publisher.collect {
122+
flow.collect {
123123
val contextValue = it.getData<Map<String, String>>().getValue("contextualTicker")
124124
assertTrue(contextValue.startsWith("junitHandler:"))
125125
list.add(contextValue.substringAfter("junitHandler:").toInt())
@@ -134,11 +134,11 @@ class FlowSubscriptionExecutionStrategyTest {
134134
fun `verify subscription to failing flow`() = runBlocking {
135135
val request = ExecutionInput.newExecutionInput().query("subscription { alwaysThrows }").build()
136136
val response = testGraphQL.execute(request)
137-
val publisher = response.getData<Publisher<ExecutionResult>>()
137+
val flow = response.getData<Flow<ExecutionResult>>()
138138
val errors = mutableListOf<GraphQLError>()
139139
val results = mutableListOf<Int>()
140140
try {
141-
publisher.collect {
141+
flow.collect {
142142
val dataMap = it.getData<Map<String, Int>>()
143143
if (dataMap != null) {
144144
results.add(dataMap.getValue("alwaysThrows"))
@@ -161,9 +161,9 @@ class FlowSubscriptionExecutionStrategyTest {
161161
fun `verify subscription to exploding flow`() = runBlocking {
162162
val request = ExecutionInput.newExecutionInput().query("subscription { throwsFast }").build()
163163
val response = testGraphQL.execute(request)
164-
val publisher = response.getData<Publisher<ExecutionResult>>()
164+
val flow = response.getData<Flow<ExecutionResult>>()
165165
val errors = response.errors
166-
assertNull(publisher)
166+
assertNull(flow)
167167
assertEquals(1, errors.size)
168168
assertEquals("JUNIT flow failure", errors[0].message.substringAfter(" : "))
169169
}
@@ -172,10 +172,10 @@ class FlowSubscriptionExecutionStrategyTest {
172172
fun `verify subscription alias`() = runBlocking {
173173
val request = ExecutionInput.newExecutionInput().query("subscription { t: ticker }").build()
174174
val response = testGraphQL.execute(request)
175-
val publisher = response.getData<Publisher<ExecutionResult>>()
175+
val flow = response.getData<Flow<ExecutionResult>>()
176176
val list = mutableListOf<Int>()
177-
publisher.collect {
178-
list.add(it.getData<Map<String, Int>>().getValue("t"))
177+
flow.collect { executionResult ->
178+
list.add(executionResult.getData<Map<String, Int>>().getValue("t"))
179179
}
180180
assertEquals(5, list.size)
181181
for (i in list.indices) {

servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandler.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
3131
import com.fasterxml.jackson.module.kotlin.convertValue
3232
import com.fasterxml.jackson.module.kotlin.readValue
3333
import kotlinx.coroutines.ExperimentalCoroutinesApi
34+
import kotlinx.coroutines.reactor.asFlux
3435
import kotlinx.coroutines.runBlocking
3536
import org.slf4j.LoggerFactory
3637
import org.springframework.web.reactive.socket.WebSocketSession
@@ -130,7 +131,7 @@ class ApolloSubscriptionProtocolHandler(
130131
try {
131132
val request = objectMapper.convertValue<GraphQLRequest>(payload)
132133
return subscriptionHandler.executeSubscription(request, context)
133-
.toFlux()
134+
.asFlux()
134135
.map {
135136
if (it.errors?.isNotEmpty() == true) {
136137
SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it)

servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 Expedia, Inc
2+
* Copyright 2021 Expedia, Inc
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,9 +26,9 @@ import com.expediagroup.graphql.server.types.GraphQLRequest
2626
import com.expediagroup.graphql.server.types.GraphQLResponse
2727
import graphql.ExecutionResult
2828
import graphql.GraphQL
29-
import org.reactivestreams.Publisher
30-
import reactor.core.publisher.Flux
31-
import reactor.kotlin.core.publisher.toFlux
29+
import kotlinx.coroutines.flow.Flow
30+
import kotlinx.coroutines.flow.catch
31+
import kotlinx.coroutines.flow.map
3232

3333
/**
3434
* Default Spring implementation of GraphQL subscription handler.
@@ -38,17 +38,16 @@ open class SpringGraphQLSubscriptionHandler(
3838
private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null
3939
) {
4040

41-
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flux<GraphQLResponse<*>> {
41+
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flow<GraphQLResponse<*>> {
4242
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
4343
val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry)
4444

4545
return graphQL.execute(input)
46-
.getData<Publisher<ExecutionResult>>()
47-
.toFlux()
46+
.getData<Flow<ExecutionResult>>()
4847
.map { result -> result.toGraphQLResponse() }
49-
.onErrorResume { throwable ->
48+
.catch { throwable ->
5049
val error = throwable.toGraphQLError()
51-
Flux.just(GraphQLResponse<Any?>(errors = listOf(error.toGraphQLKotlinType())))
50+
emit(GraphQLResponse<Any?>(errors = listOf(error.toGraphQLKotlinType())))
5251
}
5352
}
5453
}

servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/SubscriptionConfigurationTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import graphql.GraphQL
2727
import graphql.schema.GraphQLSchema
2828
import io.mockk.every
2929
import io.mockk.mockk
30+
import kotlinx.coroutines.flow.flowOf
3031
import org.assertj.core.api.Assertions.assertThat
3132
import org.junit.jupiter.api.Test
3233
import org.springframework.boot.autoconfigure.AutoConfigurations
@@ -125,7 +126,7 @@ class SubscriptionConfigurationTest {
125126

126127
@Bean
127128
fun subscriptionHandler(): SpringGraphQLSubscriptionHandler = mockk {
128-
every { executeSubscription(any(), any()) } returns Flux.empty()
129+
every { executeSubscription(any(), any()) } returns flowOf()
129130
}
130131

131132
@Bean

servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 Expedia, Inc
2+
* Copyright 2021 Expedia, Inc
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package com.expediagroup.graphql.server.spring.execution
1919
import com.expediagroup.graphql.generator.SchemaGeneratorConfig
2020
import com.expediagroup.graphql.generator.TopLevelObject
2121
import com.expediagroup.graphql.generator.exceptions.GraphQLKotlinException
22+
import com.expediagroup.graphql.generator.execution.FlowSubscriptionExecutionStrategy
2223
import com.expediagroup.graphql.generator.execution.GraphQLContext
2324
import com.expediagroup.graphql.generator.toSchema
2425
import com.expediagroup.graphql.server.execution.DefaultDataLoaderRegistryFactory
@@ -30,6 +31,7 @@ import graphql.GraphQL
3031
import graphql.schema.DataFetchingEnvironment
3132
import graphql.schema.GraphQLSchema
3233
import io.mockk.mockk
34+
import kotlinx.coroutines.reactor.asFlux
3335
import org.dataloader.DataLoader
3436
import org.junit.jupiter.api.Test
3537
import reactor.core.publisher.Flux
@@ -51,7 +53,9 @@ class SpringGraphQLSubscriptionHandlerTest {
5153
queries = listOf(TopLevelObject(BasicQuery())),
5254
subscriptions = listOf(TopLevelObject(BasicSubscription()))
5355
)
54-
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema).build()
56+
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema)
57+
.subscriptionExecutionStrategy(FlowSubscriptionExecutionStrategy())
58+
.build()
5559
private val mockLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
5660
override val dataLoaderName: String = "MockDataLoader"
5761
override fun getDataLoader(): DataLoader<String, String> = DataLoader<String, String> { ids ->
@@ -66,7 +70,7 @@ class SpringGraphQLSubscriptionHandlerTest {
6670
@Test
6771
fun `verify subscription`() {
6872
val request = GraphQLRequest(query = "subscription { ticker }")
69-
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
73+
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()
7074

7175
StepVerifier.create(responseFlux)
7276
.thenConsumeWhile { response ->
@@ -84,7 +88,7 @@ class SpringGraphQLSubscriptionHandlerTest {
8488
@Test
8589
fun `verify subscription with data loader`() {
8690
val request = GraphQLRequest(query = "subscription { dataLoaderValue }")
87-
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
91+
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()
8892

8993
StepVerifier.create(responseFlux)
9094
.thenConsumeWhile { response ->
@@ -105,7 +109,7 @@ class SpringGraphQLSubscriptionHandlerTest {
105109
fun `verify subscription with context`() {
106110
val request = GraphQLRequest(query = "subscription { contextualTicker }")
107111
val context = SubscriptionContext("junitHandler")
108-
val responseFlux = subscriptionHandler.executeSubscription(request, context)
112+
val responseFlux = subscriptionHandler.executeSubscription(request, context).asFlux()
109113

110114
StepVerifier.create(responseFlux)
111115
.thenConsumeWhile { response ->
@@ -126,7 +130,7 @@ class SpringGraphQLSubscriptionHandlerTest {
126130
@Test
127131
fun `verify subscription to failing publisher`() {
128132
val request = GraphQLRequest(query = "subscription { alwaysThrows }")
129-
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
133+
val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux()
130134

131135
StepVerifier.create(responseFlux)
132136
.assertNext { response ->

servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandlerTest.kt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ import io.mockk.mockk
3737
import io.mockk.verify
3838
import io.mockk.verifyOrder
3939
import kotlinx.coroutines.ExperimentalCoroutinesApi
40+
import kotlinx.coroutines.flow.flowOf
41+
import kotlinx.coroutines.flow.map
4042
import org.junit.jupiter.api.Test
4143
import org.springframework.web.reactive.socket.WebSocketSession
42-
import reactor.core.publisher.Flux
4344
import reactor.test.StepVerifier
4445
import java.time.Duration
4546
import kotlin.test.assertEquals
@@ -297,7 +298,7 @@ class ApolloSubscriptionProtocolHandlerTest {
297298
every { id } returns "123"
298299
}
299300
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
300-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
301+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
301302
}
302303

303304
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
@@ -329,7 +330,7 @@ class ApolloSubscriptionProtocolHandlerTest {
329330
}
330331
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
331332
// Never closes
332-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.interval(Duration.ofSeconds(1)).map { GraphQLResponse("myData") }
333+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(Duration.ofSeconds(1)).map { GraphQLResponse("myData") }
333334
}
334335

335336
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
@@ -360,7 +361,7 @@ class ApolloSubscriptionProtocolHandlerTest {
360361
every { id } returns "123"
361362
}
362363
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
363-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
364+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
364365
}
365366

366367
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
@@ -394,7 +395,7 @@ class ApolloSubscriptionProtocolHandlerTest {
394395
every { id } returns "123"
395396
}
396397
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
397-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData"))
398+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData"))
398399
}
399400

400401
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
@@ -427,7 +428,7 @@ class ApolloSubscriptionProtocolHandlerTest {
427428
}
428429
val errors = listOf(GraphQLServerError("My GraphQL Error"))
429430
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
430-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse<Any>(errors = errors))
431+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse<Any>(errors = errors))
431432
}
432433

433434
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
@@ -503,7 +504,7 @@ class ApolloSubscriptionProtocolHandlerTest {
503504
}
504505
val expectedResponse = GraphQLResponse("myData")
505506
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk {
506-
every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(expectedResponse)
507+
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(expectedResponse)
507508
}
508509
val subscriptionHooks: ApolloSubscriptionHooks = mockk {
509510
every { onConnect(any(), any(), any()) } returns null

0 commit comments

Comments
 (0)