Skip to content

Commit 29a4dab

Browse files
committed
Support @ModelAttribute with suspending function in WebFlux
Closes gh-30894
1 parent f5f8eab commit 29a4dab

File tree

2 files changed

+112
-5
lines changed

2 files changed

+112
-5
lines changed

Diff for: spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelInitializer.java

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.web.reactive.result.method.annotation;
1818

19+
import java.lang.reflect.Method;
1920
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.List;
@@ -25,6 +26,7 @@
2526
import reactor.core.publisher.Mono;
2627

2728
import org.springframework.core.Conventions;
29+
import org.springframework.core.KotlinDetector;
2830
import org.springframework.core.MethodParameter;
2931
import org.springframework.core.ReactiveAdapter;
3032
import org.springframework.core.ReactiveAdapterRegistry;
@@ -45,6 +47,7 @@
4547
* default model initialization through {@code @ModelAttribute} methods.
4648
*
4749
* @author Rossen Stoyanchev
50+
* @author Sebastien Deleuze
4851
* @since 5.0
4952
*/
5053
class ModelInitializer {
@@ -119,18 +122,22 @@ private Mono<Void> handleResult(HandlerResult handlerResult, BindingContext bind
119122
Object value = handlerResult.getReturnValue();
120123
if (value != null) {
121124
ResolvableType type = handlerResult.getReturnType();
125+
MethodParameter typeSource = handlerResult.getReturnTypeSource();
122126
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(type.resolve(), value);
123-
if (isAsyncVoidType(type, adapter)) {
127+
if (isAsyncVoidType(type, typeSource, adapter)) {
124128
return Mono.from(adapter.toPublisher(value));
125129
}
126-
String name = getAttributeName(handlerResult.getReturnTypeSource());
130+
String name = getAttributeName(typeSource);
127131
bindingContext.getModel().asMap().putIfAbsent(name, value);
128132
}
129133
return Mono.empty();
130134
}
131135

132-
private boolean isAsyncVoidType(ResolvableType type, @Nullable ReactiveAdapter adapter) {
133-
return (adapter != null && (adapter.isNoValue() || type.resolveGeneric() == Void.class));
136+
137+
private boolean isAsyncVoidType(ResolvableType type, MethodParameter typeSource, @Nullable ReactiveAdapter adapter) {
138+
Method method = typeSource.getMethod();
139+
return (adapter != null && (adapter.isNoValue() || type.resolveGeneric() == Void.class)) ||
140+
(method != null && KotlinDetector.isSuspendingFunction(method) && typeSource.getParameterType() == void.class);
134141
}
135142

136143
private String getAttributeName(MethodParameter param) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.reactive.result.method.annotation
18+
19+
import kotlinx.coroutines.delay
20+
import org.assertj.core.api.Assertions
21+
import org.junit.jupiter.api.BeforeEach
22+
import org.junit.jupiter.api.Test
23+
import org.springframework.context.support.StaticApplicationContext
24+
import org.springframework.core.ReactiveAdapterRegistry
25+
import org.springframework.ui.Model
26+
import org.springframework.web.bind.annotation.GetMapping
27+
import org.springframework.web.bind.annotation.ModelAttribute
28+
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer
29+
import org.springframework.web.method.HandlerMethod
30+
import org.springframework.web.server.ServerWebExchange
31+
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest
32+
import org.springframework.web.testfixture.method.ResolvableMethod
33+
import org.springframework.web.testfixture.server.MockServerWebExchange
34+
import reactor.core.publisher.Mono
35+
import java.time.Duration
36+
37+
/**
38+
* Kotlin test fixture for [ModelInitializer].
39+
*
40+
* @author Sebastien Deleuze
41+
*/
42+
class ModelInitializerKotlinTests {
43+
44+
private val timeout = Duration.ofMillis(5000)
45+
46+
private lateinit var modelInitializer: ModelInitializer
47+
48+
private val exchange: ServerWebExchange = MockServerWebExchange.from(MockServerHttpRequest.get("/path"))
49+
50+
@BeforeEach
51+
fun setup() {
52+
val adapterRegistry = ReactiveAdapterRegistry.getSharedInstance()
53+
val resolverConfigurer = ArgumentResolverConfigurer()
54+
resolverConfigurer.addCustomResolver(ModelMethodArgumentResolver(adapterRegistry))
55+
val methodResolver = ControllerMethodResolver(resolverConfigurer, adapterRegistry, StaticApplicationContext(),
56+
emptyList())
57+
modelInitializer = ModelInitializer(methodResolver, adapterRegistry)
58+
}
59+
60+
@Test
61+
@Suppress("UNCHECKED_CAST")
62+
fun modelAttributeMethods() {
63+
val controller = TestController()
64+
val method = ResolvableMethod.on(TestController::class.java).annotPresent(GetMapping::class.java)
65+
.resolveMethod()
66+
val handlerMethod = HandlerMethod(controller, method)
67+
val context = InitBinderBindingContext(ConfigurableWebBindingInitializer(), emptyList())
68+
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(timeout)
69+
val model = context.model.asMap()
70+
Assertions.assertThat(model).hasSize(2)
71+
val monoValue = model["suspendingReturnValue"] as Mono<TestBean>
72+
Assertions.assertThat(monoValue.block(timeout)!!.name).isEqualTo("Suspending return value")
73+
val value = model["suspendingModelParameter"] as TestBean
74+
Assertions.assertThat(value.name).isEqualTo("Suspending model parameter")
75+
}
76+
77+
78+
private data class TestBean(val name: String)
79+
80+
private class TestController {
81+
82+
@ModelAttribute("suspendingReturnValue")
83+
suspend fun suspendingReturnValue(): TestBean {
84+
delay(1)
85+
return TestBean("Suspending return value")
86+
}
87+
88+
@ModelAttribute
89+
suspend fun suspendingModelParameter(model: Model) {
90+
delay(1)
91+
model.addAttribute("suspendingModelParameter", TestBean("Suspending model parameter"))
92+
}
93+
94+
@GetMapping
95+
fun handleGet() {
96+
}
97+
98+
}
99+
100+
}

0 commit comments

Comments
 (0)