Skip to content

Commit 345befd

Browse files
committed
1 parent 8f96ca4 commit 345befd

File tree

7 files changed

+255
-190
lines changed

7 files changed

+255
-190
lines changed

spring-web/src/main/java/org/springframework/web/bind/ServletRequestDataBinder.java

+39
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@
1616

1717
package org.springframework.web.bind;
1818

19+
import java.lang.reflect.Constructor;
20+
import java.util.List;
21+
1922
import jakarta.servlet.ServletRequest;
2023
import jakarta.servlet.http.HttpServletRequest;
24+
import jakarta.servlet.http.Part;
2125

2226
import org.springframework.beans.MutablePropertyValues;
2327
import org.springframework.http.HttpMethod;
2428
import org.springframework.http.MediaType;
29+
import org.springframework.core.MethodParameter;
2530
import org.springframework.lang.Nullable;
2631
import org.springframework.util.StringUtils;
2732
import org.springframework.validation.BindException;
33+
import org.springframework.web.multipart.MultipartFile;
2834
import org.springframework.web.multipart.MultipartRequest;
2935
import org.springframework.web.multipart.support.StandardServletPartUtils;
3036
import org.springframework.web.util.WebUtils;
@@ -142,4 +148,37 @@ public void closeNoCatch() throws ServletRequestBindingException {
142148
}
143149
}
144150

151+
public <T> T construct(ServletRequest request, Constructor<T> ctor, Callback callback, @Nullable MethodParameter parameter) throws Exception {
152+
return super.construct(ctor, (name, type) -> getBindValue(request, name, type), callback, parameter);
153+
}
154+
155+
@Nullable
156+
protected Object getBindValue(ServletRequest request, String name, Class<?> type) {
157+
Object value = request.getParameterValues(name);
158+
if (value != null) {
159+
return value;
160+
}
161+
else {
162+
MultipartRequest multipartRequest = WebUtils.getNativeRequest(request, MultipartRequest.class);
163+
if (multipartRequest != null) {
164+
List<MultipartFile> files = multipartRequest.getFiles(name);
165+
if (!files.isEmpty()) {
166+
return (files.size() == 1 ? files.get(0) : files);
167+
}
168+
}
169+
else if (StringUtils.startsWithIgnoreCase(request.getContentType(), "multipart/")) {
170+
HttpServletRequest httpServletRequest = WebUtils.getNativeRequest(request, HttpServletRequest.class);
171+
if (httpServletRequest != null) {
172+
List<Part> parts = StandardServletPartUtils.getParts(httpServletRequest, name);
173+
if (!parts.isEmpty()) {
174+
return (parts.size() == 1 ? parts.get(0) : parts);
175+
}
176+
}
177+
}
178+
}
179+
return null;
180+
}
181+
182+
183+
145184
}

spring-web/src/main/java/org/springframework/web/bind/WebDataBinder.java

+163-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,15 +16,31 @@
1616

1717
package org.springframework.web.bind;
1818

19+
import java.lang.annotation.Annotation;
1920
import java.lang.reflect.Array;
21+
import java.lang.reflect.Constructor;
22+
import java.lang.reflect.Field;
23+
import java.util.ArrayList;
24+
import java.util.Arrays;
2025
import java.util.Collection;
26+
import java.util.HashSet;
2127
import java.util.List;
2228
import java.util.Map;
29+
import java.util.Optional;
30+
import java.util.Set;
31+
import java.util.function.BiFunction;
2332

33+
import org.springframework.beans.BeanInstantiationException;
34+
import org.springframework.beans.BeanUtils;
2435
import org.springframework.beans.MutablePropertyValues;
2536
import org.springframework.beans.PropertyValue;
37+
import org.springframework.beans.TypeMismatchException;
2638
import org.springframework.core.CollectionFactory;
39+
import org.springframework.core.MethodParameter;
2740
import org.springframework.lang.Nullable;
41+
import org.springframework.util.ObjectUtils;
42+
import org.springframework.validation.BindException;
43+
import org.springframework.validation.BindingResult;
2844
import org.springframework.validation.DataBinder;
2945
import org.springframework.web.multipart.MultipartFile;
3046

@@ -350,4 +366,150 @@ protected void bindMultipart(Map<String, List<MultipartFile>> multipartFiles, Mu
350366
});
351367
}
352368

369+
@SuppressWarnings("serial")
370+
protected <T> T construct(Constructor<T> ctor, BiFunction<String, Class<?>, Object> values,
371+
@Nullable Callback callback, @Nullable MethodParameter parameter) throws Exception {
372+
373+
// A single data class constructor -> resolve constructor arguments from request parameters.
374+
String[] paramNames = BeanUtils.getParameterNames(ctor);
375+
Class<?>[] paramTypes = ctor.getParameterTypes();
376+
Object[] args = new Object[paramTypes.length];
377+
String fieldDefaultPrefix = getFieldDefaultPrefix();
378+
String fieldMarkerPrefix = getFieldMarkerPrefix();
379+
boolean bindingFailure = false;
380+
Set<String> failedParams = new HashSet<>(4);
381+
382+
for (int i = 0; i < paramNames.length; i++) {
383+
String paramName = paramNames[i];
384+
Class<?> paramType = paramTypes[i];
385+
Object value = values.apply(paramName, paramType);
386+
387+
if (ObjectUtils.isArray(value) && Array.getLength(value) == 1) {
388+
value = Array.get(value, 0);
389+
}
390+
391+
if (value == null) {
392+
if (fieldDefaultPrefix != null) {
393+
value = values.apply(fieldDefaultPrefix + paramName, paramType);
394+
}
395+
if (value == null) {
396+
if (fieldMarkerPrefix != null &&
397+
values.apply(fieldMarkerPrefix + paramName, paramType) != null) {
398+
value = getEmptyValue(paramType);
399+
}
400+
}
401+
}
402+
try {
403+
MethodParameter methodParam = new FieldAwareConstructorParameter(ctor, i, paramName);
404+
if (value == null && methodParam.isOptional()) {
405+
args[i] = (methodParam.getParameterType() == Optional.class ? Optional.empty() : null);
406+
}
407+
else {
408+
args[i] = convertIfNecessary(value, paramType, methodParam);
409+
}
410+
}
411+
catch (TypeMismatchException ex) {
412+
ex.initPropertyName(paramName);
413+
args[i] = null;
414+
failedParams.add(paramName);
415+
getBindingResult().recordFieldValue(paramName, paramType, value);
416+
getBindingErrorProcessor().processPropertyAccessException(ex, getBindingResult());
417+
bindingFailure = true;
418+
}
419+
}
420+
421+
if (bindingFailure) {
422+
BindingResult result = getBindingResult();
423+
for (int i = 0; i < paramNames.length; i++) {
424+
String paramName = paramNames[i];
425+
if (!failedParams.contains(paramName)) {
426+
Object value = args[i];
427+
result.recordFieldValue(paramName, paramTypes[i], value);
428+
if (parameter != null && callback != null) {
429+
callback.validateValue(this, parameter, ctor.getDeclaringClass(), paramName, value);
430+
}
431+
}
432+
}
433+
if (parameter != null && !parameter.isOptional()) {
434+
try {
435+
Object target = BeanUtils.instantiateClass(ctor, args);
436+
throw new BindException(result) {
437+
@Override
438+
public Object getTarget() {
439+
return target;
440+
}
441+
};
442+
}
443+
catch (BeanInstantiationException ex) {
444+
// swallow and proceed without target instance
445+
}
446+
}
447+
throw new BindException(result);
448+
}
449+
450+
return BeanUtils.instantiateClass(ctor, args);
451+
}
452+
453+
public interface Callback {
454+
455+
void validateValue(WebDataBinder dataBinder, MethodParameter parameter, Class<?> declaringClass, String paramName, Object value);
456+
}
457+
458+
459+
/**
460+
* {@link MethodParameter} subclass which detects field annotations as well.
461+
*/
462+
private static class FieldAwareConstructorParameter extends MethodParameter {
463+
464+
private final String parameterName;
465+
466+
@Nullable
467+
private volatile Annotation[] combinedAnnotations;
468+
469+
public FieldAwareConstructorParameter(Constructor<?> constructor, int parameterIndex, String parameterName) {
470+
super(constructor, parameterIndex);
471+
this.parameterName = parameterName;
472+
}
473+
474+
@Override
475+
public Annotation[] getParameterAnnotations() {
476+
Annotation[] anns = this.combinedAnnotations;
477+
if (anns == null) {
478+
anns = super.getParameterAnnotations();
479+
try {
480+
Field field = getDeclaringClass().getDeclaredField(this.parameterName);
481+
Annotation[] fieldAnns = field.getAnnotations();
482+
if (fieldAnns.length > 0) {
483+
List<Annotation> merged = new ArrayList<>(anns.length + fieldAnns.length);
484+
merged.addAll(Arrays.asList(anns));
485+
for (Annotation fieldAnn : fieldAnns) {
486+
boolean existingType = false;
487+
for (Annotation ann : anns) {
488+
if (ann.annotationType() == fieldAnn.annotationType()) {
489+
existingType = true;
490+
break;
491+
}
492+
}
493+
if (!existingType) {
494+
merged.add(fieldAnn);
495+
}
496+
}
497+
anns = merged.toArray(new Annotation[0]);
498+
}
499+
}
500+
catch (NoSuchFieldException | SecurityException ex) {
501+
// ignore
502+
}
503+
this.combinedAnnotations = anns;
504+
}
505+
return anns;
506+
}
507+
508+
@Override
509+
public String getParameterName() {
510+
return this.parameterName;
511+
}
512+
}
513+
514+
353515
}

spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.web.bind.support;
1818

19+
import java.lang.reflect.Constructor;
1920
import java.util.List;
2021
import java.util.Map;
2122
import java.util.TreeMap;
@@ -24,6 +25,7 @@
2425
import reactor.core.publisher.Mono;
2526

2627
import org.springframework.beans.MutablePropertyValues;
28+
import org.springframework.core.MethodParameter;
2729
import org.springframework.http.codec.multipart.FormFieldPart;
2830
import org.springframework.http.codec.multipart.Part;
2931
import org.springframework.lang.Nullable;
@@ -85,6 +87,19 @@ public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
8587
return extractValuesToBind(exchange);
8688
}
8789

90+
public <T> Mono<T> construct(ServerWebExchange exchange, Constructor<T> ctor,
91+
@Nullable MethodParameter parameter) {
92+
return getValuesToBind(exchange).flatMap(bindValues -> {
93+
try {
94+
return Mono.just(super.construct(ctor, (name, type) -> bindValues.get(name), null, parameter));
95+
}
96+
catch (Exception ex) {
97+
return Mono.error(ex);
98+
}
99+
});
100+
}
101+
102+
88103

89104
/**
90105
* Combine query params and form data for multipart form data from the body

0 commit comments

Comments
 (0)