Skip to content

Commit a46f312

Browse files
committed
SQL: fix multi full-text functions usage with aggregate functions (#47444)
* Skip functions involving full-text predicates when replacing multiple aggregate functions with "stats" or "matrix_stats" aggregations. (cherry picked from commit bb14ba8)
1 parent 2b16d7b commit a46f312

File tree

3 files changed

+144
-2
lines changed

3 files changed

+144
-2
lines changed

x-pack/plugin/sql/qa/src/main/resources/fulltext.csv-spec

+61
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,64 @@ SELECT emp_no, first_name, SCORE() as s FROM test_emp WHERE MATCH(first_name, 'E
141141
emp_no:i | first_name:s | s:f
142142
10076 |Erez |4.1053944
143143
;
144+
145+
//
146+
// Mixture of Aggs that triggers promotion of aggs to stats using multi full-text filtering
147+
//
148+
multiAggWithCountMatchAndQuery
149+
SELECT MIN(salary) min, MAX(salary) max, gender g, COUNT(*) c FROM "test_emp" WHERE languages > 0 AND (MATCH(gender, 'F') OR MATCH(gender, 'M')) AND QUERY('M*', 'default_field=last_name;lenient=true', 'fuzzy_rewrite=scoring_boolean') GROUP BY g HAVING max > 50000 ORDER BY gender;
150+
151+
min:i | max:i | g:s | c:l
152+
---------------+---------------+---------------+---------------
153+
37112 |69904 |F |3
154+
32568 |70011 |M |8
155+
;
156+
157+
multiAggWithCountAndMultiMatch
158+
SELECT MIN(salary) min, MAX(salary) max, gender g, COUNT(*) c FROM "test_emp" WHERE MATCH(gender, 'F') OR MATCH(gender, 'M') GROUP BY g HAVING max > 50000 ORDER BY gender;
159+
160+
min:i | max:i | g:s | c:l
161+
---------------+---------------+---------------+---------------
162+
25976 |74572 |F |33
163+
25945 |74999 |M |57
164+
;
165+
166+
multiAggWithMultiMatchOrderByCount
167+
SELECT MIN(salary) min, MAX(salary) max, ROUND(AVG(salary)) avg, gender g, COUNT(*) c FROM "test_emp" WHERE MATCH(gender, 'F') OR MATCH('first_name^3,last_name^5', 'geo hir', 'fuzziness=2;operator=or') GROUP BY g ORDER BY c DESC;
168+
169+
min:i | max:i | avg:d | g:s | c:l
170+
---------------+---------------+---------------+---------------+---------------
171+
25976 |74572 |50491 |F |33
172+
32568 |32568 |32568 |M |1
173+
;
174+
175+
multiAggWithMultiMatchOrderByCountAndSimpleCondition
176+
SELECT MIN(salary) min, MAX(salary) max, ROUND(AVG(salary)) avg, gender g, COUNT(*) c FROM "test_emp" WHERE (MATCH(gender, 'F') AND languages > 4) OR MATCH('first_name^3,last_name^5', 'geo hir', 'fuzziness=2;operator=or') GROUP BY g ORDER BY c DESC;
177+
178+
min:i | max:i | avg:d | g:s | c:l
179+
---------------+---------------+---------------+---------------+---------------
180+
32272 |66817 |48081 |F |11
181+
32568 |32568 |32568 |M |1
182+
;
183+
184+
multiAggWithPercentileAndMultiQuery
185+
SELECT languages, PERCENTILE(salary, 95) "95th", ROUND(PERCENTILE_RANK(salary, 65000)) AS rank, MAX(salary), MIN(salary), COUNT(*) c FROM test_emp WHERE QUERY('A*','default_field=first_name') OR QUERY('B*', 'default_field=first_name') OR languages IS NULL GROUP BY languages;
186+
187+
languages:bt | 95th:d | rank:d | MAX(salary):i | MIN(salary):i | c:l
188+
---------------+---------------+---------------+---------------+---------------+---------------
189+
null |74999 |74 |74999 |28336 |10
190+
2 |44307 |100 |44307 |29175 |3
191+
3 |65030 |100 |65030 |38376 |4
192+
5 |66817 |100 |66817 |37137 |4
193+
;
194+
195+
multiAggWithStatsAndMatrixStatsAndMultiQuery
196+
SELECT languages, KURTOSIS(salary) k, SKEWNESS(salary) s, MAX(salary), MIN(salary), COUNT(*) c FROM test_emp WHERE QUERY('A*','default_field=first_name') OR QUERY('B*', 'default_field=first_name') OR languages IS NULL GROUP BY languages;
197+
198+
languages:bt | k:d | s:d | MAX(salary):i | MIN(salary):i | c:l
199+
---------------+------------------+-------------------+---------------+---------------+---------------
200+
null |1.9161749939033146|0.1480828817161133 |74999 |28336 |10
201+
2 |1.5000000000000002|0.484743245141609 |44307 |29175 |3
202+
3 |1.0732551278666582|0.05483979801873433|65030 |38376 |4
203+
5 |1.322529094661261 |0.24501477738153868|66817 |37137 |4
204+
;

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Case;
5353
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Coalesce;
5454
import org.elasticsearch.xpack.sql.expression.predicate.conditional.IfConditional;
55+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.FullTextPredicate;
5556
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
5657
import org.elasticsearch.xpack.sql.expression.predicate.logical.Not;
5758
import org.elasticsearch.xpack.sql.expression.predicate.logical.Or;
@@ -488,11 +489,11 @@ static LogicalPlan updateAggAttributes(LogicalPlan p, Map<String, AggregateFunct
488489
}
489490
}
490491

491-
else if (e instanceof ScalarFunction) {
492+
else if (e instanceof ScalarFunction && false == Expressions.anyMatch(e.children(), c -> c instanceof FullTextPredicate)) {
492493
ScalarFunction sf = (ScalarFunction) e;
493494

494495
// if it's a unseen function check if the function children/arguments refers to any of the promoted aggs
495-
if (!updatedScalarAttrs.containsKey(sf.functionId()) && e.anyMatch(c -> {
496+
if (newAggIds.isEmpty() == false && !updatedScalarAttrs.containsKey(sf.functionId()) && e.anyMatch(c -> {
496497
Attribute a = Expressions.attribute(c);
497498
if (a instanceof FunctionAttribute) {
498499
return newAggIds.contains(((FunctionAttribute) a).functionId());

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java

+80
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@
2323
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
2424
import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg;
2525
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
26+
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
2627
import org.elasticsearch.xpack.sql.expression.function.aggregate.First;
28+
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
2729
import org.elasticsearch.xpack.sql.expression.function.aggregate.Last;
2830
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
2931
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
32+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
33+
import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop;
34+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
35+
import org.elasticsearch.xpack.sql.expression.function.aggregate.SumOfSquares;
36+
import org.elasticsearch.xpack.sql.expression.function.aggregate.VarPop;
3037
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
3138
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayName;
3239
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfMonth;
@@ -57,7 +64,12 @@
5764
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Iif;
5865
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Least;
5966
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIf;
67+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.FullTextPredicate;
68+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
69+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate;
70+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate;
6071
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
72+
import org.elasticsearch.xpack.sql.expression.predicate.logical.BinaryLogic;
6173
import org.elasticsearch.xpack.sql.expression.predicate.logical.Not;
6274
import org.elasticsearch.xpack.sql.expression.predicate.logical.Or;
6375
import org.elasticsearch.xpack.sql.expression.predicate.nulls.IsNotNull;
@@ -87,6 +99,8 @@
8799
import org.elasticsearch.xpack.sql.optimizer.Optimizer.FoldNull;
88100
import org.elasticsearch.xpack.sql.optimizer.Optimizer.PropagateEquals;
89101
import org.elasticsearch.xpack.sql.optimizer.Optimizer.PruneDuplicateFunctions;
102+
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithExtendedStats;
103+
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithStats;
90104
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceFoldableAttributes;
91105
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceMinMaxWithTopHits;
92106
import org.elasticsearch.xpack.sql.optimizer.Optimizer.RewritePivot;
@@ -1522,4 +1536,70 @@ public void testPivotRewrite() {
15221536
assertEquals(column, in.value());
15231537
assertEquals(Arrays.asList(L(1), L(2)), in.list());
15241538
}
1539+
1540+
/**
1541+
* Test queries like SELECT MIN(agg_field), MAX(agg_field) FROM table WHERE MATCH(match_field,'A') AND/OR QUERY('match_field:A')
1542+
* or SELECT STDDEV_POP(agg_field), VAR_POP(agg_field) FROM table WHERE MATCH(match_field,'A') AND/OR QUERY('match_field:A')
1543+
*/
1544+
public void testAggregatesPromoteToStats_WithFullTextPredicatesConditions() {
1545+
FieldAttribute matchField = new FieldAttribute(EMPTY, "match_field", new EsField("match_field", DataType.TEXT, emptyMap(), true));
1546+
FieldAttribute aggField = new FieldAttribute(EMPTY, "agg_field", new EsField("agg_field", DataType.INTEGER, emptyMap(), true));
1547+
1548+
FullTextPredicate matchPredicate = new MatchQueryPredicate(EMPTY, matchField, "A", StringUtils.EMPTY);
1549+
FullTextPredicate multiMatchPredicate = new MultiMatchQueryPredicate(EMPTY, "match_field", "A", StringUtils.EMPTY);
1550+
FullTextPredicate stringQueryPredicate = new StringQueryPredicate(EMPTY, "match_field:A", StringUtils.EMPTY);
1551+
List<FullTextPredicate> predicates = Arrays.asList(matchPredicate, multiMatchPredicate, stringQueryPredicate);
1552+
1553+
FullTextPredicate left = randomFrom(predicates);
1554+
FullTextPredicate right = randomFrom(predicates);
1555+
1556+
BinaryLogic or = new Or(EMPTY, left, right);
1557+
BinaryLogic and = new And(EMPTY, left, right);
1558+
BinaryLogic condition = randomFrom(or, and);
1559+
Filter filter = new Filter(EMPTY, FROM(), condition);
1560+
1561+
List<AggregateFunction> aggregates;
1562+
boolean isSimpleStats = randomBoolean();
1563+
if (isSimpleStats) {
1564+
aggregates = Arrays.asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField),
1565+
new Max(EMPTY, aggField));
1566+
} else {
1567+
aggregates = Arrays.asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField));
1568+
}
1569+
AggregateFunction firstAggregate = randomFrom(aggregates);
1570+
AggregateFunction secondAggregate = randomValueOtherThan(firstAggregate, () -> randomFrom(aggregates));
1571+
Aggregate aggregatePlan = new Aggregate(EMPTY, filter, Collections.singletonList(matchField),
1572+
Arrays.asList(firstAggregate, secondAggregate));
1573+
LogicalPlan result;
1574+
if (isSimpleStats) {
1575+
result = new ReplaceAggsWithStats().apply(aggregatePlan);
1576+
} else {
1577+
result = new ReplaceAggsWithExtendedStats().apply(aggregatePlan);
1578+
}
1579+
1580+
assertTrue(result instanceof Aggregate);
1581+
Aggregate resultAgg = (Aggregate) result;
1582+
assertEquals(2, resultAgg.aggregates().size());
1583+
assertTrue(resultAgg.aggregates().get(0) instanceof InnerAggregate);
1584+
assertTrue(resultAgg.aggregates().get(1) instanceof InnerAggregate);
1585+
1586+
InnerAggregate resultFirstAgg = (InnerAggregate) resultAgg.aggregates().get(0);
1587+
InnerAggregate resultSecondAgg = (InnerAggregate) resultAgg.aggregates().get(1);
1588+
assertEquals(resultFirstAgg.inner(), firstAggregate);
1589+
assertEquals(resultSecondAgg.inner(), secondAggregate);
1590+
if (isSimpleStats) {
1591+
assertTrue(resultFirstAgg.outer() instanceof Stats);
1592+
assertTrue(resultSecondAgg.outer() instanceof Stats);
1593+
assertEquals(((Stats) resultFirstAgg.outer()).field(), aggField);
1594+
assertEquals(((Stats) resultSecondAgg.outer()).field(), aggField);
1595+
} else {
1596+
assertTrue(resultFirstAgg.outer() instanceof ExtendedStats);
1597+
assertTrue(resultSecondAgg.outer() instanceof ExtendedStats);
1598+
assertEquals(((ExtendedStats) resultFirstAgg.outer()).field(), aggField);
1599+
assertEquals(((ExtendedStats) resultSecondAgg.outer()).field(), aggField);
1600+
}
1601+
1602+
assertTrue(resultAgg.child() instanceof Filter);
1603+
assertEquals(resultAgg.child(), filter);
1604+
}
15251605
}

0 commit comments

Comments
 (0)