Skip to content

Commit 30f7228

Browse files
committed
Calculate sum in Kahan summation algorithm in aggregations (#27807)
1 parent c93cc1b commit 30f7228

File tree

6 files changed

+109
-19
lines changed

6 files changed

+109
-19
lines changed

core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
4646
DoubleArray sums;
4747
DocValueFormat format;
4848

49+
private DoubleArray compensations;
50+
4951
public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
5052
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
5153
super(name, context, parent, pipelineAggregators, metaData);
@@ -55,6 +57,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
5557
final BigArrays bigArrays = context.bigArrays();
5658
counts = bigArrays.newLongArray(1, true);
5759
sums = bigArrays.newDoubleArray(1, true);
60+
compensations = bigArrays.newDoubleArray(1, true);
5861
}
5962
}
6063

@@ -76,15 +79,22 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
7679
public void collect(int doc, long bucket) throws IOException {
7780
counts = bigArrays.grow(counts, bucket + 1);
7881
sums = bigArrays.grow(sums, bucket + 1);
82+
compensations = bigArrays.grow(compensations, bucket + 1);
7983

8084
if (values.advanceExact(doc)) {
8185
final int valueCount = values.docValueCount();
8286
counts.increment(bucket, valueCount);
83-
double sum = 0;
87+
double sum = sums.get(bucket);
88+
double compensation = compensations.get(bucket);
89+
8490
for (int i = 0; i < valueCount; i++) {
85-
sum += values.nextValue();
91+
double corrected = values.nextValue() - compensation;
92+
double newSum = sum + corrected;
93+
compensation = (newSum - sum) - corrected;
94+
sum = newSum;
8695
}
87-
sums.increment(bucket, sum);
96+
sums.set(bucket, sum);
97+
compensations.set(bucket, compensation);
8898
}
8999
}
90100
};
@@ -113,7 +123,7 @@ public InternalAggregation buildEmptyAggregation() {
113123

114124
@Override
115125
public void doClose() {
116-
Releasables.close(counts, sums);
126+
Releasables.close(counts, sums, compensations);
117127
}
118128

119129
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
4848
DoubleArray mins;
4949
DoubleArray maxes;
5050

51+
private DoubleArray compensations;
52+
5153

5254
public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat format,
5355
SearchContext context,
@@ -59,6 +61,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
5961
final BigArrays bigArrays = context.bigArrays();
6062
counts = bigArrays.newLongArray(1, true);
6163
sums = bigArrays.newDoubleArray(1, true);
64+
compensations = bigArrays.newDoubleArray(1, true);
6265
mins = bigArrays.newDoubleArray(1, false);
6366
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
6467
maxes = bigArrays.newDoubleArray(1, false);
@@ -88,6 +91,7 @@ public void collect(int doc, long bucket) throws IOException {
8891
final long overSize = BigArrays.overSize(bucket + 1);
8992
counts = bigArrays.resize(counts, overSize);
9093
sums = bigArrays.resize(sums, overSize);
94+
compensations = bigArrays.resize(compensations, overSize);
9195
mins = bigArrays.resize(mins, overSize);
9296
maxes = bigArrays.resize(maxes, overSize);
9397
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
@@ -97,16 +101,22 @@ public void collect(int doc, long bucket) throws IOException {
97101
if (values.advanceExact(doc)) {
98102
final int valuesCount = values.docValueCount();
99103
counts.increment(bucket, valuesCount);
100-
double sum = 0;
101104
double min = mins.get(bucket);
102105
double max = maxes.get(bucket);
106+
double sum = sums.get(bucket);
107+
double compensation = compensations.get(bucket);
108+
103109
for (int i = 0; i < valuesCount; i++) {
104110
double value = values.nextValue();
105-
sum += value;
111+
double corrected = value - compensation;
112+
double newSum = sum + corrected;
113+
compensation = (newSum - sum) - corrected;
114+
sum = newSum;
106115
min = Math.min(min, value);
107116
max = Math.max(max, value);
108117
}
109-
sums.increment(bucket, sum);
118+
sums.set(bucket, sum);
119+
compensations.set(bucket, compensation);
110120
mins.set(bucket, min);
111121
maxes.set(bucket, max);
112122
}
@@ -164,6 +174,6 @@ public InternalAggregation buildEmptyAggregation() {
164174

165175
@Override
166176
public void doClose() {
167-
Releasables.close(counts, maxes, mins, sums);
177+
Releasables.close(counts, maxes, mins, sums, compensations);
168178
}
169179
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
4343
private final DocValueFormat format;
4444

4545
private DoubleArray sums;
46+
private DoubleArray compensations;
4647

4748
SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
4849
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
@@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
5152
this.format = formatter;
5253
if (valuesSource != null) {
5354
sums = context.bigArrays().newDoubleArray(1, true);
55+
compensations = context.bigArrays().newDoubleArray(1, true);
5456
}
5557
}
5658

@@ -71,13 +73,20 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
7173
@Override
7274
public void collect(int doc, long bucket) throws IOException {
7375
sums = bigArrays.grow(sums, bucket + 1);
76+
compensations = bigArrays.grow(compensations, bucket + 1);
77+
7478
if (values.advanceExact(doc)) {
7579
final int valuesCount = values.docValueCount();
76-
double sum = 0;
80+
double sum = sums.get(bucket);
81+
double compensation = compensations.get(bucket);
7782
for (int i = 0; i < valuesCount; i++) {
78-
sum += values.nextValue();
83+
double corrected = values.nextValue() - compensation;
84+
double newSum = sum + corrected;
85+
compensation = (newSum - sum) - corrected;
86+
sum = newSum;
7987
}
80-
sums.increment(bucket, sum);
88+
compensations.set(bucket, compensation);
89+
sums.set(bucket, sum);
8190
}
8291
}
8392
};
@@ -106,6 +115,6 @@ public InternalAggregation buildEmptyAggregation() {
106115

107116
@Override
108117
public void doClose() {
109-
Releasables.close(sums);
118+
Releasables.close(sums, compensations);
110119
}
111120
}

core/src/test/java/org/elasticsearch/search/aggregations/metrics/StatsAggregatorTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.elasticsearch.search.aggregations.metrics;
2020

2121
import org.apache.lucene.document.Document;
22+
import org.apache.lucene.document.DoubleDocValuesField;
2223
import org.apache.lucene.document.SortedNumericDocValuesField;
2324
import org.apache.lucene.index.IndexReader;
2425
import org.apache.lucene.index.RandomIndexWriter;
@@ -36,6 +37,8 @@
3637
import java.io.IOException;
3738
import java.util.function.Consumer;
3839

40+
import static java.util.Collections.singleton;
41+
3942
public class StatsAggregatorTests extends AggregatorTestCase {
4043
static final double TOLERANCE = 1e-10;
4144

@@ -113,6 +116,27 @@ public void testRandomLongs() throws IOException {
113116
);
114117
}
115118

119+
public void testSummationAccuracy() throws IOException {
120+
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
121+
final String fieldName = "field";
122+
ft.setName(fieldName);
123+
testCase(ft,
124+
iw -> {
125+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
126+
for (double value : values) {
127+
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
128+
}
129+
},
130+
stats -> {
131+
assertEquals(15, stats.getCount());
132+
assertEquals(0.9, stats.getAvg(), 0d);
133+
assertEquals(13.5, stats.getSum(), 0d);
134+
assertEquals(1.7, stats.getMax(), 0d);
135+
assertEquals(0.1, stats.getMin(), 0d);
136+
}
137+
);
138+
}
139+
116140
public void testCase(MappedFieldType ft,
117141
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
118142
Consumer<InternalStats> verify) throws IOException {

core/src/test/java/org/elasticsearch/search/aggregations/metrics/SumAggregatorTests.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.elasticsearch.search.aggregations.metrics;
2020

21+
import org.apache.lucene.document.DoubleDocValuesField;
2122
import org.apache.lucene.document.Field;
2223
import org.apache.lucene.document.NumericDocValuesField;
2324
import org.apache.lucene.document.SortedDocValuesField;
@@ -116,10 +117,28 @@ public void testStringField() throws IOException {
116117
"Re-index with correct docvalues type.", e.getMessage());
117118
}
118119

120+
public void testSummationAccuracy() throws IOException {
121+
testCase(new MatchAllDocsQuery(),
122+
iw -> {
123+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
124+
for (double value : values) {
125+
iw.addDocument(singleton(new DoubleDocValuesField(FIELD_NAME, value)));
126+
}
127+
},
128+
count -> assertEquals(15.3, count.getValue(), 0d),
129+
NumberFieldMapper.NumberType.DOUBLE);
130+
}
131+
119132
private void testCase(Query query,
120133
CheckedConsumer<RandomIndexWriter, IOException> indexer,
121134
Consumer<Sum> verify) throws IOException {
135+
testCase(query, indexer, verify, NumberFieldMapper.NumberType.LONG);
136+
}
122137

138+
private void testCase(Query query,
139+
CheckedConsumer<RandomIndexWriter, IOException> indexer,
140+
Consumer<Sum> verify,
141+
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
123142
try (Directory directory = newDirectory()) {
124143
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
125144
indexer.accept(indexWriter);
@@ -128,7 +147,7 @@ private void testCase(Query query,
128147
try (IndexReader indexReader = DirectoryReader.open(directory)) {
129148
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
130149

131-
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
150+
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
132151
fieldType.setName(FIELD_NAME);
133152
fieldType.setHasDocValues(true);
134153

core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregatorTests.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.search.aggregations.metrics.avg;
2121

22+
import org.apache.lucene.document.DoubleDocValuesField;
2223
import org.apache.lucene.document.IntPoint;
2324
import org.apache.lucene.document.NumericDocValuesField;
2425
import org.apache.lucene.document.SortedNumericDocValuesField;
@@ -34,9 +35,6 @@
3435
import org.elasticsearch.index.mapper.MappedFieldType;
3536
import org.elasticsearch.index.mapper.NumberFieldMapper;
3637
import org.elasticsearch.search.aggregations.AggregatorTestCase;
37-
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder;
38-
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregator;
39-
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;
4038

4139
import java.io.IOException;
4240
import java.util.Arrays;
@@ -103,8 +101,28 @@ public void testQueryFiltersAll() throws IOException {
103101
});
104102
}
105103

106-
private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOException> buildIndex, Consumer<InternalAvg> verify)
107-
throws IOException {
104+
public void testSummationAccuracy() throws IOException {
105+
testCase(new MatchAllDocsQuery(),
106+
iw -> {
107+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
108+
for (double value : values) {
109+
iw.addDocument(singleton(new DoubleDocValuesField("number", value)));
110+
}
111+
},
112+
avg -> assertEquals(0.9, avg.getValue(), 0d),
113+
NumberFieldMapper.NumberType.DOUBLE);
114+
}
115+
116+
private void testCase(Query query,
117+
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
118+
Consumer<InternalAvg> verify) throws IOException {
119+
testCase(query, buildIndex, verify, NumberFieldMapper.NumberType.LONG);
120+
}
121+
122+
private void testCase(Query query,
123+
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
124+
Consumer<InternalAvg> verify,
125+
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
108126
Directory directory = newDirectory();
109127
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
110128
buildIndex.accept(indexWriter);
@@ -114,7 +132,7 @@ private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOExceptio
114132
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
115133

116134
AvgAggregationBuilder aggregationBuilder = new AvgAggregationBuilder("_name").field("number");
117-
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
135+
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
118136
fieldType.setName("number");
119137

120138
AvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);

0 commit comments

Comments
 (0)