Skip to content

Commit 6c45e08

Browse files
committed
Compute summation with Kahan summation algorithm for internal aggregators
1 parent 30f7228 commit 6c45e08

File tree

11 files changed

+150
-16
lines changed

11 files changed

+150
-16
lines changed

core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
4444

4545
LongArray counts;
4646
DoubleArray sums;
47+
DoubleArray compensations;
4748
DocValueFormat format;
4849

49-
private DoubleArray compensations;
50-
5150
public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
5251
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
5352
super(name, context, parent, pipelineAggregators, metaData);
@@ -84,6 +83,8 @@ public void collect(int doc, long bucket) throws IOException {
8483
if (values.advanceExact(doc)) {
8584
final int valueCount = values.docValueCount();
8685
counts.increment(bucket, valueCount);
86+
// Compute the sum of double values with Kahan summation algorithm which is more
87+
// accurate than naive summation.
8788
double sum = sums.get(bucket);
8889
double compensation = compensations.get(bucket);
8990

core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/InternalAvg.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,15 @@ public String getWriteableName() {
9191
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
9292
long count = 0;
9393
double sum = 0;
94+
double compensation = 0;
95+
// Compute the sum of double values with Kahan summation algorithm which is more
96+
// accurate than naive summation.
9497
for (InternalAggregation aggregation : aggregations) {
9598
count += ((InternalAvg) aggregation).count;
96-
sum += ((InternalAvg) aggregation).sum;
99+
double corrected = ((InternalAvg) aggregation).sum - compensation;
100+
double newSum = sum + corrected;
101+
compensation = (newSum - sum) - corrected;
102+
sum = newSum;
97103
}
98104
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
99105
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/InternalStats.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,18 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
152152
double min = Double.POSITIVE_INFINITY;
153153
double max = Double.NEGATIVE_INFINITY;
154154
double sum = 0;
155+
double compensation = 0;
155156
for (InternalAggregation aggregation : aggregations) {
156157
InternalStats stats = (InternalStats) aggregation;
157158
count += stats.getCount();
158159
min = Math.min(min, stats.getMin());
159160
max = Math.max(max, stats.getMax());
160-
sum += stats.getSum();
161+
// Compute the sum of double values with Kahan summation algorithm which is more
162+
// accurate than naive summation.
163+
double corrected = stats.getSum() - compensation;
164+
double newSum = sum + corrected;
165+
compensation = (newSum - sum) - corrected;
166+
sum = newSum;
161167
}
162168
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
163169
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ public void collect(int doc, long bucket) throws IOException {
103103
counts.increment(bucket, valuesCount);
104104
double min = mins.get(bucket);
105105
double max = maxes.get(bucket);
106+
// Compute the sum of double values with Kahan summation algorithm which is more
107+
// accurate than naive summation.
106108
double sum = sums.get(bucket);
107109
double compensation = compensations.get(bucket);
108110

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
4949

5050
LongArray counts;
5151
DoubleArray sums;
52+
DoubleArray compensations;
5253
DoubleArray mins;
5354
DoubleArray maxes;
5455
DoubleArray sumOfSqrs;
56+
DoubleArray compensationOfSqrs;
5557

5658
public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
5759
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
@@ -65,11 +67,13 @@ public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, D
6567
final BigArrays bigArrays = context.bigArrays();
6668
counts = bigArrays.newLongArray(1, true);
6769
sums = bigArrays.newDoubleArray(1, true);
70+
compensations = bigArrays.newDoubleArray(1, true);
6871
mins = bigArrays.newDoubleArray(1, false);
6972
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
7073
maxes = bigArrays.newDoubleArray(1, false);
7174
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
7275
sumOfSqrs = bigArrays.newDoubleArray(1, true);
76+
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
7377
}
7478
}
7579

@@ -95,29 +99,44 @@ public void collect(int doc, long bucket) throws IOException {
9599
final long overSize = BigArrays.overSize(bucket + 1);
96100
counts = bigArrays.resize(counts, overSize);
97101
sums = bigArrays.resize(sums, overSize);
102+
compensations = bigArrays.resize(compensations, overSize);
98103
mins = bigArrays.resize(mins, overSize);
99104
maxes = bigArrays.resize(maxes, overSize);
100105
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
106+
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
101107
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
102108
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
103109
}
104110

105111
if (values.advanceExact(doc)) {
106112
final int valuesCount = values.docValueCount();
107113
counts.increment(bucket, valuesCount);
108-
double sum = 0;
109-
double sumOfSqr = 0;
110114
double min = mins.get(bucket);
111115
double max = maxes.get(bucket);
116+
// Compute the sum and sum of squires for double values with Kahan summation algorithm
117+
// which is more accurate than naive summation.
118+
double sum = sums.get(bucket);
119+
double compensation = compensations.get(bucket);
120+
double sumOfSqr = sumOfSqrs.get(bucket);
121+
double compensationOfSqr = compensationOfSqrs.get(bucket);
112122
for (int i = 0; i < valuesCount; i++) {
113123
double value = values.nextValue();
114-
sum += value;
115-
sumOfSqr += value * value;
124+
double corrected = value - compensation;
125+
double newSum = sum + corrected;
126+
compensation = (newSum - sum) - corrected;
127+
sum = newSum;
128+
129+
double correctedOfSqr = value * value - compensationOfSqr;
130+
double newSumOfSqr = sumOfSqr + correctedOfSqr;
131+
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
132+
sumOfSqr = newSumOfSqr;
116133
min = Math.min(min, value);
117134
max = Math.max(max, value);
118135
}
119-
sums.increment(bucket, sum);
120-
sumOfSqrs.increment(bucket, sumOfSqr);
136+
sums.set(bucket, sum);
137+
compensations.set(bucket, compensation);
138+
sumOfSqrs.set(bucket, sumOfSqr);
139+
compensationOfSqrs.set(bucket, compensationOfSqr);
121140
mins.set(bucket, min);
122141
maxes.set(bucket, max);
123142
}
@@ -196,6 +215,6 @@ public InternalAggregation buildEmptyAggregation() {
196215

197216
@Override
198217
public void doClose() {
199-
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
218+
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
200219
}
201220
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/InternalSum.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.search.DocValueFormat;
2525
import org.elasticsearch.search.aggregations.InternalAggregation;
2626
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
27+
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;
2728
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2829

2930
import java.io.IOException;
@@ -73,9 +74,15 @@ public double getValue() {
7374

7475
@Override
7576
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
77+
// Compute the sum of double values with Kahan summation algorithm which is more
78+
// accurate than naive summation.
7679
double sum = 0;
80+
double compensation = 0;
7781
for (InternalAggregation aggregation : aggregations) {
78-
sum += ((InternalSum) aggregation).sum;
82+
double corrected = ((InternalSum) aggregation).sum - compensation;
83+
double newSum = sum + corrected;
84+
compensation = (newSum - sum) - corrected;
85+
sum = newSum;
7986
}
8087
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
8188
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public void collect(int doc, long bucket) throws IOException {
7777

7878
if (values.advanceExact(doc)) {
7979
final int valuesCount = values.docValueCount();
80+
// Compute the sum of double values with Kahan summation algorithm which is more
81+
// accurate than naive summation.
8082
double sum = sums.get(bucket);
8183
double compensation = compensations.get(bucket);
8284
for (int i = 0; i < valuesCount; i++) {

core/src/test/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregatorTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.elasticsearch.search.aggregations.metrics;
2121

2222
import org.apache.lucene.document.Document;
23+
import org.apache.lucene.document.DoubleDocValuesField;
2324
import org.apache.lucene.document.SortedNumericDocValuesField;
2425
import org.apache.lucene.index.IndexReader;
2526
import org.apache.lucene.index.RandomIndexWriter;
@@ -38,6 +39,8 @@
3839
import java.io.IOException;
3940
import java.util.function.Consumer;
4041

42+
import static java.util.Collections.singleton;
43+
4144
public class ExtendedStatsAggregatorTests extends AggregatorTestCase {
4245
private static final double TOLERANCE = 1e-5;
4346

@@ -132,6 +135,37 @@ public void testRandomLongs() throws IOException {
132135
);
133136
}
134137

138+
public void testSummationAccuracy() throws IOException {
139+
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
140+
final String fieldName = "field";
141+
ft.setName(fieldName);
142+
testCase(ft,
143+
iw -> {
144+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
145+
for (double value : values) {
146+
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
147+
}
148+
},
149+
stats -> {
150+
assertEquals(15, stats.getCount());
151+
assertEquals(0.9, stats.getAvg(), 0d);
152+
assertEquals(13.5, stats.getSum(), 0d);
153+
assertEquals(1.7, stats.getMax(), 0d);
154+
assertEquals(0.1, stats.getMin(), 0d);
155+
assertEquals(0.1, stats.getMin(), 0d);
156+
}
157+
);
158+
testCase(ft,
159+
iw -> {
160+
double[] values = new double[]{2.1, 0.4, 0.4, 0.5, 0.5, 0.7, 0.9, 1.001, 1.222, 1.3, 1.4, 1.5, 1.6, 1.9};
161+
for (double value : values) {
162+
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
163+
}
164+
},
165+
stats -> assertEquals(21.095285, stats.getSumOfSquares(), 0d)
166+
);
167+
}
168+
135169
public void testCase(MappedFieldType ft,
136170
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
137171
Consumer<InternalExtendedStats> verify) throws IOException {

core/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalStatsTests.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
import org.elasticsearch.common.xcontent.XContentBuilder;
2424
import org.elasticsearch.common.xcontent.json.JsonXContent;
2525
import org.elasticsearch.search.DocValueFormat;
26+
import org.elasticsearch.search.aggregations.InternalAggregation;
2627
import org.elasticsearch.search.aggregations.ParsedAggregation;
2728
import org.elasticsearch.search.aggregations.metrics.stats.InternalStats;
2829
import org.elasticsearch.search.aggregations.metrics.stats.ParsedStats;
2930
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
3031
import org.elasticsearch.test.InternalAggregationTestCase;
3132

3233
import java.io.IOException;
34+
import java.util.ArrayList;
3335
import java.util.Collections;
3436
import java.util.HashMap;
3537
import java.util.List;
@@ -48,7 +50,7 @@ protected InternalStats createTestInstance(String name, List<PipelineAggregator>
4850
}
4951

5052
protected InternalStats createInstance(String name, long count, double sum, double min, double max, DocValueFormat formatter,
51-
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
53+
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
5254
return new InternalStats(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
5355
}
5456

@@ -74,6 +76,22 @@ protected void assertReduced(InternalStats reduced, List<InternalStats> inputs)
7476
assertEquals(expectedMax, reduced.getMax(), 0d);
7577
}
7678

79+
public void testSummationAccuracy() throws IOException {
80+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
81+
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
82+
for (double value : values) {
83+
aggregations.add(new InternalStats("dummy1", 1, value, value, value, null, null, null));
84+
}
85+
InternalStats internalStats = new InternalStats("dummy2", 0, 0.0, 2.0, 0.0, null, null, null);
86+
InternalStats reduced = internalStats.doReduce(aggregations, null);
87+
assertEquals("dummy2", reduced.getName());
88+
assertEquals(values.length, reduced.getCount());
89+
assertEquals(13.5, reduced.getSum(), 0d);
90+
assertEquals(0.9, reduced.getAvg(), 0d);
91+
assertEquals(0.1, reduced.getMin(), 0d);
92+
assertEquals(1.7, reduced.getMax(), 0d);
93+
}
94+
7795
@Override
7896
protected void assertFromXContent(InternalStats aggregation, ParsedAggregation parsedAggregation) {
7997
assertTrue(parsedAggregation instanceof ParsedStats);

core/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalSumTests.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020

2121
import org.elasticsearch.common.io.stream.Writeable;
2222
import org.elasticsearch.search.DocValueFormat;
23+
import org.elasticsearch.search.aggregations.InternalAggregation;
2324
import org.elasticsearch.search.aggregations.ParsedAggregation;
2425
import org.elasticsearch.search.aggregations.metrics.sum.InternalSum;
2526
import org.elasticsearch.search.aggregations.metrics.sum.ParsedSum;
2627
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2728
import org.elasticsearch.test.InternalAggregationTestCase;
2829

30+
import java.io.IOException;
31+
import java.util.ArrayList;
2932
import java.util.HashMap;
3033
import java.util.List;
3134
import java.util.Map;
@@ -34,7 +37,7 @@ public class InternalSumTests extends InternalAggregationTestCase<InternalSum> {
3437

3538
@Override
3639
protected InternalSum createTestInstance(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
37-
double value = frequently() ? randomDouble() : randomFrom(new Double[] { Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY });
40+
double value = frequently() ? randomDouble() : randomFrom(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
3841
DocValueFormat formatter = randomFrom(new DocValueFormat.Decimal("###.##"), DocValueFormat.BOOLEAN, DocValueFormat.RAW);
3942
return new InternalSum(name, value, formatter, pipelineAggregators, metaData);
4043
}
@@ -46,8 +49,27 @@ protected Writeable.Reader<InternalSum> instanceReader() {
4649

4750
@Override
4851
protected void assertReduced(InternalSum reduced, List<InternalSum> inputs) {
49-
double expectedSum = inputs.stream().mapToDouble(InternalSum::getValue).sum();
50-
assertEquals(expectedSum, reduced.getValue(), 0.0001d);
52+
double expectedSum = 0;
53+
double compensation = 0;
54+
for (InternalSum aggregation : inputs) {
55+
double corrected = aggregation.value() - compensation;
56+
double newSum = expectedSum + corrected;
57+
compensation = (newSum - expectedSum) - corrected;
58+
expectedSum = newSum;
59+
}
60+
assertEquals(expectedSum, reduced.getValue(), 0.000d);
61+
}
62+
63+
public void testSummationAccuracy() throws IOException {
64+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
65+
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
66+
for (double value : values) {
67+
aggregations.add(new InternalSum("dummy1", value, null, null, null));
68+
}
69+
InternalSum internalSum = new InternalSum("dummy", 0, null, null, null);
70+
InternalSum reduced = internalSum.doReduce(aggregations, null);
71+
assertEquals(13.5, reduced.value(), 0d);
72+
assertEquals("dummy", reduced.getName());
5173
}
5274

5375
@Override

core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/InternalAvgTests.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121

2222
import org.elasticsearch.common.io.stream.Writeable.Reader;
2323
import org.elasticsearch.search.DocValueFormat;
24+
import org.elasticsearch.search.aggregations.InternalAggregation;
2425
import org.elasticsearch.search.aggregations.ParsedAggregation;
2526
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2627
import org.elasticsearch.test.InternalAggregationTestCase;
2728

29+
import java.io.IOException;
30+
import java.util.ArrayList;
2831
import java.util.HashMap;
2932
import java.util.List;
3033
import java.util.Map;
@@ -56,6 +59,20 @@ protected void assertReduced(InternalAvg reduced, List<InternalAvg> inputs) {
5659
assertEquals(sum / counts, reduced.value(), 0.0000001);
5760
}
5861

62+
public void testSummationAccuracy() throws IOException {
63+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
64+
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
65+
for (double value : values) {
66+
aggregations.add(new InternalAvg("dummy1", value, 1, null, null, null));
67+
}
68+
InternalAvg internalAvg = new InternalAvg("dummy2", 0, 0, null, null, null);
69+
InternalAvg reduced = internalAvg.doReduce(aggregations, null);
70+
assertEquals(values.length, reduced.getCount());
71+
assertEquals(13.5, reduced.getSum(), 0d);
72+
assertEquals(0.9, reduced.getValue(), 0d);
73+
assertEquals("dummy2", reduced.getName());
74+
}
75+
5976
@Override
6077
protected void assertFromXContent(InternalAvg avg, ParsedAggregation parsedAggregation) {
6178
ParsedAvg parsed = ((ParsedAvg) parsedAggregation);

0 commit comments

Comments
 (0)