Skip to content

Commit 15cdec2

Browse files
committed
Properly calculate sum of NaN and infinities
1 parent 95aff79 commit 15cdec2

File tree

10 files changed

+113
-46
lines changed

10 files changed

+113
-46
lines changed

core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,17 @@ public void collect(int doc, long bucket) throws IOException {
8989
double compensation = compensations.get(bucket);
9090

9191
for (int i = 0; i < valueCount; i++) {
92-
double corrected = values.nextValue() - compensation;
93-
double newSum = sum + corrected;
94-
compensation = (newSum - sum) - corrected;
95-
sum = newSum;
92+
double value = values.nextValue();
93+
if (Double.isNaN(value) || Double.isInfinite(value)) {
94+
sum += value;
95+
if (Double.isNaN(sum))
96+
break;
97+
} else if (Double.isFinite(sum)) {
98+
double corrected = value - compensation;
99+
double newSum = sum + corrected;
100+
compensation = (newSum - sum) - corrected;
101+
sum = newSum;
102+
}
96103
}
97104
sums.set(bucket, sum);
98105
compensations.set(bucket, compensation);

core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/InternalAvg.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,18 @@ public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContex
9595
// Compute the sum of double values with Kahan summation algorithm which is more
9696
// accurate than naive summation.
9797
for (InternalAggregation aggregation : aggregations) {
98-
count += ((InternalAvg) aggregation).count;
99-
double corrected = ((InternalAvg) aggregation).sum - compensation;
100-
double newSum = sum + corrected;
101-
compensation = (newSum - sum) - corrected;
102-
sum = newSum;
98+
InternalAvg avg = (InternalAvg) aggregation;
99+
count += avg.count;
100+
if (Double.isNaN(sum) == false) {
101+
if (Double.isNaN(avg.sum) || Double.isInfinite(avg.sum)) {
102+
sum += avg.sum;
103+
} else if (Double.isFinite(sum)) {
104+
double corrected = avg.sum - compensation;
105+
double newSum = sum + corrected;
106+
compensation = (newSum - sum) - corrected;
107+
sum = newSum;
108+
}
109+
}
103110
}
104111
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
105112
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/InternalStats.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,17 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
160160
max = Math.max(max, stats.getMax());
161161
// Compute the sum of double values with Kahan summation algorithm which is more
162162
// accurate than naive summation.
163-
double corrected = stats.getSum() - compensation;
164-
double newSum = sum + corrected;
165-
compensation = (newSum - sum) - corrected;
166-
sum = newSum;
163+
if (Double.isNaN(sum) == false) {
164+
double value = stats.getSum();
165+
if (Double.isNaN(value) || Double.isInfinite(value)) {
166+
sum += value;
167+
} else if (Double.isFinite(sum)) {
168+
double corrected = value - compensation;
169+
double newSum = sum + corrected;
170+
compensation = (newSum - sum) - corrected;
171+
sum = newSum;
172+
}
173+
}
167174
}
168175
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
169176
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,16 @@ public void collect(int doc, long bucket) throws IOException {
110110

111111
for (int i = 0; i < valuesCount; i++) {
112112
double value = values.nextValue();
113-
double corrected = value - compensation;
114-
double newSum = sum + corrected;
115-
compensation = (newSum - sum) - corrected;
116-
sum = newSum;
113+
if (Double.isNaN(sum) == false) {
114+
if (Double.isNaN(value) || Double.isInfinite(value)) {
115+
sum += value;
116+
} else if (Double.isFinite(sum)) {
117+
double corrected = value - compensation;
118+
double newSum = sum + corrected;
119+
compensation = (newSum - sum) - corrected;
120+
sum = newSum;
121+
}
122+
}
117123
min = Math.min(min, value);
118124
max = Math.max(max, value);
119125
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,23 @@ public void collect(int doc, long bucket) throws IOException {
121121
double compensationOfSqr = compensationOfSqrs.get(bucket);
122122
for (int i = 0; i < valuesCount; i++) {
123123
double value = values.nextValue();
124-
double corrected = value - compensation;
125-
double newSum = sum + corrected;
126-
compensation = (newSum - sum) - corrected;
127-
sum = newSum;
128-
129-
double correctedOfSqr = value * value - compensationOfSqr;
130-
double newSumOfSqr = sumOfSqr + correctedOfSqr;
131-
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
132-
sumOfSqr = newSumOfSqr;
124+
if (Double.isNaN(value) || Double.isInfinite(value)) {
125+
sum += value;
126+
sumOfSqr += value * value;
127+
} else {
128+
if (Double.isFinite(sum)) {
129+
double corrected = value - compensation;
130+
double newSum = sum + corrected;
131+
compensation = (newSum - sum) - corrected;
132+
sum = newSum;
133+
}
134+
if (Double.isFinite(sumOfSqr)) {
135+
double correctedOfSqr = value * value - compensationOfSqr;
136+
double newSumOfSqr = sumOfSqr + correctedOfSqr;
137+
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
138+
sumOfSqr = newSumOfSqr;
139+
}
140+
}
133141
min = Math.min(min, value);
134142
max = Math.max(max, value);
135143
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,23 @@ public String getStdDeviationBoundAsString(Bounds bound) {
142142
@Override
143143
public InternalExtendedStats doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
144144
double sumOfSqrs = 0;
145+
double compensationOfSqrs = 0;
145146
for (InternalAggregation aggregation : aggregations) {
146147
InternalExtendedStats stats = (InternalExtendedStats) aggregation;
147148
if (stats.sigma != sigma) {
148149
throw new IllegalStateException("Cannot reduce other stats aggregations that have a different sigma");
149150
}
150-
sumOfSqrs += stats.getSumOfSquares();
151+
if (Double.isNaN(sumOfSqrs) == false) {
152+
double value = stats.getSumOfSquares();
153+
if (Double.isNaN(value) || Double.isInfinite(value)) {
154+
sumOfSqrs += value;
155+
} else if (Double.isFinite(sumOfSqrs)) {
156+
double correctedOfSqrs = value - compensationOfSqrs;
157+
double newSumOfSqrs = sumOfSqrs + correctedOfSqrs;
158+
compensationOfSqrs = (newSumOfSqrs - sumOfSqrs) - correctedOfSqrs;
159+
sumOfSqrs = newSumOfSqrs;
160+
}
161+
}
151162
}
152163
final InternalStats stats = super.doReduce(aggregations, reduceContext);
153164
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma,

core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/InternalSum.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.search.DocValueFormat;
2525
import org.elasticsearch.search.aggregations.InternalAggregation;
2626
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
27-
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;
2827
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2928

3029
import java.io.IOException;
@@ -36,7 +35,7 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
3635
private final double sum;
3736

3837
public InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
39-
Map<String, Object> metaData) {
38+
Map<String, Object> metaData) {
4039
super(name, pipelineAggregators, metaData);
4140
this.sum = sum;
4241
this.format = formatter;
@@ -79,10 +78,17 @@ public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContex
7978
double sum = 0;
8079
double compensation = 0;
8180
for (InternalAggregation aggregation : aggregations) {
82-
double corrected = ((InternalSum) aggregation).sum - compensation;
83-
double newSum = sum + corrected;
84-
compensation = (newSum - sum) - corrected;
85-
sum = newSum;
81+
double value = ((InternalSum) aggregation).sum;
82+
if (Double.isNaN(value) || Double.isInfinite(value)) {
83+
sum += value;
84+
if (Double.isNaN(sum))
85+
break;
86+
} else if (Double.isFinite(sum)) {
87+
double corrected = value - compensation;
88+
double newSum = sum + corrected;
89+
compensation = (newSum - sum) - corrected;
90+
sum = newSum;
91+
}
8692
}
8793
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
8894
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,17 @@ public void collect(int doc, long bucket) throws IOException {
8282
double sum = sums.get(bucket);
8383
double compensation = compensations.get(bucket);
8484
for (int i = 0; i < valuesCount; i++) {
85-
double corrected = values.nextValue() - compensation;
86-
double newSum = sum + corrected;
87-
compensation = (newSum - sum) - corrected;
88-
sum = newSum;
85+
double value = values.nextValue();
86+
if (Double.isNaN(value) || Double.isInfinite(value)) {
87+
sum += value;
88+
if (Double.isNaN(sum))
89+
break;
90+
} else if (Double.isFinite(sum)) {
91+
double corrected = value - compensation;
92+
double newSum = sum + corrected;
93+
compensation = (newSum - sum) - corrected;
94+
sum = newSum;
95+
}
8996
}
9097
compensations.set(bucket, compensation);
9198
sums.set(bucket, sum);

core/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalExtendedStatsTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121

2222
import org.elasticsearch.common.io.stream.Writeable;
2323
import org.elasticsearch.search.DocValueFormat;
24+
import org.elasticsearch.search.aggregations.InternalAggregation;
2425
import org.elasticsearch.search.aggregations.ParsedAggregation;
2526
import org.elasticsearch.search.aggregations.metrics.stats.extended.ExtendedStats.Bounds;
2627
import org.elasticsearch.search.aggregations.metrics.stats.extended.InternalExtendedStats;
2728
import org.elasticsearch.search.aggregations.metrics.stats.extended.ParsedExtendedStats;
2829
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2930
import org.elasticsearch.test.InternalAggregationTestCase;
3031

32+
import java.util.ArrayList;
3133
import java.util.HashMap;
3234
import java.util.List;
3335
import java.util.Map;
@@ -188,4 +190,17 @@ protected InternalExtendedStats mutateInstance(InternalExtendedStats instance) {
188190
}
189191
return new InternalExtendedStats(name, count, sum, min, max, sumOfSqrs, sigma, formatter, pipelineAggregators, metaData);
190192
}
193+
194+
public void testSummationAccuracy() {
195+
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
196+
double sigma = randomDouble();
197+
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
198+
for (double sumOfSqrs : values) {
199+
aggregations.add(new InternalExtendedStats("dummy1", 1, 0.0, 0.0, 0.0, sumOfSqrs, sigma, null, null, null));
200+
}
201+
InternalExtendedStats stats = new InternalExtendedStats("dummy", 1, 0.0, 0.0, 0.0, 0.0, sigma, null, null, null);
202+
InternalExtendedStats reduced = stats.doReduce(aggregations, null);
203+
assertEquals(13.5, reduced.getSumOfSquares(), 0d);
204+
assertEquals("dummy", reduced.getName());
205+
}
191206
}

core/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalSumTests.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,8 @@ protected Writeable.Reader<InternalSum> instanceReader() {
4949

5050
@Override
5151
protected void assertReduced(InternalSum reduced, List<InternalSum> inputs) {
52-
double expectedSum = 0;
53-
double compensation = 0;
54-
for (InternalSum aggregation : inputs) {
55-
double corrected = aggregation.value() - compensation;
56-
double newSum = expectedSum + corrected;
57-
compensation = (newSum - expectedSum) - corrected;
58-
expectedSum = newSum;
59-
}
60-
assertEquals(expectedSum, reduced.getValue(), 0.000d);
52+
double expectedSum = inputs.stream().mapToDouble(InternalSum::getValue).sum();
53+
assertEquals(expectedSum, reduced.getValue(), 0.0001d);
6154
}
6255

6356
public void testSummationAccuracy() throws IOException {

0 commit comments

Comments
 (0)