Skip to content

Commit 4b8fabb

Browse files
committed
Introduce MethodParameterFactory to Parameters.
We now provide a factory function to create MethodParameter objects associated with the enclosing class so parameters can resolve generics properly.
1 parent 3722208 commit 4b8fabb

File tree

4 files changed

+84
-8
lines changed

4 files changed

+84
-8
lines changed

src/main/java/org/springframework/data/repository/query/DefaultParameters.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
import java.lang.reflect.Method;
1919
import java.util.List;
20+
import java.util.function.IntFunction;
2021

22+
import org.springframework.core.MethodParameter;
2123
import org.springframework.data.util.TypeInformation;
2224

2325
/**
@@ -40,13 +42,30 @@ public DefaultParameters(Method method) {
4042

4143
/**
4244
* Creates a new {@link DefaultParameters} instance from the given {@link Method} and aggregate
43-
* {@link TypeInformation}.
45+
* {@link TypeInformation}. Note that this constructor uses a {@link IntFunction MethodParameterFactory} that doesn't
46+
* resolve generic method parameter types that are declared on the containing class. Use
47+
* {@link DefaultParameters#DefaultParameters(Method, IntFunction, TypeInformation)} with a
48+
* {@link MethodParameter#withContainingClass(Class) contextual factory} to ensure proper gerneric resolution.
4449
*
4550
* @param method must not be {@literal null}.
4651
* @param aggregateType must not be {@literal null}.
4752
*/
4853
public DefaultParameters(Method method, TypeInformation<?> aggregateType) {
49-
super(method, param -> new Parameter(param, aggregateType));
54+
this(method, index -> new MethodParameter(method, index), aggregateType);
55+
}
56+
57+
/**
58+
* Creates a new {@link DefaultParameters} instance from the given {@link Method} and aggregate
59+
* {@link TypeInformation}.
60+
*
61+
* @param method must not be {@literal null}.
62+
* @param parameterFactory must not be {@literal null}.
63+
* @param aggregateType must not be {@literal null}.
64+
* @since 3.2.1
65+
*/
66+
public DefaultParameters(Method method, IntFunction<MethodParameter> parameterFactory,
67+
TypeInformation<?> aggregateType) {
68+
super(method, parameterFactory, param -> new Parameter(param, aggregateType));
5069
}
5170

5271
private DefaultParameters(List<Parameter> parameters) {

src/main/java/org/springframework/data/repository/query/Parameters.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Iterator;
2424
import java.util.List;
2525
import java.util.function.Function;
26+
import java.util.function.IntFunction;
2627

2728
import org.springframework.core.DefaultParameterNameDiscoverer;
2829
import org.springframework.core.MethodParameter;
@@ -44,7 +45,8 @@
4445
*/
4546
public abstract class Parameters<S extends Parameters<S, T>, T extends Parameter> implements Streamable<T> {
4647

47-
public static final List<Class<?>> TYPES = Arrays.asList(ScrollPosition.class, Pageable.class, Sort.class, Limit.class);
48+
public static final List<Class<?>> TYPES = Arrays.asList(ScrollPosition.class, Pageable.class, Sort.class,
49+
Limit.class);
4850

4951
private static final String PARAM_ON_SPECIAL = format("You must not use @%s on a parameter typed %s or %s",
5052
Param.class.getSimpleName(), Pageable.class.getSimpleName(), Sort.class.getSimpleName());
@@ -84,6 +86,20 @@ public Parameters(Method method) {
8486
* @since 3.0.2
8587
*/
8688
protected Parameters(Method method, Function<MethodParameter, T> parameterFactory) {
89+
this(method, index -> new MethodParameter(method, index), parameterFactory);
90+
}
91+
92+
/**
93+
* Creates a new {@link Parameters} instance for the given {@link Method} and {@link Function} to create a
94+
* {@link Parameter} instance from a {@link MethodParameter}.
95+
*
96+
* @param method must not be {@literal null}.
97+
* @param methodParameterFactory must not be {@literal null}.
98+
* @param parameterFactory must not be {@literal null}.
99+
* @since 3.2.1
100+
*/
101+
protected Parameters(Method method, IntFunction<MethodParameter> methodParameterFactory,
102+
Function<MethodParameter, T> parameterFactory) {
87103

88104
Assert.notNull(method, "Method must not be null");
89105

@@ -102,7 +118,8 @@ protected Parameters(Method method, Function<MethodParameter, T> parameterFactor
102118

103119
for (int i = 0; i < parameterCount; i++) {
104120

105-
MethodParameter methodParameter = new MethodParameter(method, i);
121+
MethodParameter methodParameter = methodParameterFactory.apply(i);
122+
106123
methodParameter.initParameterNameDiscovery(PARAMETER_NAME_DISCOVERER);
107124

108125
T parameter = parameterFactory == null //
@@ -421,7 +438,9 @@ public static boolean isBindable(Class<?> type) {
421438
return !TYPES.contains(type);
422439
}
423440

441+
@Override
424442
public Iterator<T> iterator() {
425443
return parameters.iterator();
426444
}
445+
427446
}

src/main/java/org/springframework/data/repository/query/QueryMethod.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
import java.lang.reflect.Method;
2121
import java.util.Collections;
2222
import java.util.Set;
23+
import java.util.function.IntFunction;
2324
import java.util.function.Predicate;
2425
import java.util.stream.Stream;
2526

27+
import org.springframework.core.MethodParameter;
2628
import org.springframework.data.domain.Limit;
2729
import org.springframework.data.domain.Page;
2830
import org.springframework.data.domain.Pageable;
@@ -61,6 +63,7 @@ public class QueryMethod {
6163
private final ResultProcessor resultProcessor;
6264
private final Lazy<Class<?>> domainClass;
6365
private final Lazy<Boolean> isCollectionQuery;
66+
private final IntFunction<MethodParameter> methodParameterFactory;
6467

6568
/**
6669
* Creates a new {@link QueryMethod} from the given parameters. Looks up the correct query to use for following
@@ -85,6 +88,8 @@ public QueryMethod(Method method, RepositoryMetadata metadata, ProjectionFactory
8588
this.method = method;
8689
this.unwrappedReturnType = potentiallyUnwrapReturnTypeFor(metadata, method);
8790
this.metadata = metadata;
91+
this.methodParameterFactory = index -> new MethodParameter(method, index)
92+
.withContainingClass(metadata.getRepositoryInterface());
8893
this.parameters = createParameters(method);
8994

9095
this.domainClass = Lazy.of(() -> {
@@ -186,13 +191,23 @@ private boolean calculateIsCollectionQuery() {
186191
* @since 3.0.2
187192
*/
188193
protected Parameters<?, ?> createParameters(Method method, TypeInformation<?> domainType) {
189-
return new DefaultParameters(method, domainType);
194+
return new DefaultParameters(method, getMethodParameterFactory(), domainType);
195+
}
196+
197+
/**
198+
* The {@link IntFunction MethodParameterFactory} to create method parameters.
199+
*
200+
* @return factory to create method parameters.
201+
* @since 3.2.1
202+
*/
203+
protected IntFunction<MethodParameter> getMethodParameterFactory() {
204+
return methodParameterFactory;
190205
}
191206

192207
/**
193208
* Returns the method's name.
194209
*
195-
* @return
210+
* @return the method's name.
196211
*/
197212
public String getName() {
198213
return method.getName();
@@ -314,7 +329,7 @@ public ResultProcessor getResultProcessor() {
314329
return resultProcessor;
315330
}
316331

317-
RepositoryMetadata getMetadata() {
332+
protected RepositoryMetadata getMetadata() {
318333
return metadata;
319334
}
320335

src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java

+24-1
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@
2626
import org.junit.jupiter.api.BeforeEach;
2727
import org.junit.jupiter.api.Test;
2828
import org.reactivestreams.Publisher;
29+
30+
import org.springframework.core.MethodParameter;
2931
import org.springframework.data.domain.Limit;
3032
import org.springframework.data.domain.OffsetScrollPosition;
3133
import org.springframework.data.domain.Page;
3234
import org.springframework.data.domain.Pageable;
3335
import org.springframework.data.domain.Sort;
3436
import org.springframework.data.domain.Window;
37+
import org.springframework.data.repository.Repository;
38+
import org.springframework.data.util.TypeInformation;
3539
import org.springframework.test.util.ReflectionTestUtils;
3640

3741
/**
@@ -202,6 +206,18 @@ void acceptsLimitParameter() throws Exception {
202206
assertThat(parameters.getLimitIndex()).isOne();
203207
}
204208

209+
@Test // GH-2995
210+
void considersGenericType() throws Exception {
211+
212+
var method = TypedInterface.class.getMethod("foo", Object.class);
213+
214+
var parameters = new DefaultParameters(method,
215+
index -> new MethodParameter(method, index).withContainingClass(TypedInterface.class),
216+
TypeInformation.of(User.class));
217+
218+
assertThat(parameters.getParameter(0).getType()).isEqualTo(Long.class);
219+
}
220+
205221
private Parameters<?, Parameter> getParametersFor(String methodName, Class<?>... parameterTypes)
206222
throws SecurityException, NoSuchMethodException {
207223

@@ -214,7 +230,7 @@ static class User {
214230

215231
}
216232

217-
static interface SampleDao {
233+
interface SampleDao {
218234

219235
User valid(@Param("username") String username);
220236

@@ -248,4 +264,11 @@ static interface SampleDao {
248264
}
249265

250266
interface SomePageable extends Pageable {}
267+
268+
interface Intermediate<T, ID> extends Repository<T, ID> {
269+
void foo(ID id);
270+
}
271+
272+
interface TypedInterface extends Intermediate<User, Long> {}
273+
251274
}

0 commit comments

Comments
 (0)