Skip to content

Commit 967f702

Browse files
codebirdmatriv
authored andcommitted
SQL: Fix issues with GROUP BY queries (#41964)
Translate to an agg query even if only literals are selected, so that the correct number of rows is returned (number of buckets). Fix issue with key only in GROUP BY (not in select) and WHERE clause: Resolve aggregates and groupings based on the child plan which holds the info info for all the fields of the underlying table. Fixes: #41951 Fixes: #41413 (cherry picked from commit 45b8580)
1 parent 317d146 commit 967f702

File tree

5 files changed

+97
-9
lines changed

5 files changed

+97
-9
lines changed

x-pack/plugin/sql/qa/src/main/resources/agg.sql-spec

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ countDistinctAlias
144144
SELECT COUNT(DISTINCT hire_date) AS count FROM test_emp;
145145
countDistinctAndCountSimpleWithAlias
146146
SELECT COUNT(*) cnt, COUNT(DISTINCT first_name) as names, gender FROM test_emp GROUP BY gender ORDER BY gender;
147+
aliasedCountWithFunctionFilterAndGroupBy
148+
SELECT COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 GROUP BY gender ORDER BY gender;
149+
countWithFunctionFilterAndGroupBy
150+
SELECT COUNT(*) FROM test_emp WHERE ABS(salary) > 0 GROUP BY gender ORDER BY gender;
151+
aliasedCountWithMultiFunctionFilterAndGroupBy
152+
SELECT COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 AND YEAR(birth_date) > 1980 GROUP BY gender ORDER BY gender;
153+
countWithMultiFunctionFilterAndGroupBy
154+
SELECT COUNT(*) FROM test_emp WHERE ABS(salary) > 0 AND YEAR(birth_date) > 1980 GROUP BY gender ORDER BY gender;
155+
aliasedCountWithFunctionFilterAndMultiGroupBy
156+
SELECT COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 GROUP BY gender, salary ORDER BY gender;
157+
countWithFunctionFilterAndMultiGroupBy
158+
SELECT COUNT(*) FROM test_emp WHERE ABS(salary) > 0 GROUP BY gender, salary ORDER BY gender;
159+
aliasedCountWithMultiFunctionFilterAndMultiGroupBy
160+
SELECT COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 AND YEAR(birth_date) > 1980 GROUP BY gender, salary ORDER BY gender;
161+
countWithMultiFunctionFilterAndMultiGroupBy
162+
SELECT COUNT(*) FROM test_emp WHERE ABS(salary) > 0 AND YEAR(birth_date) > 1980 GROUP BY gender, salary ORDER BY gender;
163+
aliasedCountLiteralColumnWithFunctionFilterAndMultiGroupBy
164+
SELECT 1, gender as g, COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 GROUP BY g, salary ORDER BY gender;
165+
aliasedCountLiteralColumnWithFunctionFilterAndMultiGroupByWithFunction
166+
SELECT 1, gender as g, COUNT(*) as c FROM test_emp WHERE ABS(salary) > 0 GROUP BY g, YEAR(birth_date) ORDER BY gender, YEAR(birth_date);
147167

148168
aggCountAliasAndWhereClauseMultiGroupBy
149169
SELECT gender g, languages l, COUNT(*) c FROM "test_emp" WHERE emp_no < 10020 GROUP BY gender, languages ORDER BY gender, languages;
@@ -563,6 +583,22 @@ SELECT MIN(salary) min, MAX(salary) max, gender g, languages l, COUNT(*) c FROM
563583
// group by with literal
564584
implicitGroupByWithLiteral
565585
SELECT 10, MAX("salary") FROM test_emp;
586+
literalWithGroupBy
587+
SELECT 1 FROM test_emp GROUP BY gender;
588+
literalsWithGroupBy
589+
SELECT 1, 2 FROM test_emp GROUP BY gender;
590+
aliasedLiteralWithGroupBy
591+
SELECT 1 AS s FROM test_emp GROUP BY gender;
592+
aliasedLiteralsWithGroupBy
593+
SELECT 1 AS s, 2 FROM test_emp GROUP BY gender;
594+
literalsWithMultipleGroupBy
595+
SELECT 1, 2 FROM test_emp GROUP BY gender, salary;
596+
divisionLiteralsAdditionWithMultipleGroupBy
597+
SELECT 144 / 12 AS division, 1, 2 AS x, 1 + 2 AS addition FROM test_emp GROUP BY gender, salary;
598+
aliasedLiteralsWithMultipleGroupBy
599+
SELECT 1 as s, 2 FROM test_emp GROUP BY gender, salary;
600+
aliasedLiteralsWithMultipleGroupByWithFunction
601+
SELECT 1 as s, 2 FROM test_emp GROUP BY gender, YEAR(birth_date);
566602
implicitGroupByWithLiterals
567603
SELECT 10, 'foo', MAX("salary"), 20, 'bar' FROM test_emp;
568604
groupByWithLiteral

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ else if (plan instanceof Aggregate) {
341341
List<Expression> groupings = a.groupings();
342342
List<Expression> newGroupings = new ArrayList<>();
343343
AttributeMap<Expression> resolved = Expressions.aliases(a.aggregates());
344+
344345
boolean changed = false;
345346
for (Expression grouping : groupings) {
346347
if (grouping instanceof UnresolvedAttribute) {
@@ -618,7 +619,7 @@ protected LogicalPlan rule(LogicalPlan plan) {
618619
for (Order or : o.order()) {
619620
maybeResolved.add(or.resolved() ? or : tryResolveExpression(or, child));
620621
}
621-
622+
622623
Stream<Order> referencesStream = maybeResolved.stream()
623624
.filter(Expression::resolved);
624625

@@ -629,7 +630,7 @@ protected LogicalPlan rule(LogicalPlan plan) {
629630
// a + 1 in SELECT is actually Alias("a + 1", a + 1) and translates to ReferenceAttribute
630631
// in the output. However it won't match the unnamed a + 1 despite being the same expression
631632
// so explicitly compare the source
632-
633+
633634
// if there's a match, remove the item from the reference stream
634635
if (Expressions.hasReferenceAttribute(child.outputSet())) {
635636
final Map<Attribute, Expression> collectRefs = new LinkedHashMap<>();
@@ -720,6 +721,18 @@ protected LogicalPlan rule(LogicalPlan plan) {
720721
}
721722
}
722723

724+
// Try to resolve aggregates and groupings based on the child plan
725+
if (plan instanceof Aggregate) {
726+
Aggregate a = (Aggregate) plan;
727+
LogicalPlan child = a.child();
728+
List<Expression> newGroupings = new ArrayList<>(a.groupings().size());
729+
a.groupings().forEach(e -> newGroupings.add(tryResolveExpression(e, child)));
730+
List<NamedExpression> newAggregates = new ArrayList<>(a.aggregates().size());
731+
a.aggregates().forEach(e -> newAggregates.add(tryResolveExpression(e, child)));
732+
if (newAggregates.equals(a.aggregates()) == false || newGroupings.equals(a.groupings()) == false) {
733+
return new Aggregate(a.source(), child, newGroupings, newAggregates);
734+
}
735+
}
723736
return plan;
724737
}
725738

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,8 @@ protected LogicalPlan rule(LogicalPlan plan) {
18011801

18021802
plan.forEachDown(a -> {
18031803
List<Object> values = extractConstants(a.aggregates());
1804-
if (values.size() == a.aggregates().size() && isNotQueryWithFromClauseAndFilterFoldedToFalse(a)) {
1804+
if (values.size() == a.aggregates().size() && a.groupings().isEmpty()
1805+
&& isNotQueryWithFromClauseAndFilterFoldedToFalse(a)) {
18051806
optimizedPlan.set(new LocalRelation(a.source(), new SingletonExecutable(a.output(), values.toArray())));
18061807
}
18071808
}, Aggregate.class);

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,20 +392,21 @@ protected PhysicalPlan rule(AggregateExec a) {
392392
}
393393
return a;
394394
}
395-
395+
396396
static EsQueryExec fold(AggregateExec a, EsQueryExec exec) {
397-
397+
398398
QueryContainer queryC = exec.queryContainer();
399-
399+
400400
// track aliases defined in the SELECT and used inside GROUP BY
401401
// SELECT x AS a ... GROUP BY a
402402
Map<Attribute, Expression> aliasMap = new LinkedHashMap<>();
403+
String id = null;
403404
for (NamedExpression ne : a.aggregates()) {
404405
if (ne instanceof Alias) {
405406
aliasMap.put(ne.toAttribute(), ((Alias) ne).child());
406407
}
407408
}
408-
409+
409410
if (aliasMap.isEmpty() == false) {
410411
Map<Attribute, Expression> newAliases = new LinkedHashMap<>(queryC.aliases());
411412
newAliases.putAll(aliasMap);
@@ -450,7 +451,7 @@ static EsQueryExec fold(AggregateExec a, EsQueryExec exec) {
450451
target = ((Alias) ne).child();
451452
}
452453

453-
String id = Expressions.id(target);
454+
id = Expressions.id(target);
454455

455456
// literal
456457
if (target.foldable()) {
@@ -586,7 +587,14 @@ else if (target.foldable()) {
586587
}
587588
}
588589
}
589-
590+
// If we're only selecting literals, we have to still execute the aggregation to create
591+
// the correct grouping buckets, in order to return the appropriate number of rows
592+
if (a.aggregates().stream().allMatch(e -> e.anyMatch(Expression::foldable))) {
593+
for (Expression grouping : a.groupings()) {
594+
GroupByKey matchingGroup = groupingContext.groupFor(grouping);
595+
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, false), id);
596+
}
597+
}
590598
return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC);
591599
}
592600

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
2424
import org.elasticsearch.xpack.sql.expression.Literal;
2525
import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry;
26+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
2627
import org.elasticsearch.xpack.sql.expression.function.grouping.Histogram;
2728
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
2829
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeProcessor.DateTimeExtractor;
2930
import org.elasticsearch.xpack.sql.expression.function.scalar.math.MathProcessor.MathOperation;
3031
import org.elasticsearch.xpack.sql.expression.function.scalar.math.Round;
3132
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
33+
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.GreaterThan;
3234
import org.elasticsearch.xpack.sql.optimizer.Optimizer;
3335
import org.elasticsearch.xpack.sql.parser.SqlParser;
3436
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
@@ -123,6 +125,34 @@ public void testTermEqualityAnalyzer() {
123125
assertEquals("value", tq.value());
124126
}
125127

128+
public void testAliasAndGroupByResolution(){
129+
LogicalPlan p = plan("SELECT COUNT(*) AS c FROM test WHERE ABS(int) > 0 GROUP BY int");
130+
assertTrue(p instanceof Aggregate);
131+
Aggregate a = (Aggregate) p;
132+
LogicalPlan pc = ((Aggregate) p).child();
133+
assertTrue(pc instanceof Filter);
134+
Expression condition = ((Filter) pc).condition();
135+
assertEquals("GREATERTHAN", ((GreaterThan) condition).functionName());
136+
List<Expression> groupings = a.groupings();
137+
assertTrue(groupings.get(0).resolved());
138+
assertEquals("c", a.aggregates().get(0).name());
139+
assertEquals("COUNT", ((Count) ((Alias) a.aggregates().get(0)).child()).functionName());
140+
}
141+
public void testLiteralWithGroupBy(){
142+
LogicalPlan p = plan("SELECT 1 as t, 2 FROM test GROUP BY int");
143+
assertTrue(p instanceof Aggregate);
144+
Aggregate a = (Aggregate) p;
145+
List<Expression> groupings = a.groupings();
146+
assertEquals(1, groupings.size());
147+
assertTrue(groupings.get(0).resolved());
148+
assertTrue(groupings.get(0) instanceof FieldAttribute);
149+
assertEquals(2, a.aggregates().size());
150+
assertEquals("t", a.aggregates().get(0).name());
151+
assertTrue(((Alias) a.aggregates().get(0)).child() instanceof Literal);
152+
assertEquals("1", ((Alias) a.aggregates().get(0)).child().toString());
153+
assertEquals("2", ((Alias) a.aggregates().get(1)).child().toString());
154+
}
155+
126156
public void testTermEqualityNotAnalyzed() {
127157
LogicalPlan p = plan("SELECT some.string FROM test WHERE int = 5");
128158
assertTrue(p instanceof Project);

0 commit comments

Comments
 (0)