Skip to content

Commit 452c36c

Browse files
keljpountz
kel
authored andcommitted
Calculate sum in Kahan summation algorithm in aggregations (#27807) (#27848)
1 parent 700d9ec commit 452c36c

17 files changed

+557
-37
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
4444

4545
LongArray counts;
4646
DoubleArray sums;
47+
DoubleArray compensations;
4748
DocValueFormat format;
4849

4950
public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
@@ -55,6 +56,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
5556
final BigArrays bigArrays = context.bigArrays();
5657
counts = bigArrays.newLongArray(1, true);
5758
sums = bigArrays.newDoubleArray(1, true);
59+
compensations = bigArrays.newDoubleArray(1, true);
5860
}
5961
}
6062

@@ -76,15 +78,29 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
7678
public void collect(int doc, long bucket) throws IOException {
7779
counts = bigArrays.grow(counts, bucket + 1);
7880
sums = bigArrays.grow(sums, bucket + 1);
81+
compensations = bigArrays.grow(compensations, bucket + 1);
7982

8083
if (values.advanceExact(doc)) {
8184
final int valueCount = values.docValueCount();
8285
counts.increment(bucket, valueCount);
83-
double sum = 0;
86+
// Compute the sum of double values with Kahan summation algorithm which is more
87+
// accurate than naive summation.
88+
double sum = sums.get(bucket);
89+
double compensation = compensations.get(bucket);
90+
8491
for (int i = 0; i < valueCount; i++) {
85-
sum += values.nextValue();
92+
double value = values.nextValue();
93+
if (Double.isFinite(value) == false) {
94+
sum += value;
95+
} else if (Double.isFinite(sum)) {
96+
double corrected = value - compensation;
97+
double newSum = sum + corrected;
98+
compensation = (newSum - sum) - corrected;
99+
sum = newSum;
100+
}
86101
}
87-
sums.increment(bucket, sum);
102+
sums.set(bucket, sum);
103+
compensations.set(bucket, compensation);
88104
}
89105
}
90106
};
@@ -113,7 +129,7 @@ public InternalAggregation buildEmptyAggregation() {
113129

114130
@Override
115131
public void doClose() {
116-
Releasables.close(counts, sums);
132+
Releasables.close(counts, sums, compensations);
117133
}
118134

119135
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/InternalAvg.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,20 @@ public String getWriteableName() {
9191
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
9292
long count = 0;
9393
double sum = 0;
94+
double compensation = 0;
95+
// Compute the sum of double values with Kahan summation algorithm which is more
96+
// accurate than naive summation.
9497
for (InternalAggregation aggregation : aggregations) {
95-
count += ((InternalAvg) aggregation).count;
96-
sum += ((InternalAvg) aggregation).sum;
98+
InternalAvg avg = (InternalAvg) aggregation;
99+
count += avg.count;
100+
if (Double.isFinite(avg.sum) == false) {
101+
sum += avg.sum;
102+
} else if (Double.isFinite(sum)) {
103+
double corrected = avg.sum - compensation;
104+
double newSum = sum + corrected;
105+
compensation = (newSum - sum) - corrected;
106+
sum = newSum;
107+
}
97108
}
98109
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
99110
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/InternalStats.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,23 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
152152
double min = Double.POSITIVE_INFINITY;
153153
double max = Double.NEGATIVE_INFINITY;
154154
double sum = 0;
155+
double compensation = 0;
155156
for (InternalAggregation aggregation : aggregations) {
156157
InternalStats stats = (InternalStats) aggregation;
157158
count += stats.getCount();
158159
min = Math.min(min, stats.getMin());
159160
max = Math.max(max, stats.getMax());
160-
sum += stats.getSum();
161+
// Compute the sum of double values with Kahan summation algorithm which is more
162+
// accurate than naive summation.
163+
double value = stats.getSum();
164+
if (Double.isFinite(value) == false) {
165+
sum += value;
166+
} else if (Double.isFinite(sum)) {
167+
double corrected = value - compensation;
168+
double newSum = sum + corrected;
169+
compensation = (newSum - sum) - corrected;
170+
sum = newSum;
171+
}
161172
}
162173
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
163174
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
4545

4646
LongArray counts;
4747
DoubleArray sums;
48+
DoubleArray compensations;
4849
DoubleArray mins;
4950
DoubleArray maxes;
5051

@@ -59,6 +60,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
5960
final BigArrays bigArrays = context.bigArrays();
6061
counts = bigArrays.newLongArray(1, true);
6162
sums = bigArrays.newDoubleArray(1, true);
63+
compensations = bigArrays.newDoubleArray(1, true);
6264
mins = bigArrays.newDoubleArray(1, false);
6365
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
6466
maxes = bigArrays.newDoubleArray(1, false);
@@ -88,6 +90,7 @@ public void collect(int doc, long bucket) throws IOException {
8890
final long overSize = BigArrays.overSize(bucket + 1);
8991
counts = bigArrays.resize(counts, overSize);
9092
sums = bigArrays.resize(sums, overSize);
93+
compensations = bigArrays.resize(compensations, overSize);
9194
mins = bigArrays.resize(mins, overSize);
9295
maxes = bigArrays.resize(maxes, overSize);
9396
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
@@ -97,16 +100,28 @@ public void collect(int doc, long bucket) throws IOException {
97100
if (values.advanceExact(doc)) {
98101
final int valuesCount = values.docValueCount();
99102
counts.increment(bucket, valuesCount);
100-
double sum = 0;
101103
double min = mins.get(bucket);
102104
double max = maxes.get(bucket);
105+
// Compute the sum of double values with Kahan summation algorithm which is more
106+
// accurate than naive summation.
107+
double sum = sums.get(bucket);
108+
double compensation = compensations.get(bucket);
109+
103110
for (int i = 0; i < valuesCount; i++) {
104111
double value = values.nextValue();
105-
sum += value;
112+
if (Double.isFinite(value) == false) {
113+
sum += value;
114+
} else if (Double.isFinite(sum)) {
115+
double corrected = value - compensation;
116+
double newSum = sum + corrected;
117+
compensation = (newSum - sum) - corrected;
118+
sum = newSum;
119+
}
106120
min = Math.min(min, value);
107121
max = Math.max(max, value);
108122
}
109-
sums.increment(bucket, sum);
123+
sums.set(bucket, sum);
124+
compensations.set(bucket, compensation);
110125
mins.set(bucket, min);
111126
maxes.set(bucket, max);
112127
}
@@ -164,6 +179,6 @@ public InternalAggregation buildEmptyAggregation() {
164179

165180
@Override
166181
public void doClose() {
167-
Releasables.close(counts, maxes, mins, sums);
182+
Releasables.close(counts, maxes, mins, sums, compensations);
168183
}
169184
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
4949

5050
LongArray counts;
5151
DoubleArray sums;
52+
DoubleArray compensations;
5253
DoubleArray mins;
5354
DoubleArray maxes;
5455
DoubleArray sumOfSqrs;
56+
DoubleArray compensationOfSqrs;
5557

5658
public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
5759
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
@@ -65,11 +67,13 @@ public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, D
6567
final BigArrays bigArrays = context.bigArrays();
6668
counts = bigArrays.newLongArray(1, true);
6769
sums = bigArrays.newDoubleArray(1, true);
70+
compensations = bigArrays.newDoubleArray(1, true);
6871
mins = bigArrays.newDoubleArray(1, false);
6972
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
7073
maxes = bigArrays.newDoubleArray(1, false);
7174
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
7275
sumOfSqrs = bigArrays.newDoubleArray(1, true);
76+
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
7377
}
7478
}
7579

@@ -95,29 +99,52 @@ public void collect(int doc, long bucket) throws IOException {
9599
final long overSize = BigArrays.overSize(bucket + 1);
96100
counts = bigArrays.resize(counts, overSize);
97101
sums = bigArrays.resize(sums, overSize);
102+
compensations = bigArrays.resize(compensations, overSize);
98103
mins = bigArrays.resize(mins, overSize);
99104
maxes = bigArrays.resize(maxes, overSize);
100105
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
106+
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
101107
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
102108
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
103109
}
104110

105111
if (values.advanceExact(doc)) {
106112
final int valuesCount = values.docValueCount();
107113
counts.increment(bucket, valuesCount);
108-
double sum = 0;
109-
double sumOfSqr = 0;
110114
double min = mins.get(bucket);
111115
double max = maxes.get(bucket);
116+
// Compute the sum and sum of squires for double values with Kahan summation algorithm
117+
// which is more accurate than naive summation.
118+
double sum = sums.get(bucket);
119+
double compensation = compensations.get(bucket);
120+
double sumOfSqr = sumOfSqrs.get(bucket);
121+
double compensationOfSqr = compensationOfSqrs.get(bucket);
112122
for (int i = 0; i < valuesCount; i++) {
113123
double value = values.nextValue();
114-
sum += value;
115-
sumOfSqr += value * value;
124+
if (Double.isFinite(value) == false) {
125+
sum += value;
126+
sumOfSqr += value * value;
127+
} else {
128+
if (Double.isFinite(sum)) {
129+
double corrected = value - compensation;
130+
double newSum = sum + corrected;
131+
compensation = (newSum - sum) - corrected;
132+
sum = newSum;
133+
}
134+
if (Double.isFinite(sumOfSqr)) {
135+
double correctedOfSqr = value * value - compensationOfSqr;
136+
double newSumOfSqr = sumOfSqr + correctedOfSqr;
137+
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
138+
sumOfSqr = newSumOfSqr;
139+
}
140+
}
116141
min = Math.min(min, value);
117142
max = Math.max(max, value);
118143
}
119-
sums.increment(bucket, sum);
120-
sumOfSqrs.increment(bucket, sumOfSqr);
144+
sums.set(bucket, sum);
145+
compensations.set(bucket, compensation);
146+
sumOfSqrs.set(bucket, sumOfSqr);
147+
compensationOfSqrs.set(bucket, compensationOfSqr);
121148
mins.set(bucket, min);
122149
maxes.set(bucket, max);
123150
}
@@ -196,6 +223,6 @@ public InternalAggregation buildEmptyAggregation() {
196223

197224
@Override
198225
public void doClose() {
199-
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
226+
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
200227
}
201228
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static Metrics resolve(String name) {
4545
private final double sigma;
4646

4747
public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs, double sigma,
48-
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
48+
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
4949
super(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
5050
this.sumOfSqrs = sumOfSqrs;
5151
this.sigma = sigma;
@@ -142,16 +142,25 @@ public String getStdDeviationBoundAsString(Bounds bound) {
142142
@Override
143143
public InternalExtendedStats doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
144144
double sumOfSqrs = 0;
145+
double compensationOfSqrs = 0;
145146
for (InternalAggregation aggregation : aggregations) {
146147
InternalExtendedStats stats = (InternalExtendedStats) aggregation;
147148
if (stats.sigma != sigma) {
148149
throw new IllegalStateException("Cannot reduce other stats aggregations that have a different sigma");
149150
}
150-
sumOfSqrs += stats.getSumOfSquares();
151+
double value = stats.getSumOfSquares();
152+
if (Double.isFinite(value) == false) {
153+
sumOfSqrs += value;
154+
} else if (Double.isFinite(sumOfSqrs)) {
155+
double correctedOfSqrs = value - compensationOfSqrs;
156+
double newSumOfSqrs = sumOfSqrs + correctedOfSqrs;
157+
compensationOfSqrs = (newSumOfSqrs - sumOfSqrs) - correctedOfSqrs;
158+
sumOfSqrs = newSumOfSqrs;
159+
}
151160
}
152161
final InternalStats stats = super.doReduce(aggregations, reduceContext);
153162
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma,
154-
format, pipelineAggregators(), getMetaData());
163+
format, pipelineAggregators(), getMetaData());
155164
}
156165

157166
static class Fields {

server/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/InternalSum.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
3535
private final double sum;
3636

3737
public InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
38-
Map<String, Object> metaData) {
38+
Map<String, Object> metaData) {
3939
super(name, pipelineAggregators, metaData);
4040
this.sum = sum;
4141
this.format = formatter;
@@ -73,9 +73,20 @@ public double getValue() {
7373

7474
@Override
7575
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
76+
// Compute the sum of double values with Kahan summation algorithm which is more
77+
// accurate than naive summation.
7678
double sum = 0;
79+
double compensation = 0;
7780
for (InternalAggregation aggregation : aggregations) {
78-
sum += ((InternalSum) aggregation).sum;
81+
double value = ((InternalSum) aggregation).sum;
82+
if (Double.isFinite(value) == false) {
83+
sum += value;
84+
} else if (Double.isFinite(sum)) {
85+
double corrected = value - compensation;
86+
double newSum = sum + corrected;
87+
compensation = (newSum - sum) - corrected;
88+
sum = newSum;
89+
}
7990
}
8091
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
8192
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
4343
private final DocValueFormat format;
4444

4545
private DoubleArray sums;
46+
private DoubleArray compensations;
4647

4748
SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
4849
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
@@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
5152
this.format = formatter;
5253
if (valuesSource != null) {
5354
sums = context.bigArrays().newDoubleArray(1, true);
55+
compensations = context.bigArrays().newDoubleArray(1, true);
5456
}
5557
}
5658

@@ -71,13 +73,27 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
7173
@Override
7274
public void collect(int doc, long bucket) throws IOException {
7375
sums = bigArrays.grow(sums, bucket + 1);
76+
compensations = bigArrays.grow(compensations, bucket + 1);
77+
7478
if (values.advanceExact(doc)) {
7579
final int valuesCount = values.docValueCount();
76-
double sum = 0;
80+
// Compute the sum of double values with Kahan summation algorithm which is more
81+
// accurate than naive summation.
82+
double sum = sums.get(bucket);
83+
double compensation = compensations.get(bucket);
7784
for (int i = 0; i < valuesCount; i++) {
78-
sum += values.nextValue();
85+
double value = values.nextValue();
86+
if (Double.isFinite(value) == false) {
87+
sum += value;
88+
} else if (Double.isFinite(sum)) {
89+
double corrected = value - compensation;
90+
double newSum = sum + corrected;
91+
compensation = (newSum - sum) - corrected;
92+
sum = newSum;
93+
}
7994
}
80-
sums.increment(bucket, sum);
95+
compensations.set(bucket, compensation);
96+
sums.set(bucket, sum);
8197
}
8298
}
8399
};
@@ -106,6 +122,6 @@ public InternalAggregation buildEmptyAggregation() {
106122

107123
@Override
108124
public void doClose() {
109-
Releasables.close(sums);
125+
Releasables.close(sums, compensations);
110126
}
111127
}

0 commit comments

Comments
 (0)