|
20 | 20 | import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
|
21 | 21 | import org.elasticsearch.xpack.ql.expression.function.Function;
|
22 | 22 | import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
|
| 23 | +import org.elasticsearch.xpack.ql.expression.function.aggregate.Count; |
23 | 24 | import org.elasticsearch.xpack.ql.expression.function.aggregate.InnerAggregate;
|
24 | 25 | import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
|
25 | 26 | import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
|
|
58 | 59 | import org.elasticsearch.xpack.ql.util.Holder;
|
59 | 60 | import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
|
60 | 61 | import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer.CleanAliases;
|
| 62 | +import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; |
61 | 63 | import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
|
62 | 64 | import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStatsEnclosed;
|
63 | 65 | import org.elasticsearch.xpack.sql.expression.function.aggregate.First;
|
|
66 | 68 | import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
|
67 | 69 | import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
|
68 | 70 | import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
|
| 71 | +import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate; |
69 | 72 | import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
|
70 | 73 | import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
|
71 | 74 | import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
|
72 | 75 | import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentiles;
|
73 | 76 | import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
|
| 77 | +import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum; |
74 | 78 | import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
|
75 | 79 | import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
|
76 | 80 | import org.elasticsearch.xpack.sql.expression.predicate.conditional.ArbitraryConditionalFunction;
|
77 | 81 | import org.elasticsearch.xpack.sql.expression.predicate.conditional.Case;
|
78 | 82 | import org.elasticsearch.xpack.sql.expression.predicate.conditional.Coalesce;
|
79 | 83 | import org.elasticsearch.xpack.sql.expression.predicate.conditional.IfConditional;
|
| 84 | +import org.elasticsearch.xpack.sql.expression.predicate.conditional.Iif; |
| 85 | +import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.Mul; |
80 | 86 | import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
|
81 | 87 | import org.elasticsearch.xpack.sql.plan.logical.LocalRelation;
|
82 | 88 | import org.elasticsearch.xpack.sql.plan.logical.Pivot;
|
@@ -119,7 +125,10 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
|
119 | 125 | new RewritePivot());
|
120 | 126 |
|
121 | 127 | Batch refs = new Batch("Replace References", Limiter.ONCE,
|
122 |
| - new ReplaceReferenceAttributeWithSource()); |
| 128 | + new ReplaceReferenceAttributeWithSource(), |
| 129 | + new ReplaceAggregatesWithLiterals(), |
| 130 | + new ReplaceCountInLocalRelation() |
| 131 | + ); |
123 | 132 |
|
124 | 133 | Batch operators = new Batch("Operator Optimization",
|
125 | 134 | // combining
|
@@ -776,6 +785,52 @@ private Expression simplify(BinaryComparison bc) {
|
776 | 785 | }
|
777 | 786 | }
|
778 | 787 |
|
| 788 | + /** |
| 789 | + * Any numeric aggregates (avg, min, max, sum) acting on literals are converted to an iif(count(1)=0, null, literal*count(1)) for sum, |
| 790 | + * and to iif(count(1)=0,null,literal) for the other three. |
| 791 | + */ |
| 792 | + private static class ReplaceAggregatesWithLiterals extends OptimizerRule<LogicalPlan> { |
| 793 | + |
| 794 | + @Override |
| 795 | + protected LogicalPlan rule(LogicalPlan p) { |
| 796 | + return p.transformExpressionsDown(e -> { |
| 797 | + if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum) { |
| 798 | + NumericAggregate a = (NumericAggregate) e; |
| 799 | + |
| 800 | + if (a.field().foldable()) { |
| 801 | + Expression countOne = new Count(a.source(), new Literal(Source.EMPTY, 1, a.dataType()), false); |
| 802 | + Equals countEqZero = new Equals(a.source(), countOne, new Literal(Source.EMPTY, 0, a.dataType())); |
| 803 | + Expression argument = a.field(); |
| 804 | + Literal foldedArgument = new Literal(argument.source(), argument.fold(), a.dataType()); |
| 805 | + |
| 806 | + Expression iifElseResult = foldedArgument; |
| 807 | + if (e instanceof Sum) { |
| 808 | + iifElseResult = new Mul(a.source(), countOne, foldedArgument); |
| 809 | + } |
| 810 | + |
| 811 | + return new Iif(a.source(), countEqZero, Literal.NULL, iifElseResult); |
| 812 | + } |
| 813 | + } |
| 814 | + return e; |
| 815 | + }); |
| 816 | + } |
| 817 | + } |
| 818 | + |
| 819 | + /** |
| 820 | + * A COUNT in a local relation will always be 1. |
| 821 | + */ |
| 822 | + private static class ReplaceCountInLocalRelation extends OptimizerRule<Aggregate> { |
| 823 | + |
| 824 | + @Override |
| 825 | + protected LogicalPlan rule(Aggregate a) { |
| 826 | + boolean hasLocalRelation = a.anyMatch(LocalRelation.class::isInstance); |
| 827 | + |
| 828 | + return hasLocalRelation ? a.transformExpressionsDown(c -> { |
| 829 | + return c instanceof Count ? new Literal(c.source(), 1, c.dataType()) : c; |
| 830 | + }) : a; |
| 831 | + } |
| 832 | + } |
| 833 | + |
779 | 834 | static class ReplaceAggsWithMatrixStats extends OptimizerBasicRule {
|
780 | 835 |
|
781 | 836 | @Override
|
@@ -1157,8 +1212,7 @@ private List<Object> extractConstants(List<? extends NamedExpression> named) {
|
1157 | 1212 | }
|
1158 | 1213 | } else if (n.foldable()) {
|
1159 | 1214 | values.add(n.fold());
|
1160 |
| - } |
1161 |
| - else { |
| 1215 | + } else { |
1162 | 1216 | // not everything is foldable, bail-out early
|
1163 | 1217 | return values;
|
1164 | 1218 | }
|
|
0 commit comments