Skip to content

Commit db24e5c

Browse files
authored
SQL: Allow sorting of groups by aggregates (#38042) (#38255)
Introduce client-side sorting of groups based on aggregate functions. To allow this, the Analyzer has been extended to push down to underlying Aggregate, aggregate function and the Querier has been extended to identify the case and consume the results in order and sort them based on the given columns. The underlying QueryContainer has been slightly modified to allow a view of the underlying values being extracted as the columns used for sorting might not be requested by the user. The PR also adds minor tweaks, mainly related to tree output. Close #35118 (cherry picked from commit 783c9ed)
1 parent 1c845d6 commit db24e5c

File tree

60 files changed

+1343
-401
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1343
-401
lines changed

docs/reference/sql/limitations.asciidoc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,18 @@ a field is an array (has multiple values) or not, so without reading all the dat
6767
=== Sorting by aggregation
6868

6969
When doing aggregations (`GROUP BY`) {es-sql} relies on {es}'s `composite` aggregation for its support for paginating results.
70-
But this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets. This
71-
means that queries like `SELECT * FROM test GROUP BY age ORDER BY COUNT(*)` are not possible.
70+
However this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets.
71+
{es-sql} overcomes this limitation by doing client-side sorting however as a safety measure, allows only up to *512* rows.
72+
73+
It is recommended to use `LIMIT` for queries that use sorting by aggregation, essentially indicating the top N results that are desired:
74+
75+
[source, sql]
76+
--------------------------------------------------
77+
SELECT * FROM test GROUP BY age ORDER BY COUNT(*) LIMIT 100;
78+
--------------------------------------------------
79+
80+
It is possible to run the same queries without a `LIMIT` however in that case if the maximum size (*512*) is passed, an exception will be
81+
returned as {es-sql} is unable to track (and sort) all the results returned.
7282

7383
[float]
7484
=== Using aggregation functions on top of scalar functions

x-pack/plugin/sql/qa/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/CliExplainIT.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void testExplainBasic() throws IOException {
2020
assertThat(readLine(), startsWith("----------"));
2121
assertThat(readLine(), startsWith("With[{}]"));
2222
assertThat(readLine(), startsWith("\\_Project[[?*]]"));
23-
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
23+
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
2424
assertEquals("", readLine());
2525

2626
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test"), containsString("plan"));
@@ -64,22 +64,22 @@ public void testExplainWithWhere() throws IOException {
6464
assertThat(readLine(), startsWith("----------"));
6565
assertThat(readLine(), startsWith("With[{}]"));
6666
assertThat(readLine(), startsWith("\\_Project[[?*]]"));
67-
assertThat(readLine(), startsWith(" \\_Filter[i = 2#"));
68-
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
67+
assertThat(readLine(), startsWith(" \\_Filter[Equals[?i,2"));
68+
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
6969
assertEquals("", readLine());
7070

7171
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test WHERE i = 2"),
7272
containsString("plan"));
7373
assertThat(readLine(), startsWith("----------"));
7474
assertThat(readLine(), startsWith("Project[[i{f}#"));
75-
assertThat(readLine(), startsWith("\\_Filter[i = 2#"));
75+
assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
7676
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
7777
assertEquals("", readLine());
7878

7979
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test WHERE i = 2"), containsString("plan"));
8080
assertThat(readLine(), startsWith("----------"));
8181
assertThat(readLine(), startsWith("Project[[i{f}#"));
82-
assertThat(readLine(), startsWith("\\_Filter[i = 2#"));
82+
assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
8383
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
8484
assertEquals("", readLine());
8585

@@ -124,20 +124,20 @@ public void testExplainWithCount() throws IOException {
124124
assertThat(command("EXPLAIN (PLAN PARSED) SELECT COUNT(*) FROM test"), containsString("plan"));
125125
assertThat(readLine(), startsWith("----------"));
126126
assertThat(readLine(), startsWith("With[{}]"));
127-
assertThat(readLine(), startsWith("\\_Project[[?COUNT(*)]]"));
128-
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
127+
assertThat(readLine(), startsWith("\\_Project[[?COUNT[?*]]]"));
128+
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
129129
assertEquals("", readLine());
130130

131131
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT COUNT(*) FROM test"),
132132
containsString("plan"));
133133
assertThat(readLine(), startsWith("----------"));
134-
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#"));
134+
assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
135135
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
136136
assertEquals("", readLine());
137137

138138
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT COUNT(*) FROM test"), containsString("plan"));
139139
assertThat(readLine(), startsWith("----------"));
140-
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#"));
140+
assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
141141
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
142142
assertEquals("", readLine());
143143

x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/cli/ErrorsTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public void testSelectProjectScoreInAggContext() throws Exception {
7373
public void testSelectOrderByScoreInAggContext() throws Exception {
7474
index("test", body -> body.field("foo", 1));
7575
assertFoundOneProblem(command("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()"));
76-
assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]" + END, readLine());
76+
assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function" + END, readLine());
7777
}
7878

7979
@Override

x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/jdbc/ErrorsTestCase.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ public void testSelectOrderByScoreInAggContext() throws Exception {
8181
try (Connection c = esJdbc()) {
8282
SQLException e = expectThrows(SQLException.class, () ->
8383
c.prepareStatement("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()").executeQuery());
84-
assertEquals("Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]", e.getMessage());
84+
assertEquals(
85+
"Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function",
86+
e.getMessage());
8587
}
8688
}
8789

x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/jdbc/SqlSpecTestCase.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static List<Object[]> readScriptSpec() throws Exception {
3838
tests.addAll(readScriptSpec("/datetime.sql-spec", parser));
3939
tests.addAll(readScriptSpec("/math.sql-spec", parser));
4040
tests.addAll(readScriptSpec("/agg.sql-spec", parser));
41+
tests.addAll(readScriptSpec("/agg-ordering.sql-spec", parser));
4142
tests.addAll(readScriptSpec("/arithmetic.sql-spec", parser));
4243
tests.addAll(readScriptSpec("/string-functions.sql-spec", parser));
4344
tests.addAll(readScriptSpec("/case-functions.sql-spec", parser));
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//
2+
// Custom sorting/ordering on aggregates
3+
//
4+
5+
countWithImplicitGroupBy
6+
SELECT MAX(salary) AS m FROM test_emp ORDER BY COUNT(*);
7+
8+
countWithImplicitGroupByWithHaving
9+
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*);
10+
11+
countAndMaxWithImplicitGroupBy
12+
SELECT MAX(salary) AS m FROM test_emp ORDER BY MAX(salary), COUNT(*);
13+
14+
maxWithAliasWithImplicitGroupBy
15+
SELECT MAX(salary) AS m FROM test_emp ORDER BY m;
16+
17+
maxWithAliasWithImplicitGroupByAndHaving
18+
SELECT MAX(salary) AS m FROM test_emp HAVING COUNT(*) > 1 ORDER BY m;
19+
20+
multipleOrderWithImplicitGroupByWithHaving
21+
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), m DESC;
22+
23+
multipleOrderWithImplicitGroupByWithoutAlias
24+
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), MIN(salary) DESC;
25+
26+
multipleOrderWithImplicitGroupByOfOrdinals
27+
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp HAVING MIN(salary) > 1 ORDER BY 1, COUNT(*), 2 DESC;
28+
29+
aggWithoutAlias
30+
SELECT MAX(salary) AS max FROM test_emp GROUP BY gender ORDER BY MAX(salary);
31+
32+
aggWithAlias
33+
SELECT MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY m;
34+
35+
multipleAggsThatGetRewrittenWithoutAlias
36+
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY MAX(salary);
37+
38+
multipleAggsThatGetRewrittenWithAliasDesc
39+
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY 1 DESC;
40+
41+
multipleAggsThatGetRewrittenWithAlias
42+
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY max;
43+
44+
aggNotSpecifiedInTheAggregate
45+
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary);
46+
47+
aggNotSpecifiedInTheAggregatePlusOrdinal
48+
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary), 2 DESC;
49+
50+
aggNotSpecifiedInTheAggregateWithHaving
51+
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary);
52+
53+
aggNotSpecifiedInTheAggregateWithHavingDesc
54+
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary) DESC;
55+
56+
aggNotSpecifiedInTheAggregateAndGroupWithHaving
57+
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary), gender;
58+
59+
groupAndAggNotSpecifiedInTheAggregateWithHaving
60+
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender, MAX(salary);
61+
62+
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupBy
63+
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages ORDER BY max;
64+
65+
multipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
66+
SELECT emp_no, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY max;
67+
68+
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupByWithHaving
69+
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages HAVING min BETWEEN 1000 AND 99999 ORDER BY max;
70+
71+
aggNotSpecifiedInTheAggregatemultipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
72+
SELECT emp_no, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(salary);
73+
74+
aggNotSpecifiedWithHavingOnLargeGroupBy
75+
SELECT MAX(salary) AS max FROM test_emp GROUP BY emp_no HAVING AVG(salary) > 1000 ORDER BY MIN(salary);
76+
77+
aggWithTieBreakerDescAsc
78+
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no ASC;
79+
80+
aggWithTieBreakerDescDesc
81+
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no DESC;
82+
83+
aggWithTieBreakerAscDesc
84+
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(languages) ASC NULLS FIRST, emp_no DESC;
85+
86+
aggWithMixOfOrdinals
87+
SELECT gender AS g, MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY 2 DESC LIMIT 3;

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
import org.elasticsearch.xpack.sql.type.DataTypes;
5353
import org.elasticsearch.xpack.sql.type.InvalidMappedField;
5454
import org.elasticsearch.xpack.sql.type.UnsupportedEsField;
55+
import org.elasticsearch.xpack.sql.util.CollectionUtils;
56+
import org.elasticsearch.xpack.sql.util.Holder;
5557

5658
import java.util.ArrayList;
5759
import java.util.Arrays;
@@ -106,7 +108,8 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
106108
new ResolveFunctions(),
107109
new ResolveAliases(),
108110
new ProjectedAggregations(),
109-
new ResolveAggsInHaving()
111+
new ResolveAggsInHaving(),
112+
new ResolveAggsInOrderBy()
110113
//new ImplicitCasting()
111114
);
112115
Batch finish = new Batch("Finish Analysis",
@@ -926,62 +929,57 @@ protected LogicalPlan rule(Project p) {
926929
// Handle aggs in HAVING. To help folding any aggs not found in Aggregation
927930
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
928931
//
929-
private class ResolveAggsInHaving extends AnalyzeRule<LogicalPlan> {
932+
private class ResolveAggsInHaving extends AnalyzeRule<Filter> {
930933

931934
@Override
932935
protected boolean skipResolved() {
933936
return false;
934937
}
935938

936939
@Override
937-
protected LogicalPlan rule(LogicalPlan plan) {
940+
protected LogicalPlan rule(Filter f) {
938941
// HAVING = Filter followed by an Agg
939-
if (plan instanceof Filter) {
940-
Filter f = (Filter) plan;
941-
if (f.child() instanceof Aggregate && f.child().resolved()) {
942-
Aggregate agg = (Aggregate) f.child();
942+
if (f.child() instanceof Aggregate && f.child().resolved()) {
943+
Aggregate agg = (Aggregate) f.child();
943944

944-
Set<NamedExpression> missing = null;
945-
Expression condition = f.condition();
945+
Set<NamedExpression> missing = null;
946+
Expression condition = f.condition();
946947

947-
// the condition might contain an agg (AVG(salary)) that could have been resolved
948-
// (salary cannot be pushed down to Aggregate since there's no grouping and thus the function wasn't resolved either)
948+
// the condition might contain an agg (AVG(salary)) that could have been resolved
949+
// (salary cannot be pushed down to Aggregate since there's no grouping and thus the function wasn't resolved either)
949950

950-
// so try resolving the condition in one go through a 'dummy' aggregate
951-
if (!condition.resolved()) {
952-
// that's why try to resolve the condition
953-
Aggregate tryResolvingCondition = new Aggregate(agg.source(), agg.child(), agg.groupings(),
954-
combine(agg.aggregates(), new Alias(f.source(), ".having", condition)));
951+
// so try resolving the condition in one go through a 'dummy' aggregate
952+
if (!condition.resolved()) {
953+
// that's why try to resolve the condition
954+
Aggregate tryResolvingCondition = new Aggregate(agg.source(), agg.child(), agg.groupings(),
955+
combine(agg.aggregates(), new Alias(f.source(), ".having", condition)));
955956

956-
tryResolvingCondition = (Aggregate) analyze(tryResolvingCondition, false);
957+
tryResolvingCondition = (Aggregate) analyze(tryResolvingCondition, false);
957958

958-
// if it got resolved
959-
if (tryResolvingCondition.resolved()) {
960-
// replace the condition with the resolved one
961-
condition = ((Alias) tryResolvingCondition.aggregates()
962-
.get(tryResolvingCondition.aggregates().size() - 1)).child();
963-
} else {
964-
// else bail out
965-
return plan;
966-
}
959+
// if it got resolved
960+
if (tryResolvingCondition.resolved()) {
961+
// replace the condition with the resolved one
962+
condition = ((Alias) tryResolvingCondition.aggregates()
963+
.get(tryResolvingCondition.aggregates().size() - 1)).child();
964+
} else {
965+
// else bail out
966+
return f;
967967
}
968+
}
968969

969-
missing = findMissingAggregate(agg, condition);
970-
971-
if (!missing.isEmpty()) {
972-
Aggregate newAgg = new Aggregate(agg.source(), agg.child(), agg.groupings(),
973-
combine(agg.aggregates(), missing));
974-
Filter newFilter = new Filter(f.source(), newAgg, condition);
975-
// preserve old output
976-
return new Project(f.source(), newFilter, f.output());
977-
}
970+
missing = findMissingAggregate(agg, condition);
978971

979-
return new Filter(f.source(), f.child(), condition);
972+
if (!missing.isEmpty()) {
973+
Aggregate newAgg = new Aggregate(agg.source(), agg.child(), agg.groupings(),
974+
combine(agg.aggregates(), missing));
975+
Filter newFilter = new Filter(f.source(), newAgg, condition);
976+
// preserve old output
977+
return new Project(f.source(), newFilter, f.output());
980978
}
981-
return plan;
982-
}
983979

984-
return plan;
980+
return new Filter(f.source(), f.child(), condition);
981+
}
982+
return f;
985983
}
986984

987985
private Set<NamedExpression> findMissingAggregate(Aggregate target, Expression from) {
@@ -1001,6 +999,66 @@ private Set<NamedExpression> findMissingAggregate(Aggregate target, Expression f
1001999
}
10021000
}
10031001

1002+
1003+
//
1004+
// Handle aggs in ORDER BY. To help folding any aggs not found in Aggregation
1005+
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
1006+
// Similar to Having however using a different matching pattern since HAVING is always Filter with Agg,
1007+
// while an OrderBy can have multiple intermediate nodes (Filter,Project, etc...)
1008+
//
1009+
private static class ResolveAggsInOrderBy extends AnalyzeRule<OrderBy> {
1010+
1011+
@Override
1012+
protected boolean skipResolved() {
1013+
return false;
1014+
}
1015+
1016+
@Override
1017+
protected LogicalPlan rule(OrderBy ob) {
1018+
List<Order> orders = ob.order();
1019+
1020+
// 1. collect aggs inside an order by
1021+
List<NamedExpression> aggs = new ArrayList<>();
1022+
for (Order order : orders) {
1023+
if (Functions.isAggregate(order.child())) {
1024+
aggs.add(Expressions.wrapAsNamed(order.child()));
1025+
}
1026+
}
1027+
if (aggs.isEmpty()) {
1028+
return ob;
1029+
}
1030+
1031+
// 2. find first Aggregate child and update it
1032+
final Holder<Boolean> found = new Holder<>(Boolean.FALSE);
1033+
1034+
LogicalPlan plan = ob.transformDown(a -> {
1035+
if (found.get() == Boolean.FALSE) {
1036+
found.set(Boolean.TRUE);
1037+
1038+
List<NamedExpression> missing = new ArrayList<>();
1039+
1040+
for (NamedExpression orderedAgg : aggs) {
1041+
if (Expressions.anyMatch(a.aggregates(), e -> Expressions.equalsAsAttribute(e, orderedAgg)) == false) {
1042+
missing.add(orderedAgg);
1043+
}
1044+
}
1045+
// agg already contains all aggs
1046+
if (missing.isEmpty() == false) {
1047+
// save aggregates
1048+
return new Aggregate(a.source(), a.child(), a.groupings(), CollectionUtils.combine(a.aggregates(), missing));
1049+
}
1050+
}
1051+
return a;
1052+
}, Aggregate.class);
1053+
1054+
// if the plan was updated, project the initial aggregates
1055+
if (plan != ob) {
1056+
return new Project(ob.source(), plan, ob.output());
1057+
}
1058+
return ob;
1059+
}
1060+
}
1061+
10041062
private class PruneDuplicateFunctions extends AnalyzeRule<LogicalPlan> {
10051063

10061064
@Override

0 commit comments

Comments
 (0)