Skip to content

Commit d45e6ec

Browse files
committed
Support Flux<ServerSentEvent<Fragment>> in WebFlux
Closes gh-33975
1 parent c4b100a commit d45e6ec

File tree

3 files changed

+116
-41
lines changed

3 files changed

+116
-41
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java

+59-26
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.springframework.http.HttpHeaders;
4646
import org.springframework.http.HttpStatusCode;
4747
import org.springframework.http.MediaType;
48+
import org.springframework.http.codec.ServerSentEvent;
4849
import org.springframework.http.server.reactive.ServerHttpRequest;
4950
import org.springframework.http.server.reactive.ServerHttpResponse;
5051
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
@@ -101,7 +102,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
101102

102103
private final List<View> defaultViews = new ArrayList<>(4);
103104

104-
private final List<StreamHandler> streamHandlers = List.of(new SseStreamHandler());
105+
private final SseStreamHandler sseHandler = new SseStreamHandler();
105106

106107

107108
/**
@@ -175,7 +176,7 @@ public boolean supports(HandlerResult result) {
175176
returnType = returnType.getNested(2);
176177

177178
if (adapter.isMultiValue()) {
178-
return Fragment.class.isAssignableFrom(type);
179+
return (Fragment.class.isAssignableFrom(type) || isSseFragmentStream(returnType));
179180
}
180181
}
181182

@@ -194,8 +195,13 @@ private boolean hasModelAnnotation(MethodParameter parameter) {
194195
}
195196

196197
private static boolean isFragmentCollection(ResolvableType returnType) {
197-
Class<?> clazz = returnType.resolve(Object.class);
198-
return (Collection.class.isAssignableFrom(clazz) && Fragment.class.equals(returnType.getNested(2).resolve()));
198+
return (Collection.class.isAssignableFrom(returnType.resolve(Object.class)) &&
199+
Fragment.class.equals(returnType.getNested(2).resolve()));
200+
}
201+
202+
private static boolean isSseFragmentStream(ResolvableType returnType) {
203+
return (ServerSentEvent.class.equals(returnType.resolve()) &&
204+
Fragment.class.equals(returnType.getNested(2).resolve()));
199205
}
200206

201207
@Override
@@ -204,9 +210,15 @@ public Mono<Void> handleResult(ServerWebExchange exchange, HandlerResult result)
204210
Mono<Object> valueMono;
205211
ResolvableType valueType;
206212
ReactiveAdapter adapter = getAdapter(result);
213+
BindingContext bindingContext = result.getBindingContext();
214+
Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext());
207215

208216
if (adapter != null) {
209217
if (adapter.isMultiValue()) {
218+
if (isSseFragmentStream(result.getReturnType().getNested(2))) {
219+
return handleSseFragmentStream(exchange, result, adapter, locale, bindingContext);
220+
}
221+
210222
valueMono = (result.getReturnValue() != null ?
211223
Mono.just(FragmentsRendering.fragmentsPublisher(adapter.toPublisher(result.getReturnValue())).build()) :
212224
Mono.empty());
@@ -233,8 +245,6 @@ public Mono<Void> handleResult(ServerWebExchange exchange, HandlerResult result)
233245
Mono<List<View>> viewsMono;
234246
Model model = result.getModel();
235247
MethodParameter parameter = result.getReturnTypeSource();
236-
BindingContext bindingContext = result.getBindingContext();
237-
Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext());
238248

239249
Class<?> clazz = valueType.toClass();
240250
if (clazz == Object.class) {
@@ -277,13 +287,15 @@ else if (FragmentsRendering.class.isAssignableFrom(clazz)) {
277287
response.getHeaders().putAll(render.headers());
278288
bindingContext.updateModel(exchange);
279289

280-
StreamHandler streamHandler = getStreamHandler(exchange);
290+
StreamHandler streamHandler =
291+
(this.sseHandler.supports(exchange.getRequest()) ? this.sseHandler : null);
292+
281293
if (streamHandler != null) {
282294
streamHandler.updateResponse(exchange);
283295
}
284296

285297
Flux<Flux<DataBuffer>> renderFlux = render.fragments()
286-
.concatMap(fragment -> renderFragment(fragment, streamHandler, locale, bindingContext, exchange))
298+
.concatMap(fragment -> renderFragment(fragment, null, streamHandler, locale, bindingContext, exchange))
287299
.doOnDiscard(DataBuffer.class, DataBufferUtils::release);
288300

289301
return response.writeAndFlushWith(renderFlux);
@@ -338,9 +350,29 @@ private Mono<List<View>> resolveViews(String viewName, Locale locale) {
338350
});
339351
}
340352

353+
private Mono<Void> handleSseFragmentStream(
354+
ServerWebExchange exchange, HandlerResult result, ReactiveAdapter adapter, Locale locale,
355+
BindingContext bindingContext) {
356+
357+
this.sseHandler.updateResponse(exchange);
358+
359+
Flux<ServerSentEvent<Fragment>> eventFlux =
360+
Flux.from(adapter.toPublisher(result.getReturnValue()));
361+
362+
Flux<Flux<DataBuffer>> dataBufferFlux = eventFlux
363+
.concatMap(event -> renderFragment(event.data(), event, this.sseHandler, locale, bindingContext, exchange))
364+
.doOnDiscard(DataBuffer.class, DataBufferUtils::release);
365+
366+
return exchange.getResponse().writeAndFlushWith(dataBufferFlux);
367+
}
368+
341369
private Mono<Flux<DataBuffer>> renderFragment(
342-
Fragment fragment, @Nullable StreamHandler streamHandler, Locale locale,
343-
BindingContext bindingContext, ServerWebExchange exchange) {
370+
@Nullable Fragment fragment, @Nullable Object streamingHints, @Nullable StreamHandler streamHandler,
371+
Locale locale, BindingContext bindingContext, ServerWebExchange exchange) {
372+
373+
if (fragment == null) {
374+
return Mono.empty();
375+
}
344376

345377
// Merge attributes from top-level model
346378
fragment.mergeAttributes(bindingContext.getModel());
@@ -355,25 +387,18 @@ private Mono<Flux<DataBuffer>> renderFragment(
355387
Map<String, Object> model = fragment.model();
356388

357389
if (streamHandler != null) {
358-
return selectedViews.flatMap(views -> render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange))
359-
.then(Mono.fromSupplier(() -> streamHandler.format(response.getBodyFlux(), fragment, exchange)));
390+
return selectedViews
391+
.flatMap(views ->
392+
render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange))
393+
.then(Mono.fromSupplier(() -> streamHandler.format(
394+
response.getBodyFlux(), fragment, streamingHints, exchange)));
360395
}
361396
else {
362397
return selectedViews.flatMap(views -> render(views, model, null, bindingContext, mutatedExchange))
363398
.then(Mono.fromSupplier(response::getBodyFlux));
364399
}
365400
}
366401

367-
@Nullable
368-
private StreamHandler getStreamHandler(ServerWebExchange exchange) {
369-
for (StreamHandler handler : this.streamHandlers) {
370-
if (handler.supports(exchange.getRequest())) {
371-
return handler;
372-
}
373-
}
374-
return null;
375-
}
376-
377402
private String getNameForReturnValue(MethodParameter returnType) {
378403
return Optional.ofNullable(returnType.getMethodAnnotation(ModelAttribute.class))
379404
.filter(ann -> StringUtils.hasText(ann.value()))
@@ -499,10 +524,13 @@ private interface StreamHandler {
499524
* Format the given fragment.
500525
* @param fragmentContent the fragment serialized to data buffers
501526
* @param fragment the fragment being rendered
527+
* @param streamingHints extra hints for the stream format (e.g. ServerSentEvent wrapper)
502528
* @param exchange the current exchange
503529
* @return the formatted fragment
504530
*/
505-
Flux<DataBuffer> format(Flux<DataBuffer> fragmentContent, Fragment fragment, ServerWebExchange exchange);
531+
Flux<DataBuffer> format(
532+
Flux<DataBuffer> fragmentContent, Fragment fragment, @Nullable Object streamingHints,
533+
ServerWebExchange exchange);
506534
}
507535

508536

@@ -540,16 +568,21 @@ private Charset getCharset(ServerHttpRequest request) {
540568

541569
@Override
542570
public Flux<DataBuffer> format(
543-
Flux<DataBuffer> fragmentFlux, Fragment fragment, ServerWebExchange exchange) {
571+
Flux<DataBuffer> fragmentFlux, Fragment fragment, @Nullable Object hints,
572+
ServerWebExchange exchange) {
544573

545574
MediaType mediaType = exchange.getResponse().getHeaders().getContentType();
546575
Charset charset = (mediaType != null && mediaType.getCharset() != null ?
547576
mediaType.getCharset() : StandardCharsets.UTF_8);
577+
Assert.state(hints == null || hints instanceof ServerSentEvent, "Expected ServerSentEvent");
548578

549579
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
550580

551-
String eventLine = (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "");
552-
DataBuffer prefix = encodeText(eventLine + "data:", charset, bufferFactory);
581+
ServerSentEvent<?> sse = (ServerSentEvent<?>) hints;
582+
CharSequence eventText = (sse != null ? sse.format() :
583+
(fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "") + "data:");
584+
585+
DataBuffer prefix = encodeText(eventText.toString(), charset, bufferFactory);
553586
DataBuffer suffix = encodeText("\n\n", charset, bufferFactory);
554587

555588
Mono<DataBuffer> content = DataBufferUtils.join(fragmentFlux)

spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java

+52-15
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
import org.springframework.context.annotation.Bean;
3535
import org.springframework.context.annotation.Configuration;
3636
import org.springframework.core.MethodParameter;
37+
import org.springframework.core.ResolvableType;
3738
import org.springframework.http.MediaType;
39+
import org.springframework.http.codec.ServerSentEvent;
3840
import org.springframework.web.reactive.BindingContext;
3941
import org.springframework.web.reactive.HandlerResult;
4042
import org.springframework.web.reactive.accept.HeaderContentTypeResolver;
@@ -99,7 +101,51 @@ void render(Object returnValue, MethodParameter parameter) {
99101
}
100102

101103
@Test
102-
void renderSse() {
104+
void renderFragmentStream() {
105+
106+
testSse(Flux.just(fragment1, fragment2),
107+
on(Handler.class).resolveReturnType(Flux.class, Fragment.class),
108+
"""
109+
event:fragment1
110+
data:<p>
111+
data: Hello Foo
112+
data:</p>
113+
114+
event:fragment2
115+
data:<p>
116+
data: Hello Bar
117+
data:</p>
118+
119+
""");
120+
}
121+
122+
@Test
123+
void renderServerSentEventFragmentStream() {
124+
125+
ServerSentEvent<Fragment> event1 = ServerSentEvent.builder(fragment1).id("id1").event("event1").build();
126+
ServerSentEvent<Fragment> event2 = ServerSentEvent.builder(fragment2).id("id2").event("event2").build();
127+
128+
MethodParameter returnType = on(Handler.class).resolveReturnType(
129+
Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class));
130+
131+
testSse(Flux.just(event1, event2), returnType,
132+
"""
133+
id:id1
134+
event:event1
135+
data:<p>
136+
data: Hello Foo
137+
data:</p>
138+
139+
id:id2
140+
event:event2
141+
data:<p>
142+
data: Hello Bar
143+
data:</p>
144+
145+
""");
146+
}
147+
148+
private void testSse(Flux<?> dataFlux, MethodParameter returnType, String output) {
103149
MockServerHttpRequest request = MockServerHttpRequest.get("/")
104150
.accept(MediaType.TEXT_EVENT_STREAM)
105151
.acceptLanguageAsLocales(Locale.ENGLISH)
@@ -110,27 +156,16 @@ void renderSse() {
110156

111157
HandlerResult result = new HandlerResult(
112158
new Handler(),
113-
Flux.just(fragment1, fragment2).subscribeOn(Schedulers.boundedElastic()),
114-
on(Handler.class).resolveReturnType(Flux.class, Fragment.class),
159+
dataFlux.subscribeOn(Schedulers.boundedElastic()),
160+
returnType,
115161
new BindingContext());
116162

117163
String body = initHandler().handleResult(exchange, result)
118164
.then(Mono.defer(response::getBodyAsString))
119165
.block(Duration.ofSeconds(60));
120166

121167
assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM);
122-
assertThat(body).isEqualTo("""
123-
event:fragment1
124-
data:<p>
125-
data: Hello Foo
126-
data:</p>
127-
128-
event:fragment2
129-
data:<p>
130-
data: Hello Bar
131-
data:</p>
132-
133-
""");
168+
assertThat(body).isEqualTo(output);
134169
}
135170

136171
private ViewResolutionResultHandler initHandler() {
@@ -155,6 +190,8 @@ private static class Handler {
155190

156191
Flux<Fragment> renderFlux() { return null; }
157192

193+
Flux<ServerSentEvent<Fragment>> renderSseFlux() { return null; }
194+
158195
List<Fragment> renderList() { return null; }
159196

160197
}

spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
4242
import org.springframework.http.HttpStatus;
4343
import org.springframework.http.MediaType;
44+
import org.springframework.http.codec.ServerSentEvent;
4445
import org.springframework.http.server.reactive.ServerHttpResponse;
4546
import org.springframework.lang.Nullable;
4647
import org.springframework.ui.ConcurrentModel;
@@ -84,6 +85,9 @@ void supports() {
8485

8586
testSupports(on(Handler.class).resolveReturnType(FragmentsRendering.class));
8687
testSupports(on(Handler.class).resolveReturnType(Flux.class, Fragment.class));
88+
testSupports(on(Handler.class).resolveReturnType(
89+
Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class)));
90+
8791
testSupports(on(Handler.class).resolveReturnType(List.class, Fragment.class));
8892
testSupports(on(Handler.class).resolveReturnType(
8993
Mono.class, ResolvableType.forClassWithGenerics(List.class, Fragment.class)));
@@ -457,6 +461,7 @@ private static class Handler {
457461

458462
FragmentsRendering fragmentsRendering() { return null; }
459463
Flux<Fragment> fragmentFlux() { return null; }
464+
Flux<ServerSentEvent<Fragment>> fragmentServerSentEventFlux() { return null; }
460465
Mono<List<Fragment>> monoFragmentList() { return null; }
461466
List<Fragment> fragmentList() { return null; }
462467

0 commit comments

Comments
 (0)