Skip to content

Commit f7cd046

Browse files
authored
[ML] Changes default destination index field mapping and adds scripted_metric agg (#40750)
* [ML] Allowing destination index mappings to have dynamic types, adds script_metric agg * Making dynamic|source mapping explicit
1 parent 3ea36b7 commit f7cd046

File tree

8 files changed

+200
-15
lines changed

8 files changed

+200
-15
lines changed

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,59 @@ public void testPivotWithMaxOnDateField() throws Exception {
314314
assertThat(actual, containsString("2017-01-15T"));
315315
}
316316

317+
public void testPivotWithScriptedMetricAgg() throws Exception {
318+
String transformId = "scriptedMetricPivot";
319+
String dataFrameIndex = "scripted_metric_pivot_reviews";
320+
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
321+
322+
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
323+
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
324+
325+
String config = "{"
326+
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
327+
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
328+
329+
config += " \"pivot\": {"
330+
+ " \"group_by\": {"
331+
+ " \"reviewer\": {"
332+
+ " \"terms\": {"
333+
+ " \"field\": \"user_id\""
334+
+ " } } },"
335+
+ " \"aggregations\": {"
336+
+ " \"avg_rating\": {"
337+
+ " \"avg\": {"
338+
+ " \"field\": \"stars\""
339+
+ " } },"
340+
+ " \"squared_sum\": {"
341+
+ " \"scripted_metric\": {"
342+
+ " \"init_script\": \"state.reviews_sqrd = []\","
343+
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
344+
+ " \"combine_script\": \"state.reviews_sqrd\","
345+
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
346+
+ " } }"
347+
+ " } }"
348+
+ "}";
349+
350+
createDataframeTransformRequest.setJsonEntity(config);
351+
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
352+
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
353+
assertTrue(indexExists(dataFrameIndex));
354+
355+
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
356+
357+
// we expect 27 documents as there shall be 27 user_id's
358+
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
359+
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
360+
361+
// get and check some users
362+
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
363+
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
364+
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
365+
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
366+
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
367+
assertEquals(711.0, actual.doubleValue(), 0.000001);
368+
}
369+
317370
private void assertOnePivotValue(String query, double expected) throws IOException {
318371
Map<String, Object> searchResult = getAsMap(query);
319372

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportPreviewDataFrameTransformAction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,5 @@ private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> l
9595
},
9696
listener::onFailure
9797
));
98-
9998
}
10099
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
1414
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
1515
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
16+
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
1617
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
1718
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
1819
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
@@ -73,6 +74,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
7374
} else {
7475
document.put(aggName, aggResultSingleValue.getValueAsString());
7576
}
77+
} else if (aggResult instanceof ScriptedMetric) {
78+
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
7679
} else {
7780
// Execution should never reach this point!
7881
// Creating transforms with unsupported aggregations shall not be possible

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import java.util.stream.Stream;
1313

1414
public final class Aggregations {
15+
16+
// the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data.
17+
private static final String DYNAMIC = "_dynamic";
18+
// the field mapping should be determined explicitly from the source field mapping if possible.
19+
private static final String SOURCE = "_source";
1520
private Aggregations() {}
1621

1722
/**
@@ -27,9 +32,10 @@ enum AggregationType {
2732
AVG("avg", "double"),
2833
CARDINALITY("cardinality", "long"),
2934
VALUE_COUNT("value_count", "long"),
30-
MAX("max", null),
31-
MIN("min", null),
32-
SUM("sum", null);
35+
MAX("max", SOURCE),
36+
MIN("min", SOURCE),
37+
SUM("sum", SOURCE),
38+
SCRIPTED_METRIC("scripted_metric", DYNAMIC);
3339

3440
private final String aggregationType;
3541
private final String targetMapping;
@@ -55,8 +61,12 @@ public static boolean isSupportedByDataframe(String aggregationType) {
5561
return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT));
5662
}
5763

64+
public static boolean isDynamicMapping(String targetMapping) {
65+
return DYNAMIC.equals(targetMapping);
66+
}
67+
5868
public static String resolveTargetMapping(String aggregationType, String sourceType) {
5969
AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT));
60-
return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping();
70+
return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping();
6171
}
6272
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.client.Client;
1616
import org.elasticsearch.index.mapper.NumberFieldMapper;
1717
import org.elasticsearch.search.aggregations.AggregationBuilder;
18+
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
1819
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
1920
import org.elasticsearch.xpack.core.ClientHelper;
2021
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
@@ -75,6 +76,8 @@ public static void deduceMappings(final Client client,
7576
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
7677
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
7778
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
79+
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
80+
aggregationTypes.put(agg.getName(), agg.getType());
7881
} else {
7982
// execution should not reach this point
8083
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
@@ -127,15 +130,17 @@ private static Map<String, String> resolveMappings(Map<String, String> aggregati
127130

128131
aggregationTypes.forEach((targetFieldName, aggregationName) -> {
129132
String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName);
130-
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName));
133+
String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName);
134+
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping);
131135

132136
logger.debug(
133137
"Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]");
134-
if (destinationMapping != null) {
138+
if (Aggregations.isDynamicMapping(destinationMapping)) {
139+
logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]");
140+
} else if (destinationMapping != null) {
135141
targetMapping.put(targetFieldName, destinationMapping);
136142
} else {
137-
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
138-
targetMapping.put(targetFieldName, "double");
143+
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
139144
}
140145
});
141146

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
3636
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
3737
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
38+
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
3839
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
3940
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
4041
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
42+
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
4143
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
4244
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
4345
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
@@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
7678
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
7779
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
7880
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
81+
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
7982
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
8083
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
8184
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
@@ -409,6 +412,92 @@ aggTypedName2, asMap(
409412
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
410413
}
411414

415+
public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
416+
String targetField = randomAlphaOfLengthBetween(5, 10);
417+
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
418+
419+
GroupConfig groupBy = parseGroupConfig("{"
420+
+ "\"" + targetField + "\" : {"
421+
+ " \"terms\" : {"
422+
+ " \"field\" : \"doesn't_matter_for_this_test\""
423+
+ " } },"
424+
+ "\"" + targetField2 + "\" : {"
425+
+ " \"terms\" : {"
426+
+ " \"field\" : \"doesn't_matter_for_this_test\""
427+
+ " } }"
428+
+ "}");
429+
430+
String aggName = randomAlphaOfLengthBetween(5, 10);
431+
String aggTypedName = "scripted_metric#" + aggName;
432+
433+
Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));
434+
435+
Map<String, Object> input = asMap(
436+
"buckets",
437+
asList(
438+
asMap(
439+
KEY, asMap(
440+
targetField, "ID1",
441+
targetField2, "ID1_2"
442+
),
443+
aggTypedName, asMap(
444+
"value", asMap("field", 123.0)),
445+
DOC_COUNT, 1),
446+
asMap(
447+
KEY, asMap(
448+
targetField, "ID1",
449+
targetField2, "ID2_2"
450+
),
451+
aggTypedName, asMap(
452+
"value", asMap("field", 1.0)),
453+
DOC_COUNT, 2),
454+
asMap(
455+
KEY, asMap(
456+
targetField, "ID2",
457+
targetField2, "ID1_2"
458+
),
459+
aggTypedName, asMap(
460+
"value", asMap("field", 2.13)),
461+
DOC_COUNT, 3),
462+
asMap(
463+
KEY, asMap(
464+
targetField, "ID3",
465+
targetField2, "ID2_2"
466+
),
467+
aggTypedName, asMap(
468+
"value", asMap("field", 12.0)),
469+
DOC_COUNT, 4)
470+
));
471+
472+
List<Map<String, Object>> expected = asList(
473+
asMap(
474+
targetField, "ID1",
475+
targetField2, "ID1_2",
476+
aggName, asMap("field", 123.0)
477+
),
478+
asMap(
479+
targetField, "ID1",
480+
targetField2, "ID2_2",
481+
aggName, asMap("field", 1.0)
482+
),
483+
asMap(
484+
targetField, "ID2",
485+
targetField2, "ID1_2",
486+
aggName, asMap("field", 2.13)
487+
),
488+
asMap(
489+
targetField, "ID3",
490+
targetField2, "ID2_2",
491+
aggName, asMap("field", 12.0)
492+
)
493+
);
494+
Map<String, String> fieldTypeMap = asStringMap(
495+
targetField, "keyword",
496+
targetField2, "keyword"
497+
);
498+
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
499+
}
500+
412501
public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
413502
String targetField = randomAlphaOfLengthBetween(5, 10);
414503
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,31 @@ public void testResolveTargetMapping() {
1515
assertEquals("double", Aggregations.resolveTargetMapping("avg", "int"));
1616
assertEquals("double", Aggregations.resolveTargetMapping("avg", "double"));
1717

18+
// cardinality
19+
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int"));
20+
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double"));
21+
22+
// value_count
23+
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int"));
24+
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double"));
25+
1826
// max
1927
assertEquals("int", Aggregations.resolveTargetMapping("max", "int"));
2028
assertEquals("double", Aggregations.resolveTargetMapping("max", "double"));
2129
assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float"));
30+
31+
// min
32+
assertEquals("int", Aggregations.resolveTargetMapping("min", "int"));
33+
assertEquals("double", Aggregations.resolveTargetMapping("min", "double"));
34+
assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float"));
35+
36+
// sum
37+
assertEquals("int", Aggregations.resolveTargetMapping("sum", "int"));
38+
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
39+
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));
40+
41+
// scripted_metric
42+
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
43+
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
2244
}
2345
}

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737

3838
import java.io.IOException;
3939
import java.util.ArrayList;
40-
import java.util.Collections;
4140
import java.util.List;
42-
import java.util.Map;
4341
import java.util.Set;
4442
import java.util.concurrent.CountDownLatch;
4543
import java.util.concurrent.TimeUnit;
@@ -176,14 +174,20 @@ private AggregationConfig getValidAggregationConfig() throws IOException {
176174
}
177175

178176
private AggregationConfig getAggregationConfig(String agg) throws IOException {
177+
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
178+
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
179+
"\"scripted_metric\": {\n" +
180+
" \"init_script\" : \"state.transactions = []\",\n" +
181+
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
182+
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
183+
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
184+
" }\n" +
185+
"}}");
186+
}
179187
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
180188
+ " }\n" + " }" + "}");
181189
}
182190

183-
private Map<String, String> getFieldMappings() {
184-
return Collections.singletonMap("values", "double");
185-
}
186-
187191
private AggregationConfig parseAggregations(String json) throws IOException {
188192
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
189193
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);

0 commit comments

Comments
 (0)