Skip to content

Commit 92a820b

Browse files
authored
[ML] Add bucket_script agg support to data frames (elastic#41594) (elastic#41639)
1 parent a01f451 commit 92a820b

File tree

9 files changed

+248
-30
lines changed

9 files changed

+248
-30
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/AggregationConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.common.xcontent.XContentType;
2222
import org.elasticsearch.search.aggregations.AggregationBuilder;
2323
import org.elasticsearch.search.aggregations.AggregatorFactories;
24+
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2425
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;
2526

2627
import java.io.IOException;
@@ -66,6 +67,10 @@ public Collection<AggregationBuilder> getAggregatorFactories() {
6667
return aggregations.getAggregatorFactories();
6768
}
6869

70+
public Collection<PipelineAggregationBuilder> getPipelineAggregatorFactories() {
71+
return aggregations.getPipelineAggregatorFactories();
72+
}
73+
6974
public static AggregationConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException {
7075
NamedXContentRegistry registry = parser.getXContentRegistry();
7176
Map<String, Object> source = parser.mapOrdered();

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,57 @@ public void testPivotWithScriptedMetricAgg() throws Exception {
368368
assertEquals(711.0, actual.doubleValue(), 0.000001);
369369
}
370370

371+
public void testPivotWithBucketScriptAgg() throws Exception {
372+
String transformId = "bucketScriptPivot";
373+
String dataFrameIndex = "bucket_script_pivot_reviews";
374+
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
375+
376+
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
377+
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
378+
379+
String config = "{"
380+
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
381+
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
382+
383+
config += " \"pivot\": {"
384+
+ " \"group_by\": {"
385+
+ " \"reviewer\": {"
386+
+ " \"terms\": {"
387+
+ " \"field\": \"user_id\""
388+
+ " } } },"
389+
+ " \"aggregations\": {"
390+
+ " \"avg_rating\": {"
391+
+ " \"avg\": {"
392+
+ " \"field\": \"stars\""
393+
+ " } },"
394+
+ " \"avg_rating_again\": {"
395+
+ " \"bucket_script\": {"
396+
+ " \"buckets_path\": {\"param_1\": \"avg_rating\"},"
397+
+ " \"script\": \"return params.param_1\""
398+
+ " } }"
399+
+ " } }"
400+
+ "}";
401+
402+
createDataframeTransformRequest.setJsonEntity(config);
403+
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
404+
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
405+
406+
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
407+
assertTrue(indexExists(dataFrameIndex));
408+
409+
// we expect 27 documents as there shall be 27 user_id's
410+
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
411+
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
412+
413+
// get and check some users
414+
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
415+
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
416+
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
417+
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
418+
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating_again", searchResult)).get(0);
419+
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
420+
}
421+
371422
private void assertOnePivotValue(String query, double expected) throws IOException {
372423
Map<String, Object> searchResult = getAsMap(query);
373424

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.Logger;
1111
import org.elasticsearch.search.aggregations.Aggregation;
1212
import org.elasticsearch.search.aggregations.AggregationBuilder;
13+
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
1314
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
1415
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
1516
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
@@ -21,7 +22,9 @@
2122

2223
import java.util.Collection;
2324
import java.util.HashMap;
25+
import java.util.List;
2426
import java.util.Map;
27+
import java.util.stream.Collectors;
2528
import java.util.stream.Stream;
2629

2730
import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;
@@ -42,6 +45,7 @@ final class AggregationResultUtils {
4245
public static Stream<Map<String, Object>> extractCompositeAggregationResults(CompositeAggregation agg,
4346
GroupConfig groups,
4447
Collection<AggregationBuilder> aggregationBuilders,
48+
Collection<PipelineAggregationBuilder> pipelineAggs,
4549
Map<String, String> fieldTypeMap,
4650
DataFrameIndexerTransformStats stats) {
4751
return agg.getBuckets().stream().map(bucket -> {
@@ -58,18 +62,21 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
5862
document.put(destinationFieldName, value);
5963
});
6064

61-
for (AggregationBuilder aggregationBuilder : aggregationBuilders) {
62-
String aggName = aggregationBuilder.getName();
65+
List<String> aggNames = aggregationBuilders.stream().map(AggregationBuilder::getName).collect(Collectors.toList());
66+
aggNames.addAll(pipelineAggs.stream().map(PipelineAggregationBuilder::getName).collect(Collectors.toList()));
67+
68+
for (String aggName: aggNames) {
6369
final String fieldType = fieldTypeMap.get(aggName);
6470

6571
// TODO: support other aggregation types
6672
Aggregation aggResult = bucket.getAggregations().get(aggName);
6773

6874
if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
6975
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
70-
// If the type is numeric, simply gather the `value` type, otherwise utilize `getValueAsString` so we don't lose
71-
// formatted outputs.
72-
if (isNumericType(fieldType)) {
76+
// If the type is numeric or if the formatted string is the same as simply making the value a string,
77+
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
78+
if (isNumericType(fieldType) ||
79+
(aggResultSingleValue.getValueAsString().equals(String.valueOf(aggResultSingleValue.value())))) {
7380
document.put(aggName, aggResultSingleValue.value());
7481
} else {
7582
document.put(aggName, aggResultSingleValue.getValueAsString());

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ enum AggregationType {
3535
MAX("max", SOURCE),
3636
MIN("min", SOURCE),
3737
SUM("sum", SOURCE),
38-
SCRIPTED_METRIC("scripted_metric", DYNAMIC);
38+
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
39+
BUCKET_SCRIPT("bucket_script", DYNAMIC);
3940

4041
private final String aggregationType;
4142
private final String targetMapping;

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Pivot.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.index.query.QueryBuilder;
2020
import org.elasticsearch.rest.RestStatus;
2121
import org.elasticsearch.search.aggregations.AggregationBuilder;
22+
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2223
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
2324
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
2425
import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -102,10 +103,12 @@ public Stream<Map<String, Object>> extractResults(CompositeAggregation agg,
102103

103104
GroupConfig groups = config.getGroupConfig();
104105
Collection<AggregationBuilder> aggregationBuilders = config.getAggregationConfig().getAggregatorFactories();
106+
Collection<PipelineAggregationBuilder> pipelineAggregationBuilders = config.getAggregationConfig().getPipelineAggregatorFactories();
105107

106108
return AggregationResultUtils.extractCompositeAggregationResults(agg,
107109
groups,
108110
aggregationBuilders,
111+
pipelineAggregationBuilders,
109112
fieldTypeMap,
110113
dataFrameIndexerTransformStats);
111114
}
@@ -148,6 +151,7 @@ private static CompositeAggregationBuilder createCompositeAggregation(PivotConfi
148151
LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput());
149152
compositeAggregation = CompositeAggregationBuilder.parse(COMPOSITE_AGGREGATION_NAME, parser);
150153
config.getAggregationConfig().getAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg));
154+
config.getAggregationConfig().getPipelineAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg));
151155
} catch (IOException e) {
152156
throw new RuntimeException(DataFrameMessages.DATA_FRAME_TRANSFORM_PIVOT_FAILED_TO_CREATE_COMPOSITE_AGGREGATION, e);
153157
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.client.Client;
1616
import org.elasticsearch.index.mapper.NumberFieldMapper;
1717
import org.elasticsearch.search.aggregations.AggregationBuilder;
18+
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
1819
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
1920
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
2021
import org.elasticsearch.xpack.core.ClientHelper;
@@ -85,6 +86,12 @@ public static void deduceMappings(final Client client,
8586
}
8687
}
8788

89+
// For pipeline aggs, since they are referencing other aggregations in the payload, they do not have any
90+
// sourcefieldnames to put into the payload. Though, certain ones, i.e. avg_bucket, do have determinant value types
91+
for (PipelineAggregationBuilder agg : config.getAggregationConfig().getPipelineAggregatorFactories()) {
92+
aggregationTypes.put(agg.getName(), agg.getType());
93+
}
94+
8895
Map<String, String> allFieldNames = new HashMap<>();
8996
allFieldNames.putAll(aggregationSourceFieldNames);
9097
allFieldNames.putAll(fieldNamesForGrouping);

0 commit comments

Comments
 (0)