Skip to content

[7.x] Refactor and DRY up Kahan Sum algorithm (#48558) #48959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,15 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);

for (int i = 0; i < valueCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(value);
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);

sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.search.aggregations.metrics;


/**
* Used to calculate sums using the Kahan summation algorithm.
*
* <p>The Kahan summation algorithm (also known as compensated summation) reduces the numerical errors that
* occur when adding a sequence of finite precision floating point numbers. Numerical errors arise due to
* truncation and rounding. These errors can lead to numerical instability.
*
* @see <a href="http://en.wikipedia.org/wiki/Kahan_summation_algorithm">Kahan Summation Algorithm</a>
*/
public class CompensatedSum {

private static final double NO_CORRECTION = 0.0;

private double value;
private double delta;

/**
* Used to calculate sums using the Kahan summation algorithm.
*
* @param value the sum
* @param delta correction term
*/
public CompensatedSum(double value, double delta) {
this.value = value;
this.delta = delta;
}

/**
* The value of the sum.
*/
public double value() {
return value;
}

/**
* The correction term.
*/
public double delta() {
return delta;
}

/**
* Increments the Kahan sum by adding a value without a correction term.
*/
public CompensatedSum add(double value) {
return add(value, NO_CORRECTION);
}

/**
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
*/
public CompensatedSum add(double value, double delta) {
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(value) == false) {
this.value = value + this.value;
}

if (Double.isFinite(this.value)) {
double correctedSum = value + (this.delta + delta);
double updatedValue = this.value + correctedSum;
this.delta = correctedSum - (updatedValue - this.value);
this.value = updatedValue;
}

return this;
}


}

Original file line number Diff line number Diff line change
Expand Up @@ -117,34 +117,24 @@ public void collect(int doc, long bucket) throws IOException {
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);

double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
sumOfSqr += value * value;
} else {
if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
if (Double.isFinite(sumOfSqr)) {
double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
}
}
compensatedSum.add(value);
compensatedSumOfSqr.add(value * value);
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);

sums.set(bucket, compensatedSum.value());
compensations.set(bucket, compensatedSum.delta());
sumOfSqrs.set(bucket, compensatedSumOfSqr.value());
compensationOfSqrs.set(bucket, compensatedSumOfSqr.delta());
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,21 @@ public void collect(int doc, long bucket) throws IOException {
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);

CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);

// update the sum
for (int i = 0; i < valueCount; ++i) {
GeoPoint value = values.nextValue();
//latitude
double correctedLat = value.getLat() - compensationLat;
double newSumLat = sumLat + correctedLat;
compensationLat = (newSumLat - sumLat) - correctedLat;
sumLat = newSumLat;
compensatedSumLat.add(value.getLat());
//longitude
double correctedLon = value.getLon() - compensationLon;
double newSumLon = sumLon + correctedLon;
compensationLon = (newSumLon - sumLon) - correctedLon;
sumLon = newSumLon;
compensatedSumLon.add(value.getLon());
}
lonSum.set(bucket, sumLon);
lonCompensations.set(bucket, compensationLon);
latSum.set(bucket, sumLat);
latCompensations.set(bucket, compensationLat);
lonSum.set(bucket, compensatedSumLon.value());
lonCompensations.set(bucket, compensatedSumLon.delta());
latSum.set(bucket, compensatedSumLat.value());
latCompensations.set(bucket, compensatedSumLat.delta());
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,16 @@ public String getWriteableName() {

@Override
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
InternalAvg avg = (InternalAvg) aggregation;
count += avg.count;
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(avg.sum);
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
return new InternalAvg(getName(), kahanSummation.value(), count, format, pipelineAggregators(), getMetaData());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,26 +149,18 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
long count = 0;
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
CompensatedSum kahanSummation = new CompensatedSum(0, 0);

for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
min = Math.min(min, stats.getMin());
max = Math.max(max, stats.getMax());
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = stats.getSum();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(stats.getSum());
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
return new InternalStats(name, count, kahanSummation.value(), min, max, format, pipelineAggregators(), getMetaData());
}

static class Fields {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,12 @@ public double getValue() {
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
for (InternalAggregation aggregation : aggregations) {
double value = ((InternalSum) aggregation).sum;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(value);
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
return new InternalSum(name, kahanSummation.value(), format, pipelineAggregators(), getMetaData());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,21 @@ public String getWriteableName() {

@Override
public InternalWeightedAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double weight = 0;
double sum = 0;
double sumCompensation = 0;
double weightCompensation = 0;
CompensatedSum sumCompensation = new CompensatedSum(0, 0);
CompensatedSum weightCompensation = new CompensatedSum(0, 0);

// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
InternalWeightedAvg avg = (InternalWeightedAvg) aggregation;
// If the weight is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.weight) == false) {
weight += avg.weight;
} else if (Double.isFinite(weight)) {
double corrected = avg.weight - weightCompensation;
double newWeight = weight + corrected;
weightCompensation = (newWeight - weight) - corrected;
weight = newWeight;
}
// If the avg is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - sumCompensation;
double newSum = sum + corrected;
sumCompensation = (newSum - sum) - corrected;
sum = newSum;
}
weightCompensation.add(avg.weight);
sumCompensation.add(avg.sum);
}
return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData());

return new InternalWeightedAvg(getName(), sumCompensation.value(), weightCompensation.value(),
format, pipelineAggregators(), getMetaData());
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,16 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(value);
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,15 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(value);
}
compensations.set(bucket, compensation);
sums.set(bucket, sum);

compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
}
}
};
Expand Down
Loading