Skip to content

Commit a01c6d5

Browse files
committed
Inherit parent context in coRouter DSL
This commit also allows context override, as it is useful for the nested router use case. Closes gh-31831
1 parent 8d4deca commit a01c6d5

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

Diff for: spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

+1-4
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
144144
* @see RouterFunctions.nest
145145
*/
146146
fun RequestPredicate.nest(r: (CoRouterFunctionDsl.() -> Unit)) {
147-
builder.add(nest(this, CoRouterFunctionDsl(r).build()))
147+
builder.add(nest(this, CoRouterFunctionDsl(r).also { it.contextProvider = contextProvider }.build()))
148148
}
149149

150150

@@ -628,9 +628,6 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
628628
* @since 6.1
629629
*/
630630
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
631-
if (this.contextProvider != null) {
632-
throw IllegalStateException("The Coroutine context provider should not be defined more than once")
633-
}
634631
this.contextProvider = provider
635632
}
636633

Diff for: spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

+69
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,45 @@ class CoRouterFunctionDslTests {
193193
.verifyComplete()
194194
}
195195

196+
@Test
197+
fun nestedContextProvider() {
198+
val mockRequest = get("https://example.com/nested/")
199+
.header("Custom-Header", "foo")
200+
.build()
201+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
202+
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
203+
.expectNextMatches { response ->
204+
response.headers().getFirst("context")!!.contains("foo")
205+
}
206+
.verifyComplete()
207+
}
208+
209+
@Test
210+
fun nestedContextProviderWithOverride() {
211+
val mockRequest = get("https://example.com/nested/")
212+
.header("Custom-Header", "foo")
213+
.build()
214+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
215+
StepVerifier.create(nestedRouterWithContextProviderOverride.route(request).flatMap { it.handle(request) })
216+
.expectNextMatches { response ->
217+
response.headers().getFirst("context")!!.contains("foo")
218+
}
219+
.verifyComplete()
220+
}
221+
222+
@Test
223+
fun doubleNestedContextProvider() {
224+
val mockRequest = get("https://example.com/nested/nested/")
225+
.header("Custom-Header", "foo")
226+
.build()
227+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
228+
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
229+
.expectNextMatches { response ->
230+
response.headers().getFirst("context")!!.contains("foo")
231+
}
232+
.verifyComplete()
233+
}
234+
196235
@Test
197236
fun contextProviderAndFilter() {
198237
val mockRequest = get("https://example.com/")
@@ -323,6 +362,36 @@ class CoRouterFunctionDslTests {
323362
}
324363
}
325364

365+
private val nestedRouterWithContextProvider = coRouter {
366+
context {
367+
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
368+
}
369+
"/nested".nest {
370+
GET("/") {
371+
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
372+
}
373+
"/nested".nest {
374+
GET("/") {
375+
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
376+
}
377+
}
378+
}
379+
}
380+
381+
private val nestedRouterWithContextProviderOverride = coRouter {
382+
context {
383+
CoroutineName("parent-context")
384+
}
385+
"/nested".nest {
386+
context {
387+
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
388+
}
389+
GET("/") {
390+
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
391+
}
392+
}
393+
}
394+
326395
private val routerWithoutContext = coRouter {
327396
GET("/") {
328397
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()

0 commit comments

Comments
 (0)