Skip to content

Calculate sum in Kahan summation algorithm in aggregations (#27807) #27848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 22, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;

public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Expand All @@ -55,6 +56,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -76,15 +78,29 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one corner-case is not covered with this logic. Imagine that all values are finite but 2: -Inf and +Inf, for instance:
[+Inf, 4, -Inf]. The expected result is NaN but your logic will make it return +Inf if I'm not mistaken. Maybe this should be something like that instead:

if (Double.isFinite(value) == false || Double.isFinite(sum) == false) {
  sum += value;
} else {
  // kahan summation
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jpountz Thanks a lot! I made a test for both of them:

        double sum = 0;
        double compensation = 0;

        double[] values = new double[]{Double.POSITIVE_INFINITY, 4, Double.NEGATIVE_INFINITY};

        for (int i = 0; i < values.length; i++) {
            double value = values[i];
            if (Double.isFinite(value) == false) {
                sum += value;
            } else if (Double.isFinite(sum)) {
                double corrected = value - compensation;
                double newSum = sum + corrected;
                compensation = (newSum - sum) - corrected;
                sum = newSum;
            }
        }
        System.out.println(sum);

        sum = 0;
        compensation = 0;

        for (int i = 0; i < values.length; i++) {
            double value = values[i];
            if (Double.isFinite(value) == false || Double.isFinite(sum) == false) {
                sum += value;
            } else {
                double corrected = value - compensation;
                double newSum = sum + corrected;
                compensation = (newSum - sum) - corrected;
                sum = newSum;
            }
        }
        System.out.println(sum);

Both of their result are NaN. Because in my code, for each value, if it's not finite, it will be summed up to sum, no matter what sum is. I can also made this change because I also think your way is more intuitive. I also have a random test case for this kind corner-case:

        int n = randomIntBetween(5, 10);
        values = new double[n];
        double sum = 0;
        for (int i = 0; i < n; i++) {
            values[i] = frequently()
                ? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
                : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
            sum += values[i];
        }
        verifyAvgOfDoubles(values, sum / n, 1e-10);

I can make it be regular for example:

        double[] values = new double[]{Double.POSITIVE_INFINITY, 4, Double.NEGATIVE_INFINITY};
        verifyAvgOfDoubles(values, NaN, 0d);

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had misread your code, sorry! I am fine either way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep things the way they are.

}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
Expand Down Expand Up @@ -113,7 +129,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,20 @@ public String getWriteableName() {
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
count += ((InternalAvg) aggregation).count;
sum += ((InternalAvg) aggregation).sum;
InternalAvg avg = (InternalAvg) aggregation;
count += avg.count;
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,23 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
min = Math.min(min, stats.getMin());
max = Math.max(max, stats.getMax());
sum += stats.getSum();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = stats.getSum();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;

Expand All @@ -59,6 +60,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
Expand Down Expand Up @@ -88,6 +90,7 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
Expand All @@ -97,16 +100,28 @@ public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -164,6 +179,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;
DoubleArray sumOfSqrs;
DoubleArray compensationOfSqrs;

public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
Expand All @@ -65,11 +67,13 @@ public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, D
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
sumOfSqrs = bigArrays.newDoubleArray(1, true);
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -95,29 +99,52 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
}

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double sumOfSqr = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum and sum of squires for double values with Kahan summation algorithm
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
sumOfSqr += value * value;
if (Double.isFinite(value) == false) {
sum += value;
sumOfSqr += value * value;
} else {
if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
if (Double.isFinite(sumOfSqr)) {
double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
}
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sumOfSqrs.increment(bucket, sumOfSqr);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -196,6 +223,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static Metrics resolve(String name) {
private final double sigma;

public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs, double sigma,
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
this.sumOfSqrs = sumOfSqrs;
this.sigma = sigma;
Expand Down Expand Up @@ -142,16 +142,25 @@ public String getStdDeviationBoundAsString(Bounds bound) {
@Override
public InternalExtendedStats doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double sumOfSqrs = 0;
double compensationOfSqrs = 0;
for (InternalAggregation aggregation : aggregations) {
InternalExtendedStats stats = (InternalExtendedStats) aggregation;
if (stats.sigma != sigma) {
throw new IllegalStateException("Cannot reduce other stats aggregations that have a different sigma");
}
sumOfSqrs += stats.getSumOfSquares();
double value = stats.getSumOfSquares();
if (Double.isFinite(value) == false) {
sumOfSqrs += value;
} else if (Double.isFinite(sumOfSqrs)) {
double correctedOfSqrs = value - compensationOfSqrs;
double newSumOfSqrs = sumOfSqrs + correctedOfSqrs;
compensationOfSqrs = (newSumOfSqrs - sumOfSqrs) - correctedOfSqrs;
sumOfSqrs = newSumOfSqrs;
}
}
final InternalStats stats = super.doReduce(aggregations, reduceContext);
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma,
format, pipelineAggregators(), getMetaData());
format, pipelineAggregators(), getMetaData());
}

static class Fields {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
private final double sum;

public InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.sum = sum;
this.format = formatter;
Expand Down Expand Up @@ -73,9 +73,20 @@ public double getValue() {

@Override
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
sum += ((InternalSum) aggregation).sum;
double value = ((InternalSum) aggregation).sum;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
Expand All @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}

Expand All @@ -71,13 +73,27 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
Expand Down Expand Up @@ -106,6 +122,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}
Loading