Skip to content

Allow for custom subscription return types #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/execution/subscriptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ toSchema(

### Subscription Hooks

Through the hooks a new method was added `didGenerateSubscriptionType` which is called after a new subscription type is
generated but before it is added to the schema. The other hook are still called so you can add logic for the types and
#### `didGenerateSubscriptionType`
This hook is called after a new subscription type is generated but before it is added to the schema. The other generator hooks are still called so you can add logic for the types and
validation of subscriptions the same as queries and mutations.

#### `isValidSubscriptionReturnType`
This hook is called when generating the functions for each subscription. It allows for changing the rules of what classes can be used as the return type. By default, graphql-java supports `org.reactivestreams.Publisher`.

To effectively use this hook, you should also override the `willResolveMonad` hook, and if you are using `graphql-kotlin-spring-server` you should override the `GraphQL` bean to specify a custom subscription execution strategy.

### Server Implementation

The server that runs your GraphQL schema will have to support some method for subscriptions, like WebSockets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import kotlin.reflect.KFunction

class InvalidSubscriptionTypeException(kClass: KClass<*>, kFunction: KFunction<*>? = null) :
GraphQLKotlinException(
"Schema requires all subscriptions to be public and return a type of Publisher. " +
"Schema requires all subscriptions to be public and return a valid type from the hooks. " +
"${kClass.simpleName} has ${kClass.visibility} visibility modifier. " +
if (kFunction != null) "The function return type is ${kFunction.returnType.getSimpleName()}" else ""
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import com.expediagroup.graphql.exceptions.InvalidSubscriptionTypeException
import com.expediagroup.graphql.generator.SchemaGenerator
import com.expediagroup.graphql.generator.extensions.getValidFunctions
import com.expediagroup.graphql.generator.extensions.isNotPublic
import com.expediagroup.graphql.generator.extensions.isSubclassOf
import graphql.schema.GraphQLObjectType
import org.reactivestreams.Publisher

internal fun generateSubscriptions(generator: SchemaGenerator, subscriptions: List<TopLevelObject>): GraphQLObjectType? {
if (subscriptions.isEmpty()) {
Expand All @@ -34,18 +32,20 @@ internal fun generateSubscriptions(generator: SchemaGenerator, subscriptions: Li
subscriptionBuilder.name(generator.config.topLevelNames.subscription)

for (subscription in subscriptions) {
if (subscription.kClass.isNotPublic()) {
throw InvalidSubscriptionTypeException(subscription.kClass)
val kClass = subscription.kClass

if (kClass.isNotPublic()) {
throw InvalidSubscriptionTypeException(kClass)
}

subscription.kClass.getValidFunctions(generator.config.hooks)
kClass.getValidFunctions(generator.config.hooks)
.forEach {
if (it.returnType.isSubclassOf(Publisher::class).not()) {
throw InvalidSubscriptionTypeException(subscription.kClass, it)
if (generator.config.hooks.isValidSubscriptionReturnType(kClass, it).not()) {
throw InvalidSubscriptionTypeException(kClass, it)
}

val function = generateFunction(generator, it, generator.config.topLevelNames.subscription, subscription.obj)
val functionFromHook = generator.config.hooks.didGenerateSubscriptionField(subscription.kClass, it, function)
val functionFromHook = generator.config.hooks.didGenerateSubscriptionField(kClass, it, function)
subscriptionBuilder.field(functionFromHook)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.expediagroup.graphql.exceptions.EmptyMutationTypeException
import com.expediagroup.graphql.exceptions.EmptyObjectTypeException
import com.expediagroup.graphql.exceptions.EmptyQueryTypeException
import com.expediagroup.graphql.exceptions.EmptySubscriptionTypeException
import com.expediagroup.graphql.generator.extensions.isSubclassOf
import graphql.schema.FieldCoordinates
import graphql.schema.GraphQLCodeRegistry
import graphql.schema.GraphQLFieldDefinition
Expand All @@ -33,6 +34,7 @@ import graphql.schema.GraphQLSchema
import graphql.schema.GraphQLSchemaElement
import graphql.schema.GraphQLType
import graphql.schema.GraphQLTypeUtil
import org.reactivestreams.Publisher
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KProperty
Expand Down Expand Up @@ -91,6 +93,15 @@ interface SchemaGeneratorHooks {
@Suppress("Detekt.FunctionOnlyReturningConstant")
fun isValidFunction(kClass: KClass<*>, function: KFunction<*>): Boolean = true

/**
* Called when looking at the subscription functions to determine if it is using a valid return type.
* By default, graphql-java supports org.reactivestreams.Publisher in the subscription execution strategy.
* If you want to provide a custom execution strategy, you may need to override this hook.
*
* NOTE: You will most likely need to also override the [willResolveMonad] hook to allow for your custom type to be generated.
*/
fun isValidSubscriptionReturnType(kClass: KClass<*>, function: KFunction<*>): Boolean = function.returnType.isSubclassOf(Publisher::class)

/**
* Called after `willGenerateGraphQLType` and before `didGenerateGraphQLType`.
* Enables you to change the wiring, e.g. apply directives to alter the target type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@ import com.expediagroup.graphql.TopLevelNames
import com.expediagroup.graphql.TopLevelObject
import com.expediagroup.graphql.exceptions.EmptySubscriptionTypeException
import com.expediagroup.graphql.exceptions.InvalidSubscriptionTypeException
import com.expediagroup.graphql.generator.extensions.getTypeOfFirstArgument
import com.expediagroup.graphql.generator.extensions.isSubclassOf
import com.expediagroup.graphql.hooks.SchemaGeneratorHooks
import graphql.schema.GraphQLFieldDefinition
import io.mockk.every
import io.reactivex.rxjava3.core.Flowable
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.reactivestreams.Publisher
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KType
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
Expand Down Expand Up @@ -123,6 +128,26 @@ internal class GenerateSubscriptionTest : TypeTestHelper() {
assertEquals(3, result?.fieldDefinitions?.size)
assertNotNull(result?.fieldDefinitions?.find { it.name == "changedField" })
}

@Test
fun `given custom hooks that allow custom subscription return types, it should generate a valid schema`() {
val subscriptions = listOf(TopLevelObject(MyCustomSubscriptionClass()))

class CustomHooks : SchemaGeneratorHooks {
override fun isValidSubscriptionReturnType(kClass: KClass<*>, function: KFunction<*>) = function.returnType.isSubclassOf(Flow::class)
override fun willResolveMonad(type: KType): KType = when {
type.isSubclassOf(Flow::class) -> type.getTypeOfFirstArgument()
else -> this.willResolveMonad(type)
}
}

every { config.hooks } returns CustomHooks()

val result = generateSubscriptions(generator, subscriptions)

assertEquals(1, result?.fieldDefinitions?.size)
assertNotNull(result?.fieldDefinitions?.find { it.name == "number" })
}
}

class MyPublicTestSubscription {
Expand All @@ -138,6 +163,11 @@ class MyInvalidSubscriptionClass {
fun number(): Int = 1
}

class MyCustomSubscriptionClass {
@Suppress("Detekt.FunctionOnlyReturningConstant")
fun number(): Flow<Int> = flowOf(1)
}

private class MyPrivateTestSubscription {
fun counter(): Publisher<Int> = Flowable.just(3)
}
Expand Down