Skip to content

Commit 5bc0a3f

Browse files
authored
Reuse CompensatedSum object in agg collect loops (#49548)
The new CompensatedSum is a nice DRY refactor, but had the unanticipated side effect of creating a lot of object allocation in the aggregation hot collection loop: one object per visited document, per aggregator. In some places it created two per-doc-per-agg (weighted avg, geo centroids, etc) since there were multiple compensations being maintained. This PR moves the object creation out of the hot loop so that it is now created once per segment, and resets the internal state each time through the loop
1 parent 9cc247d commit 5bc0a3f

File tree

7 files changed

+54
-32
lines changed

7 files changed

+54
-32
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
7373
}
7474
final BigArrays bigArrays = context.bigArrays();
7575
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
76+
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
77+
7678
return new LeafBucketCollectorBase(sub, values) {
7779
@Override
7880
public void collect(int doc, long bucket) throws IOException {
@@ -87,7 +89,8 @@ public void collect(int doc, long bucket) throws IOException {
8789
// accurate than naive summation.
8890
double sum = sums.get(bucket);
8991
double compensation = compensations.get(bucket);
90-
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
92+
93+
kahanSummation.reset(sum, compensation);
9194

9295
for (int i = 0; i < valueCount; i++) {
9396
double value = values.nextValue();

server/src/main/java/org/elasticsearch/search/aggregations/metrics/CompensatedSum.java

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ public CompensatedSum add(double value) {
6868
return add(value, NO_CORRECTION);
6969
}
7070

71+
/**
72+
* Resets the internal state to use the new value and compensation delta
73+
*/
74+
public void reset(double value, double delta) {
75+
this.value = value;
76+
this.delta = delta;
77+
}
78+
7179
/**
7280
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
7381
*/

server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
9090
}
9191
final BigArrays bigArrays = context.bigArrays();
9292
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
93+
final CompensatedSum compensatedSum = new CompensatedSum(0, 0);
94+
final CompensatedSum compensatedSumOfSqr = new CompensatedSum(0, 0);
9395
return new LeafBucketCollectorBase(sub, values) {
9496

9597
@Override
@@ -117,11 +119,11 @@ public void collect(int doc, long bucket) throws IOException {
117119
// which is more accurate than naive summation.
118120
double sum = sums.get(bucket);
119121
double compensation = compensations.get(bucket);
120-
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
122+
compensatedSum.reset(sum, compensation);
121123

122124
double sumOfSqr = sumOfSqrs.get(bucket);
123125
double compensationOfSqr = compensationOfSqrs.get(bucket);
124-
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
126+
compensatedSumOfSqr.reset(sumOfSqr, compensationOfSqr);
125127

126128
for (int i = 0; i < valuesCount; i++) {
127129
double value = values.nextValue();

server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
6868
}
6969
final BigArrays bigArrays = context.bigArrays();
7070
final MultiGeoPointValues values = valuesSource.geoPointValues(ctx);
71+
final CompensatedSum compensatedSumLat = new CompensatedSum(0, 0);
72+
final CompensatedSum compensatedSumLon = new CompensatedSum(0, 0);
73+
7174
return new LeafBucketCollectorBase(sub, values) {
7275
@Override
7376
public void collect(int doc, long bucket) throws IOException {
@@ -88,8 +91,8 @@ public void collect(int doc, long bucket) throws IOException {
8891
double sumLon = lonSum.get(bucket);
8992
double compensationLon = lonCompensations.get(bucket);
9093

91-
CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
92-
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
94+
compensatedSumLat.reset(sumLat, compensationLat);
95+
compensatedSumLon.reset(sumLon, compensationLon);
9396

9497
// update the sum
9598
for (int i = 0; i < valueCount; ++i) {

server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
8181
}
8282
final BigArrays bigArrays = context.bigArrays();
8383
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
84+
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
85+
8486
return new LeafBucketCollectorBase(sub, values) {
8587
@Override
8688
public void collect(int doc, long bucket) throws IOException {
@@ -105,7 +107,7 @@ public void collect(int doc, long bucket) throws IOException {
105107
// accurate than naive summation.
106108
double sum = sums.get(bucket);
107109
double compensation = compensations.get(bucket);
108-
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
110+
kahanSummation.reset(sum, compensation);
109111

110112
for (int i = 0; i < valuesCount; i++) {
111113
double value = values.nextValue();

server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
6969
}
7070
final BigArrays bigArrays = context.bigArrays();
7171
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
72+
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
7273
return new LeafBucketCollectorBase(sub, values) {
7374
@Override
7475
public void collect(int doc, long bucket) throws IOException {
@@ -81,7 +82,7 @@ public void collect(int doc, long bucket) throws IOException {
8182
// accurate than naive summation.
8283
double sum = sums.get(bucket);
8384
double compensation = compensations.get(bucket);
84-
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
85+
kahanSummation.reset(sum, compensation);
8586

8687
for (int i = 0; i < valuesCount; i++) {
8788
double value = values.nextValue();

server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregator.java

+28-25
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
4646
private final MultiValuesSource.NumericMultiValuesSource valuesSources;
4747

4848
private DoubleArray weights;
49-
private DoubleArray sums;
50-
private DoubleArray sumCompensations;
49+
private DoubleArray valueSums;
50+
private DoubleArray valueCompensations;
5151
private DoubleArray weightCompensations;
5252
private DocValueFormat format;
5353

@@ -60,8 +60,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
6060
if (valuesSources != null) {
6161
final BigArrays bigArrays = context.bigArrays();
6262
weights = bigArrays.newDoubleArray(1, true);
63-
sums = bigArrays.newDoubleArray(1, true);
64-
sumCompensations = bigArrays.newDoubleArray(1, true);
63+
valueSums = bigArrays.newDoubleArray(1, true);
64+
valueCompensations = bigArrays.newDoubleArray(1, true);
6565
weightCompensations = bigArrays.newDoubleArray(1, true);
6666
}
6767
}
@@ -80,13 +80,15 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
8080
final BigArrays bigArrays = context.bigArrays();
8181
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
8282
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);
83+
final CompensatedSum compensatedValueSum = new CompensatedSum(0, 0);
84+
final CompensatedSum compensatedWeightSum = new CompensatedSum(0, 0);
8385

8486
return new LeafBucketCollectorBase(sub, docValues) {
8587
@Override
8688
public void collect(int doc, long bucket) throws IOException {
8789
weights = bigArrays.grow(weights, bucket + 1);
88-
sums = bigArrays.grow(sums, bucket + 1);
89-
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
90+
valueSums = bigArrays.grow(valueSums, bucket + 1);
91+
valueCompensations = bigArrays.grow(valueCompensations, bucket + 1);
9092
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);
9193

9294
if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
@@ -102,42 +104,43 @@ public void collect(int doc, long bucket) throws IOException {
102104
final int numValues = docValues.docValueCount();
103105
assert numValues > 0;
104106

107+
double valueSum = valueSums.get(bucket);
108+
double valueCompensation = valueCompensations.get(bucket);
109+
compensatedValueSum.reset(valueSum, valueCompensation);
110+
111+
double weightSum = weights.get(bucket);
112+
double weightCompensation = weightCompensations.get(bucket);
113+
compensatedWeightSum.reset(weightSum, weightCompensation);
114+
105115
for (int i = 0; i < numValues; i++) {
106-
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
107-
kahanSum(weight, weights, weightCompensations, bucket);
116+
compensatedValueSum.add(docValues.nextValue() * weight);
117+
compensatedWeightSum.add(weight);
108118
}
119+
120+
valueSums.set(bucket, compensatedValueSum.value());
121+
valueCompensations.set(bucket, compensatedValueSum.delta());
122+
weights.set(bucket, compensatedWeightSum.value());
123+
weightCompensations.set(bucket, compensatedWeightSum.delta());
109124
}
110125
}
111126
};
112127
}
113128

114-
private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) {
115-
// Compute the sum of double values with Kahan summation algorithm which is more
116-
// accurate than naive summation.
117-
double sum = values.get(bucket);
118-
double compensation = compensations.get(bucket);
119-
120-
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation)
121-
.add(value);
122-
123-
values.set(bucket, kahanSummation.value());
124-
compensations.set(bucket, kahanSummation.delta());
125-
}
126129

127130
@Override
128131
public double metric(long owningBucketOrd) {
129-
if (valuesSources == null || owningBucketOrd >= sums.size()) {
132+
if (valuesSources == null || owningBucketOrd >= valueSums.size()) {
130133
return Double.NaN;
131134
}
132-
return sums.get(owningBucketOrd) / weights.get(owningBucketOrd);
135+
return valueSums.get(owningBucketOrd) / weights.get(owningBucketOrd);
133136
}
134137

135138
@Override
136139
public InternalAggregation buildAggregation(long bucket) {
137-
if (valuesSources == null || bucket >= sums.size()) {
140+
if (valuesSources == null || bucket >= valueSums.size()) {
138141
return buildEmptyAggregation();
139142
}
140-
return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
143+
return new InternalWeightedAvg(name, valueSums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
141144
}
142145

143146
@Override
@@ -147,7 +150,7 @@ public InternalAggregation buildEmptyAggregation() {
147150

148151
@Override
149152
public void doClose() {
150-
Releasables.close(weights, sums, sumCompensations, weightCompensations);
153+
Releasables.close(weights, valueSums, valueCompensations, weightCompensations);
151154
}
152155

153156
}

0 commit comments

Comments
 (0)