Skip to content

[Transform] add support for extended_stats #120340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/120340.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 120340
summary: Add support for `extended_stats`
area: Transform
type: enhancement
issues: []
1 change: 1 addition & 0 deletions docs/reference/rest-api/common-parms.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ currently supported:
* <<search-aggregations-pipeline-bucket-script-aggregation,Bucket script>>
* <<search-aggregations-pipeline-bucket-selector-aggregation,Bucket selector>>
* <<search-aggregations-metrics-cardinality-aggregation,Cardinality>>
* <<search-aggregations-metrics-extendedstats-aggregation,Extended Stats>>
* <<search-aggregations-bucket-filter-aggregation,Filter>>
* <<search-aggregations-metrics-geobounds-aggregation,Geo bounds>>
* <<search-aggregations-metrics-geocentroid-aggregation,Geo centroid>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class ExtendedStatsAggregationBuilder extends ValuesSourceAggregationBuilder.MetricsAggregationBuilder<
Expand Down Expand Up @@ -87,6 +88,11 @@ public Set<String> metricNames() {
return InternalExtendedStats.METRIC_NAMES;
}

@Override
public Optional<Set<String>> getOutputFieldNames() {
return Optional.of(InternalExtendedStats.Fields.OUTPUT_FORMAT);
}

@Override
protected ValuesSourceType defaultValueSourceType() {
return CoreValuesSourceType.NUMERIC;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.TriConsumer;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.DocValueFormat;
Expand All @@ -19,6 +20,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -337,6 +339,67 @@ static class Fields {
public static final String LOWER_POPULATION = "lower_population";
public static final String UPPER_SAMPLING = "upper_sampling";
public static final String LOWER_SAMPLING = "lower_sampling";

static final Set<String> OUTPUT_FORMAT = Set.of(
Metrics.count.name(),
Metrics.sum.name(),
Metrics.min.name(),
Metrics.max.name(),
Metrics.avg.name(),
SUM_OF_SQRS,
VARIANCE,
VARIANCE_POPULATION,
VARIANCE_SAMPLING,
STD_DEVIATION,
STD_DEVIATION_POPULATION,
STD_DEVIATION_SAMPLING,
STD_DEVIATION_BOUNDS + "." + UPPER,
STD_DEVIATION_BOUNDS + "." + LOWER,
STD_DEVIATION_BOUNDS + "." + UPPER_POPULATION,
STD_DEVIATION_BOUNDS + "." + LOWER_POPULATION,
STD_DEVIATION_BOUNDS + "." + UPPER_SAMPLING,
STD_DEVIATION_BOUNDS + "." + LOWER_SAMPLING
);
}

public Map<String, Object> asIndexableMap() {
if (count != 0) {
// NumberFieldMapper will invalidate non-finite doubles
TriConsumer<Map<String, Object>, String, Double> putIfValidDouble = (map, key, value) -> {
if (Double.isFinite(value)) {
map.put(key, value);
}
};
var extendedStatsMap = new HashMap<String, Object>(13);
extendedStatsMap.put(Metrics.count.name(), getCount());
putIfValidDouble.apply(extendedStatsMap, Metrics.sum.name(), getSum());
putIfValidDouble.apply(extendedStatsMap, Metrics.min.name(), getMin());
putIfValidDouble.apply(extendedStatsMap, Metrics.max.name(), getMax());
putIfValidDouble.apply(extendedStatsMap, Metrics.avg.name(), getAvg());

putIfValidDouble.apply(extendedStatsMap, Fields.SUM_OF_SQRS, sumOfSqrs);
putIfValidDouble.apply(extendedStatsMap, Fields.VARIANCE, getVariance());
putIfValidDouble.apply(extendedStatsMap, Fields.VARIANCE_POPULATION, getVariancePopulation());
putIfValidDouble.apply(extendedStatsMap, Fields.VARIANCE_SAMPLING, getVarianceSampling());
putIfValidDouble.apply(extendedStatsMap, Fields.STD_DEVIATION, getStdDeviation());
putIfValidDouble.apply(extendedStatsMap, Fields.STD_DEVIATION_POPULATION, getStdDeviationPopulation());
putIfValidDouble.apply(extendedStatsMap, Fields.STD_DEVIATION_SAMPLING, getStdDeviationSampling());

var stdDevBounds = new HashMap<String, Object>(6);
putIfValidDouble.apply(stdDevBounds, Fields.UPPER, getStdDeviationBound(Bounds.UPPER));
putIfValidDouble.apply(stdDevBounds, Fields.LOWER, getStdDeviationBound(Bounds.LOWER));
putIfValidDouble.apply(stdDevBounds, Fields.UPPER_POPULATION, getStdDeviationBound(Bounds.UPPER_POPULATION));
putIfValidDouble.apply(stdDevBounds, Fields.LOWER_POPULATION, getStdDeviationBound(Bounds.LOWER_POPULATION));
putIfValidDouble.apply(stdDevBounds, Fields.UPPER_SAMPLING, getStdDeviationBound(Bounds.UPPER_SAMPLING));
putIfValidDouble.apply(stdDevBounds, Fields.LOWER_SAMPLING, getStdDeviationBound(Bounds.LOWER_SAMPLING));
if (stdDevBounds.isEmpty() == false) {
extendedStatsMap.put(Fields.STD_DEVIATION_BOUNDS, stdDevBounds);
}

return extendedStatsMap;
} else {
return Map.of();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
import org.elasticsearch.index.mapper.DateFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;

import java.io.IOException;
import java.util.Map;
import java.util.function.Consumer;

import static java.util.Collections.singleton;
import static org.elasticsearch.search.aggregations.AggregationBuilders.stats;
import static org.hamcrest.Matchers.is;

public class ExtendedStatsAggregatorTests extends AggregatorTestCase {
private static final double TOLERANCE = 1e-5;
Expand Down Expand Up @@ -304,6 +306,13 @@ public void testCase(
testCase(buildIndex, verify, new AggTestConfig(aggBuilder, ft));
}

@Override
protected <T extends AggregationBuilder, V extends InternalAggregation> void verifyOutputFieldNames(T aggregationBuilder, V agg)
throws IOException {
assertTrue(aggregationBuilder.getOutputFieldNames().isPresent());
assertThat(aggregationBuilder.getOutputFieldNames().get(), is(InternalExtendedStats.Fields.OUTPUT_FORMAT));
}

static class ExtendedSimpleStatsAggregator extends StatsAggregatorTests.SimpleStatsAggregator {
double sumOfSqrs = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,29 @@

package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.notNullValue;

public class InternalExtendedStatsTests extends InternalAggregationTestCase<InternalExtendedStats> {

Expand Down Expand Up @@ -209,4 +222,75 @@ private void verifySumOfSqrsOfDoubles(double[] values, double expectedSumOfSqrs,
InternalExtendedStats reduced = (InternalExtendedStats) InternalAggregationTestCase.reduce(aggregations, null);
assertEquals(expectedSumOfSqrs, reduced.getSumOfSquares(), delta);
}

@SuppressWarnings(value = "unchecked")
public void testAsMapMatchesXContent() throws IOException {
var stats = new InternalExtendedStats(
"testAsMapIsSameAsXContent",
randomLongBetween(1, 50),
randomDoubleBetween(1, 50, true),
randomDoubleBetween(1, 50, true),
randomDoubleBetween(1, 50, true),
randomDoubleBetween(1, 50, true),
sigma,
DocValueFormat.RAW,
Map.of()
);

var outputMap = stats.asIndexableMap();
assertThat(outputMap, notNullValue());

Map<String, Object> xContentMap;
try (var builder = XContentFactory.jsonBuilder()) {
builder.startObject();
stats.doXContentBody(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
xContentMap = XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
}
assertThat(xContentMap, notNullValue());

// serializing -> deserializing converts the long to an int, so we convert it back to test
var countMetricName = InternalStats.Metrics.count.name();
var xContentCount = xContentMap.get(countMetricName);
assertThat(xContentCount, isA(Integer.class));
assertThat(((Integer) xContentCount).longValue(), equalTo(outputMap.get(countMetricName)));

// verify the entries in the bounds map are similar
var xContentStdDevBounds = (Map<String, Object>) xContentMap.get(InternalExtendedStats.Fields.STD_DEVIATION_BOUNDS);
var outputStdDevBounds = (Map<String, Object>) outputMap.get(InternalExtendedStats.Fields.STD_DEVIATION_BOUNDS);
xContentStdDevBounds.forEach((key, value) -> {
if (value instanceof String == false || Double.isFinite(Double.parseDouble(value.toString()))) {
assertThat(outputStdDevBounds.get(key), equalTo(value));
}
});

// verify all the other entries that are not "std_deviation_bounds" or "count"
Predicate<Map.Entry<String, Object>> notCountOrStdDevBounds = Predicate.not(
e -> e.getKey().equals(countMetricName) || e.getKey().equals(InternalExtendedStats.Fields.STD_DEVIATION_BOUNDS)
);
xContentMap.entrySet().stream().filter(notCountOrStdDevBounds).forEach(e -> {
if (e.getValue() instanceof String == false || Double.isFinite(Double.parseDouble(e.getValue().toString()))) {
assertThat(outputMap.get(e.getKey()), equalTo(e.getValue()));
}
});
}

public void testIndexableMapExcludesNaN() {
var stats = new InternalExtendedStats(
"testAsMapIsSameAsXContent",
randomLongBetween(1, 50),
Double.NaN,
Double.NaN,
Double.NaN,
Double.NaN,
sigma,
DocValueFormat.RAW,
Map.of()
);

var outputMap = stats.asIndexableMap();
assertThat(outputMap, is(aMapWithSize(1)));
assertThat(outputMap, hasKey(InternalStats.Metrics.count.name()));
assertThat(outputMap.get(InternalStats.Metrics.count.name()), is(stats.getCount()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,84 @@ public void testPivotWithTopMetrics() throws Exception {
assertEquals("business_3", actual);
}

@SuppressWarnings(value = "unchecked")
public void testPivotWithExtendedStats() throws Exception {
var transformId = "extended_stats_transform";
var transformIndex = "extended_stats_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, transformIndex);

var createTransformRequest = createRequestWithAuth(
"PUT",
getTransformEndpoint() + transformId,
BASIC_AUTH_VALUE_TRANSFORM_ADMIN_WITH_SOME_DATA_ACCESS
);

var config = Strings.format("""
{
"source": {
"index": "%s"
},
"dest": {
"index": "%s"
},
"pivot": {
"group_by": {
"reviewer": {
"terms": {
"field": "user_id"
}
}
},
"aggregations": {
"stars": {
"extended_stats": {
"field": "stars"
}
}
}
}
}""", REVIEWS_INDEX_NAME, transformIndex);

createTransformRequest.setJsonEntity(config);
var createTransformResponse = entityAsMap(client().performRequest(createTransformRequest));
assertThat(createTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));

startAndWaitForTransform(transformId, transformIndex, BASIC_AUTH_VALUE_TRANSFORM_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(transformIndex));

var searchResult = getAsMap(transformIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
var stdDevMap = (Map<String, Object>) ((List<?>) XContentMapValues.extractValue("hits.hits._source.stars", searchResult)).get(0);
assertThat(stdDevMap.get("count"), equalTo(41));
assertThat(
stdDevMap,
allOf(
hasEntry("sum", 159.0),
hasEntry("min", 1.0),
hasEntry("max", 5.0),
hasEntry("avg", 3.8780487804878048),
hasEntry("sum_of_squares", 711.0),
hasEntry("variance", 2.3022010707911953),
hasEntry("variance_population", 2.3022010707911953),
hasEntry("variance_sampling", 2.3597560975609753),
hasEntry("std_deviation", 1.5173005868288574),
hasEntry("std_deviation_sampling", 1.5361497640402693),
hasEntry("std_deviation_population", 1.5173005868288574)
)
);
assertThat(
(Map<String, ?>) stdDevMap.get("std_deviation_bounds"),
allOf(
hasEntry("upper", 6.91264995414552),
hasEntry("lower", 0.84344760683009),
hasEntry("upper_population", 6.91264995414552),
hasEntry("lower_population", 0.84344760683009),
hasEntry("upper_sampling", 6.950348308568343),
hasEntry("lower_sampling", 0.8057492524072662)
)
);
}

public void testPivotWithBoxplot() throws Exception {
String transformId = "boxplot_transform";
String transformIndex = "boxplot_pivot_reviews";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.search.aggregations.bucket.range.Range;
import org.elasticsearch.search.aggregations.metrics.GeoBounds;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.InternalExtendedStats;
import org.elasticsearch.search.aggregations.metrics.MultiValueAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.MultiValue;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
Expand Down Expand Up @@ -69,6 +70,7 @@ public final class AggregationResultUtils {
tempMap.put(GeoShapeMetricAggregation.class.getName(), new GeoShapeMetricAggExtractor());
tempMap.put(MultiValue.class.getName(), new NumericMultiValueAggExtractor());
tempMap.put(MultiValueAggregation.class.getName(), new MultiValueAggExtractor());
tempMap.put(InternalExtendedStats.class.getName(), new ExtendedStatsExtractor());
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
}

Expand Down Expand Up @@ -171,6 +173,9 @@ static AggValueExtractor getExtractor(Aggregation aggregation) {
// TODO: can the Range extractor be removed?
} else if (aggregation instanceof Range) {
return TYPE_VALUE_EXTRACTOR_MAP.get(Range.class.getName());
} else if (aggregation instanceof InternalExtendedStats) {
// note: extended stats is also a multi bucket agg, therefore check range first
return TYPE_VALUE_EXTRACTOR_MAP.get(InternalExtendedStats.class.getName());
} else if (aggregation instanceof MultiValue) {
return TYPE_VALUE_EXTRACTOR_MAP.get(MultiValue.class.getName());
} else if (aggregation instanceof MultiValueAggregation) {
Expand Down Expand Up @@ -281,6 +286,13 @@ public Object value(Aggregation agg, Map<String, String> fieldTypeMap, String lo
}
}

static class ExtendedStatsExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, Map<String, String> fieldTypeMap, String lookupFieldPrefix) {
return ((InternalExtendedStats) agg).asIndexableMap();
}
}

static class MultiValueAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, Map<String, String> fieldTypeMap, String lookupFieldPrefix) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public final class TransformAggregations {
"date_histogram",
"date_range",
"diversified_sampler",
"extended_stats", // https://github.com/elastic/elasticsearch/issues/51925
"filters",
"geo_distance",
"geohash_grid",
Expand Down Expand Up @@ -120,7 +119,8 @@ enum AggregationType {
MISSING("missing", LONG),
TOP_METRICS("top_metrics", SOURCE),
STATS("stats", DOUBLE),
BOXPLOT("boxplot", DOUBLE);
BOXPLOT("boxplot", DOUBLE),
EXTENDED_STATS("extended_stats", DOUBLE);

private final String aggregationType;
private final String targetMapping;
Expand Down
Loading