Skip to content

Commit f02a21c

Browse files
committed
Polishing.
Introduce doWithPlainSelect(…) callback for easier filtering of Select subtypes. Add test for known (previously) failing case. See: #3869 Original pull request: #3870
1 parent 7520e29 commit f02a21c

File tree

2 files changed

+113
-75
lines changed

2 files changed

+113
-75
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java

Lines changed: 101 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@
4848
import java.util.List;
4949
import java.util.Set;
5050
import java.util.StringJoiner;
51+
import java.util.function.Predicate;
52+
import java.util.function.Supplier;
5153

5254
import org.jspecify.annotations.Nullable;
5355

5456
import org.springframework.data.domain.Sort;
57+
import org.springframework.data.util.Predicates;
5558
import org.springframework.util.Assert;
5659
import org.springframework.util.CollectionUtils;
5760
import org.springframework.util.ObjectUtils;
@@ -149,24 +152,8 @@ static <T extends Statement> T parseStatement(String sql, Class<T> classOfT) {
149152

150153
if (ParsedType.SELECT.equals(parsedType)) {
151154

152-
Select selectStatement = (Select) statement;
153-
154-
/*
155-
* For all the other types ({@link ValuesStatement} and {@link SetOperationList}) it does not make sense to provide
156-
* alias since:
157-
* ValuesStatement has no alias
158-
* SetOperation can have multiple alias for each operation item
159-
*/
160-
if (!(selectStatement instanceof PlainSelect selectBody)) {
161-
return null;
162-
}
163-
164-
if (selectBody.getFromItem() == null) {
165-
return null;
166-
}
167-
168-
Alias alias = selectBody.getFromItem().getAlias();
169-
return alias == null ? null : alias.getName();
155+
return doWithPlainSelect(statement, it -> it.getFromItem() == null || it.getFromItem().getAlias() == null,
156+
it -> it.getFromItem().getAlias().getName(), () -> null);
170157
}
171158

172159
return null;
@@ -179,20 +166,24 @@ static <T extends Statement> T parseStatement(String sql, Class<T> classOfT) {
179166
*/
180167
private static Set<String> getSelectionAliases(Statement statement) {
181168

182-
if (!(statement instanceof PlainSelect select) || CollectionUtils.isEmpty(select.getSelectItems())) {
183-
return Collections.emptySet();
169+
if (statement instanceof SetOperationList sel) {
170+
statement = sel.getSelect(0);
184171
}
185172

186-
Set<String> set = new HashSet<>(select.getSelectItems().size());
173+
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
174+
175+
Set<String> set = new HashSet<>(it.getSelectItems().size(), 1.0f);
187176

188-
for (SelectItem<?> selectItem : select.getSelectItems()) {
189-
Alias alias = selectItem.getAlias();
190-
if (alias != null) {
191-
set.add(alias.getName());
177+
for (SelectItem<?> selectItem : it.getSelectItems()) {
178+
Alias alias = selectItem.getAlias();
179+
if (alias != null) {
180+
set.add(alias.getName());
181+
}
192182
}
193-
}
194183

195-
return set;
184+
return set;
185+
186+
}, Collections::emptySet);
196187
}
197188

198189
/**
@@ -202,21 +193,74 @@ private static Set<String> getSelectionAliases(Statement statement) {
202193
*/
203194
private static Set<String> getJoinAliases(Statement statement) {
204195

205-
if (!(statement instanceof PlainSelect selectBody) || CollectionUtils.isEmpty(selectBody.getJoins())) {
206-
return Collections.emptySet();
196+
if (statement instanceof SetOperationList sel) {
197+
statement = sel.getSelect(0);
207198
}
208199

209-
Set<String> set = new HashSet<>(selectBody.getJoins().size());
200+
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getJoins()), it -> {
210201

211-
for (Join join : selectBody.getJoins()) {
212-
Alias alias = join.getRightItem().getAlias();
213-
if (alias != null) {
214-
set.add(alias.getName());
202+
Set<String> set = new HashSet<>(it.getJoins().size(), 1.0f);
203+
204+
for (Join join : it.getJoins()) {
205+
Alias alias = join.getRightItem().getAlias();
206+
if (alias != null) {
207+
set.add(alias.getName());
208+
}
215209
}
210+
return set;
211+
212+
}, Collections::emptySet);
213+
}
214+
215+
/**
216+
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
217+
* {@link Statement} is or contains a {@link PlainSelect}.
218+
*
219+
* @param statement
220+
* @param mapper
221+
* @param fallback
222+
* @return
223+
* @param <T>
224+
*/
225+
private static <T> T doWithPlainSelect(Statement statement, java.util.function.Function<PlainSelect, T> mapper,
226+
Supplier<T> fallback) {
227+
228+
Predicate<PlainSelect> neverSkip = Predicates.isFalse();
229+
return doWithPlainSelect(statement, neverSkip, mapper, fallback);
230+
}
231+
232+
/**
233+
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
234+
* {@link Statement} is or contains a {@link PlainSelect}.
235+
* <p>
236+
* The operation is only applied if {@link Predicate skipIf} returns {@literal false} for the given statement
237+
* returning the fallback value from {@code fallback}.
238+
*
239+
* @param statement
240+
* @param skipIf
241+
* @param mapper
242+
* @param fallback
243+
* @return
244+
* @param <T>
245+
*/
246+
private static <T> T doWithPlainSelect(Statement statement, Predicate<PlainSelect> skipIf,
247+
java.util.function.Function<PlainSelect, T> mapper, Supplier<T> fallback) {
248+
249+
if (!(statement instanceof Select select)) {
250+
return fallback.get();
216251
}
217252

218-
return set;
253+
try {
254+
if (skipIf.test(select.getPlainSelect())) {
255+
return fallback.get();
256+
}
257+
}
258+
// e.g. SetOperationList is a subclass of Select but it is not a PlainSelect
259+
catch (ClassCastException e) {
260+
return fallback.get();
261+
}
219262

263+
return mapper.apply(select.getPlainSelect());
220264
}
221265

222266
private static String detectProjection(Statement statement) {
@@ -235,18 +279,17 @@ private static String detectProjection(Statement statement) {
235279

236280
// using the first one since for setoperations the projection has to be the same
237281
selectBody = setOperationList.getSelects().get(0);
238-
239-
if (!(selectBody instanceof PlainSelect)) {
240-
return "";
241-
}
242282
}
243283

244-
StringJoiner joiner = new StringJoiner(", ");
245-
for (SelectItem<?> selectItem : selectBody.getPlainSelect().getSelectItems()) {
246-
joiner.add(selectItem.toString());
247-
}
248-
return joiner.toString().trim();
284+
return doWithPlainSelect(selectBody, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
285+
286+
StringJoiner joiner = new StringJoiner(", ");
287+
for (SelectItem<?> selectItem : it.getSelectItems()) {
288+
joiner.add(selectItem.toString());
289+
}
290+
return joiner.toString().trim();
249291

292+
}, () -> "");
250293
}
251294

252295
/**
@@ -317,24 +360,22 @@ private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullab
317360
return applySortingToSetOperationList(setOperationList, sort);
318361
}
319362

320-
if (!(selectStatement instanceof PlainSelect selectBody)) {
321-
if (selectStatement != null) {
322-
return selectStatement.toString();
363+
doWithPlainSelect(selectStatement, it -> {
364+
365+
List<OrderByElement> orderByElements = new ArrayList<>(16);
366+
for (Sort.Order order : sort) {
367+
orderByElements.add(getOrderClause(joinAliases, selectAliases, alias, order));
368+
}
369+
370+
if (CollectionUtils.isEmpty(it.getOrderByElements())) {
371+
it.setOrderByElements(orderByElements);
323372
} else {
324-
throw new IllegalArgumentException("Select must not be null");
373+
it.getOrderByElements().addAll(orderByElements);
325374
}
326-
}
327375

328-
List<OrderByElement> orderByElements = new ArrayList<>(16);
329-
for (Sort.Order order : sort) {
330-
orderByElements.add(getOrderClause(joinAliases, selectAliases, alias, order));
331-
}
376+
return null;
332377

333-
if (CollectionUtils.isEmpty(selectBody.getOrderByElements())) {
334-
selectBody.setOrderByElements(orderByElements);
335-
} else {
336-
selectBody.getOrderByElements().addAll(orderByElements);
337-
}
378+
}, () -> "");
338379

339380
return selectStatement.toString();
340381
}
@@ -349,14 +390,9 @@ public String createCountQueryFor(@Nullable String countProjection) {
349390
Assert.hasText(this.query.getQueryString(), "OriginalQuery must not be null or empty");
350391

351392
Statement statement = (Statement) deserialize(this.serialized);
352-
/*
353-
We only support count queries for {@link PlainSelect}.
354-
*/
355-
if (!(statement instanceof PlainSelect selectBody)) {
356-
return this.query.getQueryString();
357-
}
358393

359-
return createCountQueryFor(selectBody, countProjection, primaryAlias);
394+
return doWithPlainSelect(statement, it -> createCountQueryFor(it, countProjection, primaryAlias),
395+
this.query::getQueryString);
360396
}
361397

362398
private static String createCountQueryFor(PlainSelect selectBody, @Nullable String countProjection,

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancerUnitTests.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
* @author Geoffrey Deremetz
3737
* @author Christoph Strobl
3838
*/
39-
public class JSqlParserQueryEnhancerUnitTests extends QueryEnhancerTckTests {
39+
class JSqlParserQueryEnhancerUnitTests extends QueryEnhancerTckTests {
4040

4141
@Override
4242
QueryEnhancer createQueryEnhancer(DeclaredQuery query) {
@@ -232,6 +232,17 @@ void truncateStatementShouldWork() {
232232
assertThat(queryEnhancer.hasConstructorExpression()).isFalse();
233233
}
234234

235+
@Test // GH-3869
236+
void shouldWorkWithParenthesedSelect() {
237+
238+
DefaultEntityQuery query = new TestEntityQuery("(SELECT is_contained_in(:innerId, :outerId))", true);
239+
QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(query).create(query);
240+
241+
assertThat(query.getQueryString()).isEqualTo("(SELECT is_contained_in(:innerId, :outerId))");
242+
assertThat(query.getAlias()).isNull();
243+
assertThat(queryEnhancer.getProjection()).isEqualTo("is_contained_in(:innerId, :outerId)");
244+
}
245+
235246
@ParameterizedTest // GH-2641
236247
@MethodSource("mergeStatementWorksSource")
237248
void mergeStatementWorksWithJSqlParser(String queryString, String alias) {
@@ -263,13 +274,4 @@ private static DefaultQueryRewriteInformation getRewriteInformation(Sort sort) {
263274
ReturnedType.of(Object.class, Object.class, new SpelAwareProxyProjectionFactory()));
264275
}
265276

266-
@Test // GH-3869
267-
void shouldWorkWithoutFromClause() {
268-
String query = "SELECT is_contained_in(:innerId, :outerId)";
269-
270-
StringQuery stringQuery = new StringQuery(query, true);
271-
272-
assertThat(stringQuery.getQueryString()).isEqualTo(query);
273-
}
274-
275277
}

0 commit comments

Comments
 (0)