Skip to content

Commit e49289b

Browse files
Add native subscription support for coroutine Flows
Implement a SubscriptionExecutionStrategy that allows for `Flow`s and `Publisher`s to be returned from graphql schema elements, and can be processed as a `Flow` by subscription consumers. Relax restrictions that look for `Publisher`s to also allow `Flow`s. Fixes ExpediaGroup#358
1 parent b129244 commit e49289b

File tree

6 files changed

+306
-2
lines changed

6 files changed

+306
-2
lines changed

graphql-kotlin-schema-generator/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies {
1111
api("com.graphql-java:graphql-java:$graphQLJavaVersion")
1212
// TODO change below from api to implementation?
1313
api("org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:$kotlinCoroutinesVersion")
14+
api("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:$kotlinCoroutinesVersion")
1415
api("org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion")
1516
api("io.github.classgraph:classgraph:$classGraphVersion")
1617
api("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion")

graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/exceptions/InvalidSubscriptionTypeException.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import kotlin.reflect.KFunction
2222

2323
class InvalidSubscriptionTypeException(kClass: KClass<*>, kFunction: KFunction<*>? = null) :
2424
GraphQLKotlinException(
25-
"Schema requires all subscriptions to be public and return a type of Publisher. " +
25+
"Schema requires all subscriptions to be public and return a type of Publisher or Flow. " +
2626
"${kClass.simpleName} has ${kClass.visibility} visibility modifier. " +
2727
if (kFunction != null) "The function return type is ${kFunction.returnType.getSimpleName()}" else ""
2828
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package com.expediagroup.graphql.execution
2+
3+
import graphql.AssertException
4+
import graphql.ExecutionResult
5+
import graphql.ExecutionResultImpl
6+
import graphql.execution.DataFetcherExceptionHandler
7+
import graphql.execution.ExecutionContext
8+
import graphql.execution.ExecutionStrategy
9+
import graphql.execution.ExecutionStrategyParameters
10+
import graphql.execution.FetchedValue
11+
import graphql.execution.SimpleDataFetcherExceptionHandler
12+
import graphql.execution.SubscriptionExecutionStrategy
13+
import kotlinx.coroutines.flow.Flow
14+
import kotlinx.coroutines.flow.map
15+
import kotlinx.coroutines.future.await
16+
import kotlinx.coroutines.reactive.asFlow
17+
import org.reactivestreams.Publisher
18+
import java.util.Collections
19+
import java.util.concurrent.CompletableFuture
20+
21+
class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : SubscriptionExecutionStrategy(dfe) {
22+
constructor() : this(SimpleDataFetcherExceptionHandler())
23+
24+
override fun execute(
25+
executionContext: ExecutionContext,
26+
parameters: ExecutionStrategyParameters
27+
): CompletableFuture<ExecutionResult> {
28+
29+
val sourceEventStream = createSourceEventStream(executionContext, parameters)
30+
31+
//
32+
// when the upstream source event stream completes, subscribe to it and wire in our adapter
33+
return sourceEventStream.thenApply { sourceFlow ->
34+
if (sourceFlow == null) {
35+
ExecutionResultImpl(null, executionContext.errors)
36+
} else {
37+
val returnFlow = sourceFlow.map {
38+
executeSubscriptionEvent(executionContext, parameters, it).await()
39+
}
40+
ExecutionResultImpl(returnFlow, executionContext.errors)
41+
}
42+
}
43+
}
44+
45+
/*
46+
https://github.com/facebook/graphql/blob/master/spec/Section%206%20--%20Execution.md
47+
48+
CreateSourceEventStream(subscription, schema, variableValues, initialValue):
49+
50+
Let {subscriptionType} be the root Subscription type in {schema}.
51+
Assert: {subscriptionType} is an Object type.
52+
Let {selectionSet} be the top level Selection Set in {subscription}.
53+
Let {rootField} be the first top level field in {selectionSet}.
54+
Let {argumentValues} be the result of {CoerceArgumentValues(subscriptionType, rootField, variableValues)}.
55+
Let {fieldStream} be the result of running {ResolveFieldEventStream(subscriptionType, initialValue, rootField, argumentValues)}.
56+
Return {fieldStream}.
57+
*/
58+
private fun createSourceEventStream(
59+
executionContext: ExecutionContext,
60+
parameters: ExecutionStrategyParameters
61+
): CompletableFuture<Flow<*>> {
62+
val newParameters = firstFieldOfSubscriptionSelection(parameters)
63+
64+
val fieldFetched = fetchField(executionContext, newParameters)
65+
return fieldFetched.thenApply { fetchedValue ->
66+
val flow = when (val publisherOrFlow = fetchedValue.fetchedValue) {
67+
null -> null
68+
is Publisher<*> -> publisherOrFlow.asFlow()
69+
is Flow<*> -> publisherOrFlow
70+
else -> throw AssertException(
71+
"You data fetcher must return a Flow of events when using graphql subscriptions"
72+
)
73+
}
74+
flow
75+
}
76+
}
77+
78+
/*
79+
ExecuteSubscriptionEvent(subscription, schema, variableValues, initialValue):
80+
81+
Let {subscriptionType} be the root Subscription type in {schema}.
82+
Assert: {subscriptionType} is an Object type.
83+
Let {selectionSet} be the top level Selection Set in {subscription}.
84+
Let {data} be the result of running {ExecuteSelectionSet(selectionSet, subscriptionType, initialValue, variableValues)} normally (allowing parallelization).
85+
Let {errors} be any field errors produced while executing the selection set.
86+
Return an unordered map containing {data} and {errors}.
87+
88+
Note: The {ExecuteSubscriptionEvent()} algorithm is intentionally similar to {ExecuteQuery()} since this is how each event result is produced.
89+
*/
90+
91+
private fun executeSubscriptionEvent(
92+
executionContext: ExecutionContext,
93+
parameters: ExecutionStrategyParameters,
94+
eventPayload: Any?
95+
): CompletableFuture<ExecutionResult> {
96+
val newExecutionContext = executionContext.transform { builder -> builder.root(eventPayload) }
97+
98+
val newParameters = firstFieldOfSubscriptionSelection(parameters)
99+
val fetchedValue = FetchedValue.newFetchedValue().fetchedValue(eventPayload)
100+
.rawFetchedValue(eventPayload)
101+
.localContext(parameters.localContext)
102+
.build()
103+
return completeField(newExecutionContext, newParameters, fetchedValue).fieldValue
104+
.thenApply { executionResult -> wrapWithRootFieldName(newParameters, executionResult) }
105+
}
106+
107+
private fun wrapWithRootFieldName(
108+
parameters: ExecutionStrategyParameters,
109+
executionResult: ExecutionResult
110+
): ExecutionResult {
111+
val rootFieldName = getRootFieldName(parameters)
112+
return ExecutionResultImpl(
113+
Collections.singletonMap<String, Any>(rootFieldName, executionResult.getData<Any>()),
114+
executionResult.errors
115+
)
116+
}
117+
118+
private fun getRootFieldName(parameters: ExecutionStrategyParameters): String {
119+
val rootField = parameters.field.singleField
120+
return if (rootField.alias != null) rootField.alias else rootField.name
121+
}
122+
123+
private fun firstFieldOfSubscriptionSelection(
124+
parameters: ExecutionStrategyParameters
125+
): ExecutionStrategyParameters {
126+
val fields = parameters.fields
127+
val firstField = fields.getSubField(fields.keys[0])
128+
129+
val fieldPath = parameters.path.segment(ExecutionStrategy.mkNameForPath(firstField.singleField))
130+
return parameters.transform { builder -> builder.field(firstField).path(fieldPath) }
131+
}
132+
}

graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/types/generateSubscription.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import com.expediagroup.graphql.generator.extensions.getValidFunctions
2323
import com.expediagroup.graphql.generator.extensions.isNotPublic
2424
import com.expediagroup.graphql.generator.extensions.isSubclassOf
2525
import graphql.schema.GraphQLObjectType
26+
import kotlinx.coroutines.flow.Flow
2627
import org.reactivestreams.Publisher
2728

2829
internal fun generateSubscriptions(generator: SchemaGenerator, subscriptions: List<TopLevelObject>): GraphQLObjectType? {
@@ -40,7 +41,7 @@ internal fun generateSubscriptions(generator: SchemaGenerator, subscriptions: Li
4041

4142
subscription.kClass.getValidFunctions(generator.config.hooks)
4243
.forEach {
43-
if (it.returnType.isSubclassOf(Publisher::class).not()) {
44+
if (it.returnType.isSubclassOf(Publisher::class).or(it.returnType.isSubclassOf(Flow::class)).not()) {
4445
throw InvalidSubscriptionTypeException(subscription.kClass, it)
4546
}
4647

graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/types/utils/functionReturnTypes.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package com.expediagroup.graphql.generator.types.utils
1919
import com.expediagroup.graphql.generator.extensions.getTypeOfFirstArgument
2020
import com.expediagroup.graphql.generator.extensions.isSubclassOf
2121
import graphql.execution.DataFetcherResult
22+
import kotlinx.coroutines.flow.Flow
2223
import org.reactivestreams.Publisher
2324
import java.util.concurrent.CompletableFuture
2425
import kotlin.reflect.KType
@@ -41,6 +42,7 @@ import kotlin.reflect.KType
4142
internal fun getWrappedReturnType(returnType: KType): KType {
4243
return when {
4344
returnType.isSubclassOf(Publisher::class) -> returnType.getTypeOfFirstArgument()
45+
returnType.isSubclassOf(Flow::class) -> returnType.getTypeOfFirstArgument()
4446
returnType.isSubclassOf(DataFetcherResult::class) -> returnType.getTypeOfFirstArgument()
4547
returnType.isSubclassOf(CompletableFuture::class) -> {
4648
val wrappedType = returnType.getTypeOfFirstArgument()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright 2020 Expedia, Inc
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.expediagroup.graphql.execution
18+
19+
import com.expediagroup.graphql.SchemaGeneratorConfig
20+
import com.expediagroup.graphql.TopLevelObject
21+
import com.expediagroup.graphql.annotations.GraphQLContext
22+
import com.expediagroup.graphql.exceptions.GraphQLKotlinException
23+
import com.expediagroup.graphql.toSchema
24+
import graphql.ExecutionInput
25+
import graphql.ExecutionResult
26+
import graphql.ExecutionResultImpl
27+
import graphql.GraphQL
28+
import graphql.GraphQLError
29+
import graphql.GraphqlErrorBuilder
30+
import graphql.schema.GraphQLSchema
31+
import kotlinx.coroutines.InternalCoroutinesApi
32+
import kotlinx.coroutines.delay
33+
import kotlinx.coroutines.flow.Flow
34+
import kotlinx.coroutines.flow.catch
35+
import kotlinx.coroutines.flow.collect
36+
import kotlinx.coroutines.flow.flow
37+
import kotlinx.coroutines.flow.onEach
38+
import kotlinx.coroutines.runBlocking
39+
import org.junit.jupiter.api.Test
40+
import kotlin.test.assertEquals
41+
import kotlin.test.assertNull
42+
import kotlin.test.assertTrue
43+
44+
@InternalCoroutinesApi
45+
class FlowSubscriptionExecutionStrategyTest {
46+
47+
private val testSchema: GraphQLSchema = toSchema(
48+
config = SchemaGeneratorConfig(supportedPackages = listOf("com.expediagroup.graphql.spring.execution")),
49+
queries = listOf(TopLevelObject(BasicQuery())),
50+
subscriptions = listOf(TopLevelObject(FlowSubscription()))
51+
)
52+
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema).subscriptionExecutionStrategy(FlowSubscriptionExecutionStrategy()).build()
53+
54+
@Test
55+
fun `verify subscription`() = runBlocking {
56+
val request = ExecutionInput.newExecutionInput().query("subscription { ticker }").build()
57+
val response = testGraphQL.execute(request)
58+
val flow = response.getData<Flow<ExecutionResult>>()
59+
val list = mutableListOf<Int>()
60+
flow.collect {
61+
list.add(it.getData<Map<String, Int>>().getValue("ticker"))
62+
}
63+
assertEquals(5, list.size)
64+
for (i in list.indices) {
65+
assertEquals(i + 1, list[i])
66+
}
67+
}
68+
69+
@Test
70+
fun `verify subscription with context`() = runBlocking {
71+
val request = ExecutionInput.newExecutionInput()
72+
.query("subscription { contextualTicker }")
73+
.context(SubscriptionContext("junitHandler"))
74+
.build()
75+
val response = testGraphQL.execute(request)
76+
val flow = response.getData<Flow<ExecutionResult>>()
77+
val list = mutableListOf<Int>()
78+
flow.collect {
79+
val contextValue = it.getData<Map<String, String>>().getValue("contextualTicker")
80+
assertTrue(contextValue.startsWith("junitHandler:"))
81+
list.add(contextValue.substringAfter("junitHandler:").toInt())
82+
}
83+
assertEquals(5, list.size)
84+
for (i in list.indices) {
85+
assertEquals(i + 1, list[i])
86+
}
87+
}
88+
89+
@Test
90+
fun `verify subscription to failing publisher`() = runBlocking {
91+
val request = ExecutionInput.newExecutionInput().query("subscription { alwaysThrows }").build()
92+
val response = testGraphQL.execute(request)
93+
val flow = response.getData<Flow<ExecutionResult>>()
94+
val errors = mutableListOf<GraphQLError>()
95+
val results = mutableListOf<Int>()
96+
flow.onEach {
97+
val dataMap = it.getData<Map<String, Int>>()
98+
if (dataMap != null) {
99+
results.add(dataMap.getValue("alwaysThrows"))
100+
}
101+
errors.addAll(it.errors)
102+
}.catch {
103+
errors.add(GraphqlErrorBuilder.newError().message(it.message).build())
104+
}.collect()
105+
assertEquals(2, results.size)
106+
for (i in results.indices) {
107+
assertEquals(i + 1, results[i])
108+
}
109+
assertEquals(1, errors.size)
110+
assertEquals("JUNIT subscription failure", errors[0].message)
111+
}
112+
113+
@Test
114+
fun `verify subscription to exploding publisher`() = runBlocking {
115+
val request = ExecutionInput.newExecutionInput().query("subscription { throwsFast }").build()
116+
val response = testGraphQL.execute(request)
117+
val flow = response.getData<Flow<ExecutionResult>>()
118+
val errors = response.errors
119+
assertNull(flow)
120+
assertEquals(1, errors.size)
121+
assertEquals("JUNIT flow failure", errors[0].message.substringAfter(" : "))
122+
}
123+
124+
// GraphQL spec requires at least single query to be present as Query type is needed to run introspection queries
125+
// see: https://github.com/graphql/graphql-spec/issues/490 and https://github.com/graphql/graphql-spec/issues/568
126+
class BasicQuery {
127+
@Suppress("Detekt.FunctionOnlyReturningConstant")
128+
fun query(): String = "hello"
129+
}
130+
131+
class FlowSubscription {
132+
fun ticker(): Flow<Int> {
133+
return flow {
134+
for (i in 1..5) {
135+
delay(100)
136+
emit(i)
137+
}
138+
}
139+
}
140+
141+
fun throwsFast(): Flow<Int> {
142+
throw GraphQLKotlinException("JUNIT flow failure")
143+
}
144+
145+
fun alwaysThrows(): Flow<Int> {
146+
return flow {
147+
for (i in 1..5) {
148+
if (i > 2) {
149+
throw GraphQLKotlinException("JUNIT subscription failure")
150+
}
151+
delay(100)
152+
emit(i)
153+
}
154+
}
155+
}
156+
157+
fun contextualTicker(@GraphQLContext context: SubscriptionContext): Flow<String> {
158+
return flow {
159+
for (i in 1..5) {
160+
delay(100)
161+
emit("${context.value}:$i")
162+
}
163+
}
164+
}
165+
}
166+
167+
data class SubscriptionContext(val value: String)
168+
}

0 commit comments

Comments
 (0)