Skip to content

Commit eb70849

Browse files
JosemyDuartenot-napoleon
authored andcommitted
Refactor and DRY up Kahan Sum algorithm (#48558)
1 parent ee8853f commit eb70849

File tree

12 files changed

+240
-132
lines changed

12 files changed

+240
-132
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,15 @@ public void collect(int doc, long bucket) throws IOException {
8787
// accurate than naive summation.
8888
double sum = sums.get(bucket);
8989
double compensation = compensations.get(bucket);
90+
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
9091

9192
for (int i = 0; i < valueCount; i++) {
9293
double value = values.nextValue();
93-
if (Double.isFinite(value) == false) {
94-
sum += value;
95-
} else if (Double.isFinite(sum)) {
96-
double corrected = value - compensation;
97-
double newSum = sum + corrected;
98-
compensation = (newSum - sum) - corrected;
99-
sum = newSum;
100-
}
94+
kahanSummation.add(value);
10195
}
102-
sums.set(bucket, sum);
103-
compensations.set(bucket, compensation);
96+
97+
sums.set(bucket, kahanSummation.value());
98+
compensations.set(bucket, kahanSummation.delta());
10499
}
105100
}
106101
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.search.aggregations.metrics;
21+
22+
23+
/**
24+
* Used to calculate sums using the Kahan summation algorithm.
25+
*
26+
* <p>The Kahan summation algorithm (also known as compensated summation) reduces the numerical errors that
27+
* occur when adding a sequence of finite precision floating point numbers. Numerical errors arise due to
28+
* truncation and rounding. These errors can lead to numerical instability.
29+
*
30+
* @see <a href="http://en.wikipedia.org/wiki/Kahan_summation_algorithm">Kahan Summation Algorithm</a>
31+
*/
32+
public class CompensatedSum {
33+
34+
private static final double NO_CORRECTION = 0.0;
35+
36+
private double value;
37+
private double delta;
38+
39+
/**
40+
* Used to calculate sums using the Kahan summation algorithm.
41+
*
42+
* @param value the sum
43+
* @param delta correction term
44+
*/
45+
public CompensatedSum(double value, double delta) {
46+
this.value = value;
47+
this.delta = delta;
48+
}
49+
50+
/**
51+
* The value of the sum.
52+
*/
53+
public double value() {
54+
return value;
55+
}
56+
57+
/**
58+
* The correction term.
59+
*/
60+
public double delta() {
61+
return delta;
62+
}
63+
64+
/**
65+
* Increments the Kahan sum by adding a value without a correction term.
66+
*/
67+
public CompensatedSum add(double value) {
68+
return add(value, NO_CORRECTION);
69+
}
70+
71+
/**
72+
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
73+
*/
74+
public CompensatedSum add(double value, double delta) {
75+
// If the value is Inf or NaN, just add it to the running tally to "convert" to
76+
// Inf/NaN. This keeps the behavior bwc from before kahan summing
77+
if (Double.isFinite(value) == false) {
78+
this.value = value + this.value;
79+
}
80+
81+
if (Double.isFinite(this.value)) {
82+
double correctedSum = value + (this.delta + delta);
83+
double updatedValue = this.value + correctedSum;
84+
this.delta = correctedSum - (updatedValue - this.value);
85+
this.value = updatedValue;
86+
}
87+
88+
return this;
89+
}
90+
91+
92+
}
93+

server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java

+11-21
Original file line numberDiff line numberDiff line change
@@ -117,34 +117,24 @@ public void collect(int doc, long bucket) throws IOException {
117117
// which is more accurate than naive summation.
118118
double sum = sums.get(bucket);
119119
double compensation = compensations.get(bucket);
120+
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
121+
120122
double sumOfSqr = sumOfSqrs.get(bucket);
121123
double compensationOfSqr = compensationOfSqrs.get(bucket);
124+
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
125+
122126
for (int i = 0; i < valuesCount; i++) {
123127
double value = values.nextValue();
124-
if (Double.isFinite(value) == false) {
125-
sum += value;
126-
sumOfSqr += value * value;
127-
} else {
128-
if (Double.isFinite(sum)) {
129-
double corrected = value - compensation;
130-
double newSum = sum + corrected;
131-
compensation = (newSum - sum) - corrected;
132-
sum = newSum;
133-
}
134-
if (Double.isFinite(sumOfSqr)) {
135-
double correctedOfSqr = value * value - compensationOfSqr;
136-
double newSumOfSqr = sumOfSqr + correctedOfSqr;
137-
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
138-
sumOfSqr = newSumOfSqr;
139-
}
140-
}
128+
compensatedSum.add(value);
129+
compensatedSumOfSqr.add(value * value);
141130
min = Math.min(min, value);
142131
max = Math.max(max, value);
143132
}
144-
sums.set(bucket, sum);
145-
compensations.set(bucket, compensation);
146-
sumOfSqrs.set(bucket, sumOfSqr);
147-
compensationOfSqrs.set(bucket, compensationOfSqr);
133+
134+
sums.set(bucket, compensatedSum.value());
135+
compensations.set(bucket, compensatedSum.delta());
136+
sumOfSqrs.set(bucket, compensatedSumOfSqr.value());
137+
compensationOfSqrs.set(bucket, compensatedSumOfSqr.delta());
148138
mins.set(bucket, min);
149139
maxes.set(bucket, max);
150140
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java

+9-12
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,21 @@ public void collect(int doc, long bucket) throws IOException {
8888
double sumLon = lonSum.get(bucket);
8989
double compensationLon = lonCompensations.get(bucket);
9090

91+
CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
92+
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
93+
9194
// update the sum
9295
for (int i = 0; i < valueCount; ++i) {
9396
GeoPoint value = values.nextValue();
9497
//latitude
95-
double correctedLat = value.getLat() - compensationLat;
96-
double newSumLat = sumLat + correctedLat;
97-
compensationLat = (newSumLat - sumLat) - correctedLat;
98-
sumLat = newSumLat;
98+
compensatedSumLat.add(value.getLat());
9999
//longitude
100-
double correctedLon = value.getLon() - compensationLon;
101-
double newSumLon = sumLon + correctedLon;
102-
compensationLon = (newSumLon - sumLon) - correctedLon;
103-
sumLon = newSumLon;
100+
compensatedSumLon.add(value.getLon());
104101
}
105-
lonSum.set(bucket, sumLon);
106-
lonCompensations.set(bucket, compensationLon);
107-
latSum.set(bucket, sumLat);
108-
latCompensations.set(bucket, compensationLat);
102+
lonSum.set(bucket, compensatedSumLon.value());
103+
lonCompensations.set(bucket, compensatedSumLon.delta());
104+
latSum.set(bucket, compensatedSumLat.value());
105+
latCompensations.set(bucket, compensatedSumLat.delta());
109106
}
110107
}
111108
};

server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalAvg.java

+3-11
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,16 @@ public String getWriteableName() {
8888

8989
@Override
9090
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
91+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
9192
long count = 0;
92-
double sum = 0;
93-
double compensation = 0;
9493
// Compute the sum of double values with Kahan summation algorithm which is more
9594
// accurate than naive summation.
9695
for (InternalAggregation aggregation : aggregations) {
9796
InternalAvg avg = (InternalAvg) aggregation;
9897
count += avg.count;
99-
if (Double.isFinite(avg.sum) == false) {
100-
sum += avg.sum;
101-
} else if (Double.isFinite(sum)) {
102-
double corrected = avg.sum - compensation;
103-
double newSum = sum + corrected;
104-
compensation = (newSum - sum) - corrected;
105-
sum = newSum;
106-
}
98+
kahanSummation.add(avg.sum);
10799
}
108-
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
100+
return new InternalAvg(getName(), kahanSummation.value(), count, format, pipelineAggregators(), getMetaData());
109101
}
110102

111103
@Override

server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalStats.java

+4-12
Original file line numberDiff line numberDiff line change
@@ -149,26 +149,18 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
149149
long count = 0;
150150
double min = Double.POSITIVE_INFINITY;
151151
double max = Double.NEGATIVE_INFINITY;
152-
double sum = 0;
153-
double compensation = 0;
152+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
153+
154154
for (InternalAggregation aggregation : aggregations) {
155155
InternalStats stats = (InternalStats) aggregation;
156156
count += stats.getCount();
157157
min = Math.min(min, stats.getMin());
158158
max = Math.max(max, stats.getMax());
159159
// Compute the sum of double values with Kahan summation algorithm which is more
160160
// accurate than naive summation.
161-
double value = stats.getSum();
162-
if (Double.isFinite(value) == false) {
163-
sum += value;
164-
} else if (Double.isFinite(sum)) {
165-
double corrected = value - compensation;
166-
double newSum = sum + corrected;
167-
compensation = (newSum - sum) - corrected;
168-
sum = newSum;
169-
}
161+
kahanSummation.add(stats.getSum());
170162
}
171-
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
163+
return new InternalStats(name, count, kahanSummation.value(), min, max, format, pipelineAggregators(), getMetaData());
172164
}
173165

174166
static class Fields {

server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalSum.java

+3-11
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,12 @@ public double getValue() {
7474
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
7575
// Compute the sum of double values with Kahan summation algorithm which is more
7676
// accurate than naive summation.
77-
double sum = 0;
78-
double compensation = 0;
77+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
7978
for (InternalAggregation aggregation : aggregations) {
8079
double value = ((InternalSum) aggregation).sum;
81-
if (Double.isFinite(value) == false) {
82-
sum += value;
83-
} else if (Double.isFinite(sum)) {
84-
double corrected = value - compensation;
85-
double newSum = sum + corrected;
86-
compensation = (newSum - sum) - corrected;
87-
sum = newSum;
88-
}
80+
kahanSummation.add(value);
8981
}
90-
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
82+
return new InternalSum(name, kahanSummation.value(), format, pipelineAggregators(), getMetaData());
9183
}
9284

9385
@Override

server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalWeightedAvg.java

+9-25
Original file line numberDiff line numberDiff line change
@@ -88,37 +88,21 @@ public String getWriteableName() {
8888

8989
@Override
9090
public InternalWeightedAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
91-
double weight = 0;
92-
double sum = 0;
93-
double sumCompensation = 0;
94-
double weightCompensation = 0;
91+
CompensatedSum sumCompensation = new CompensatedSum(0, 0);
92+
CompensatedSum weightCompensation = new CompensatedSum(0, 0);
93+
9594
// Compute the sum of double values with Kahan summation algorithm which is more
9695
// accurate than naive summation.
9796
for (InternalAggregation aggregation : aggregations) {
9897
InternalWeightedAvg avg = (InternalWeightedAvg) aggregation;
99-
// If the weight is Inf or NaN, just add it to the running tally to "convert" to
100-
// Inf/NaN. This keeps the behavior bwc from before kahan summing
101-
if (Double.isFinite(avg.weight) == false) {
102-
weight += avg.weight;
103-
} else if (Double.isFinite(weight)) {
104-
double corrected = avg.weight - weightCompensation;
105-
double newWeight = weight + corrected;
106-
weightCompensation = (newWeight - weight) - corrected;
107-
weight = newWeight;
108-
}
109-
// If the avg is Inf or NaN, just add it to the running tally to "convert" to
110-
// Inf/NaN. This keeps the behavior bwc from before kahan summing
111-
if (Double.isFinite(avg.sum) == false) {
112-
sum += avg.sum;
113-
} else if (Double.isFinite(sum)) {
114-
double corrected = avg.sum - sumCompensation;
115-
double newSum = sum + corrected;
116-
sumCompensation = (newSum - sum) - corrected;
117-
sum = newSum;
118-
}
98+
weightCompensation.add(avg.weight);
99+
sumCompensation.add(avg.sum);
119100
}
120-
return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData());
101+
102+
return new InternalWeightedAvg(getName(), sumCompensation.value(), weightCompensation.value(),
103+
format, pipelineAggregators(), getMetaData());
121104
}
105+
122106
@Override
123107
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
124108
builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null);

server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java

+4-10
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,16 @@ public void collect(int doc, long bucket) throws IOException {
105105
// accurate than naive summation.
106106
double sum = sums.get(bucket);
107107
double compensation = compensations.get(bucket);
108+
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
108109

109110
for (int i = 0; i < valuesCount; i++) {
110111
double value = values.nextValue();
111-
if (Double.isFinite(value) == false) {
112-
sum += value;
113-
} else if (Double.isFinite(sum)) {
114-
double corrected = value - compensation;
115-
double newSum = sum + corrected;
116-
compensation = (newSum - sum) - corrected;
117-
sum = newSum;
118-
}
112+
kahanSummation.add(value);
119113
min = Math.min(min, value);
120114
max = Math.max(max, value);
121115
}
122-
sums.set(bucket, sum);
123-
compensations.set(bucket, compensation);
116+
sums.set(bucket, kahanSummation.value());
117+
compensations.set(bucket, kahanSummation.delta());
124118
mins.set(bucket, min);
125119
maxes.set(bucket, max);
126120
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java

+6-10
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,15 @@ public void collect(int doc, long bucket) throws IOException {
8181
// accurate than naive summation.
8282
double sum = sums.get(bucket);
8383
double compensation = compensations.get(bucket);
84+
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
85+
8486
for (int i = 0; i < valuesCount; i++) {
8587
double value = values.nextValue();
86-
if (Double.isFinite(value) == false) {
87-
sum += value;
88-
} else if (Double.isFinite(sum)) {
89-
double corrected = value - compensation;
90-
double newSum = sum + corrected;
91-
compensation = (newSum - sum) - corrected;
92-
sum = newSum;
93-
}
88+
kahanSummation.add(value);
9489
}
95-
compensations.set(bucket, compensation);
96-
sums.set(bucket, sum);
90+
91+
compensations.set(bucket, kahanSummation.delta());
92+
sums.set(bucket, kahanSummation.value());
9793
}
9894
}
9995
};

0 commit comments

Comments
 (0)