Skip to content

Commit 19ce9ce

Browse files
authored
[ML][Data Frame] allow null values for aggs with sparse data (#42966)
* [ML][Data Frame] allow null values for aggs with sparse data * Making classes static, memory allocation optimization
1 parent fcaef00 commit 19ce9ce

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
package org.elasticsearch.xpack.dataframe.transforms.pivot;
88

9-
import org.apache.logging.log4j.LogManager;
10-
import org.apache.logging.log4j.Logger;
119
import org.elasticsearch.ElasticsearchException;
10+
import org.elasticsearch.common.Numbers;
1211
import org.elasticsearch.search.aggregations.Aggregation;
1312
import org.elasticsearch.search.aggregations.AggregationBuilder;
1413
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
1514
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
1615
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
17-
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
1816
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
1917
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
2018
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
@@ -23,6 +21,7 @@
2321
import org.elasticsearch.xpack.dataframe.transforms.IDGenerator;
2422

2523
import java.util.Collection;
24+
import java.util.Collections;
2625
import java.util.HashMap;
2726
import java.util.List;
2827
import java.util.Map;
@@ -32,7 +31,15 @@
3231
import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;
3332

3433
public final class AggregationResultUtils {
35-
private static final Logger logger = LogManager.getLogger(AggregationResultUtils.class);
34+
35+
private static final Map<String, AggValueExtractor> TYPE_VALUE_EXTRACTOR_MAP;
36+
static {
37+
Map<String, AggValueExtractor> tempMap = new HashMap<>();
38+
tempMap.put(SingleValue.class.getName(), new SingleValueAggExtractor());
39+
tempMap.put(ScriptedMetric.class.getName(), new ScriptedMetricAggExtractor());
40+
tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
41+
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
42+
}
3643

3744
/**
3845
* Extracts aggregation results from a composite aggregation and puts it into a map.
@@ -73,27 +80,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
7380
// TODO: support other aggregation types
7481
Aggregation aggResult = bucket.getAggregations().get(aggName);
7582

76-
if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
77-
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
78-
// If the type is numeric or if the formatted string is the same as simply making the value a string,
79-
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
80-
if (isNumericType(fieldType) ||
81-
(aggResultSingleValue.getValueAsString().equals(String.valueOf(aggResultSingleValue.value())))) {
82-
updateDocument(document, aggName, aggResultSingleValue.value());
83-
} else {
84-
updateDocument(document, aggName, aggResultSingleValue.getValueAsString());
85-
}
86-
} else if (aggResult instanceof ScriptedMetric) {
87-
updateDocument(document, aggName, ((ScriptedMetric) aggResult).aggregation());
88-
} else if (aggResult instanceof GeoCentroid) {
89-
updateDocument(document, aggName, ((GeoCentroid) aggResult).centroid().toString());
90-
} else {
91-
// Execution should never reach this point!
92-
// Creating transforms with unsupported aggregations shall not be possible
93-
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
94-
aggResult.getType(),
95-
aggResult.getName());
96-
}
83+
AggValueExtractor extractor = getExtractor(aggResult);
84+
updateDocument(document, aggName, extractor.value(aggResult, fieldType));
9785
}
9886

9987
document.put(DataFrameField.DOCUMENT_ID_FIELD, idGen.getID());
@@ -102,6 +90,23 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
10290
});
10391
}
10492

93+
static AggValueExtractor getExtractor(Aggregation aggregation) {
94+
if (aggregation instanceof SingleValue) {
95+
return TYPE_VALUE_EXTRACTOR_MAP.get(SingleValue.class.getName());
96+
} else if (aggregation instanceof ScriptedMetric) {
97+
return TYPE_VALUE_EXTRACTOR_MAP.get(ScriptedMetric.class.getName());
98+
} else if (aggregation instanceof GeoCentroid) {
99+
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoCentroid.class.getName());
100+
} else {
101+
// Execution should never reach this point!
102+
// Creating transforms with unsupported aggregations shall not be possible
103+
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
104+
aggregation.getType(),
105+
aggregation.getName());
106+
}
107+
}
108+
109+
105110
@SuppressWarnings("unchecked")
106111
static void updateDocument(Map<String, Object> document, String fieldName, Object value) {
107112
String[] fieldTokens = fieldName.split("\\.");
@@ -147,4 +152,44 @@ public static class AggregationExtractionException extends ElasticsearchExceptio
147152
super(msg, args);
148153
}
149154
}
155+
156+
private interface AggValueExtractor {
157+
Object value(Aggregation aggregation, String fieldType);
158+
}
159+
160+
private static class SingleValueAggExtractor implements AggValueExtractor {
161+
@Override
162+
public Object value(Aggregation agg, String fieldType) {
163+
SingleValue aggregation = (SingleValue)agg;
164+
// If the double is invalid, this indicates sparse data
165+
if (Numbers.isValidDouble(aggregation.value()) == false) {
166+
return null;
167+
}
168+
// If the type is numeric or if the formatted string is the same as simply making the value a string,
169+
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
170+
if (isNumericType(fieldType) ||
171+
aggregation.getValueAsString().equals(String.valueOf(aggregation.value()))){
172+
return aggregation.value();
173+
} else {
174+
return aggregation.getValueAsString();
175+
}
176+
}
177+
}
178+
179+
private static class ScriptedMetricAggExtractor implements AggValueExtractor {
180+
@Override
181+
public Object value(Aggregation agg, String fieldType) {
182+
ScriptedMetric aggregation = (ScriptedMetric)agg;
183+
return aggregation.aggregation();
184+
}
185+
}
186+
187+
private static class GeoCentroidAggExtractor implements AggValueExtractor {
188+
@Override
189+
public Object value(Aggregation agg, String fieldType) {
190+
GeoCentroid aggregation = (GeoCentroid)agg;
191+
// if the account is `0` iff there is no contained centroid
192+
return aggregation.count() > 0 ? aggregation.centroid().toString() : null;
193+
}
194+
}
150195
}

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ aggTypedName, asMap(
135135
KEY, asMap(
136136
targetField, "ID3"),
137137
aggTypedName, asMap(
138-
"value", 12.55),
139-
DOC_COUNT, 9)
138+
"value", Double.NaN),
139+
DOC_COUNT, 0)
140140
));
141141

142142
List<Map<String, Object>> expected = asList(
@@ -150,14 +150,14 @@ aggTypedName, asMap(
150150
),
151151
asMap(
152152
targetField, "ID3",
153-
aggName, 12.55
153+
aggName, null
154154
)
155155
);
156156
Map<String, String> fieldTypeMap = asStringMap(
157157
targetField, "keyword",
158158
aggName, "double"
159159
);
160-
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 20);
160+
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 11);
161161
}
162162

163163
public void testExtractCompositeAggregationResultsMultipleGroups() throws IOException {
@@ -212,8 +212,8 @@ KEY, asMap(
212212
targetField2, "ID2_2"
213213
),
214214
aggTypedName, asMap(
215-
"value", 12.55),
216-
DOC_COUNT, 4)
215+
"value", Double.NaN),
216+
DOC_COUNT, 0)
217217
));
218218

219219
List<Map<String, Object>> expected = asList(
@@ -235,15 +235,15 @@ aggTypedName, asMap(
235235
asMap(
236236
targetField, "ID3",
237237
targetField2, "ID2_2",
238-
aggName, 12.55
238+
aggName, null
239239
)
240240
);
241241
Map<String, String> fieldTypeMap = asStringMap(
242242
aggName, "double",
243243
targetField, "keyword",
244244
targetField2, "keyword"
245245
);
246-
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
246+
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
247247
}
248248

249249
public void testExtractCompositeAggregationResultsMultiAggregations() throws IOException {
@@ -287,7 +287,7 @@ KEY, asMap(
287287
aggTypedName, asMap(
288288
"value", 12.55),
289289
aggTypedName2, asMap(
290-
"value", -2.44),
290+
"value", Double.NaN),
291291
DOC_COUNT, 1)
292292
));
293293

@@ -305,7 +305,7 @@ aggTypedName2, asMap(
305305
asMap(
306306
targetField, "ID3",
307307
aggName, 12.55,
308-
aggName2, -2.44
308+
aggName2, null
309309
)
310310
);
311311
Map<String, String> fieldTypeMap = asStringMap(
@@ -383,8 +383,8 @@ KEY, asMap(
383383
aggTypedName, asMap(
384384
"value", 12.55),
385385
aggTypedName2, asMap(
386-
"value", -100.44,
387-
"value_as_string", "-100.44F"),
386+
"value", Double.NaN,
387+
"value_as_string", "NaN"),
388388
DOC_COUNT, 4)
389389
));
390390

@@ -411,7 +411,7 @@ aggTypedName2, asMap(
411411
targetField, "ID3",
412412
targetField2, "ID2_2",
413413
aggName, 12.55,
414-
aggName2, "-100.44F"
414+
aggName2, null
415415
)
416416
);
417417
Map<String, String> fieldTypeMap = asStringMap(
@@ -476,8 +476,8 @@ KEY, asMap(
476476
targetField2, "ID2_2"
477477
),
478478
aggTypedName, asMap(
479-
"value", asMap("field", 12.0)),
480-
DOC_COUNT, 4)
479+
"value", null),
480+
DOC_COUNT, 0)
481481
));
482482

483483
List<Map<String, Object>> expected = asList(
@@ -499,14 +499,14 @@ aggName, asMap("field", 2.13)
499499
asMap(
500500
targetField, "ID3",
501501
targetField2, "ID2_2",
502-
aggName, asMap("field", 12.0)
502+
aggName, null
503503
)
504504
);
505505
Map<String, String> fieldTypeMap = asStringMap(
506506
targetField, "keyword",
507507
targetField2, "keyword"
508508
);
509-
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
509+
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
510510
}
511511

512512
public void testExtractCompositeAggregationResultsWithPipelineAggregation() throws IOException {
@@ -576,7 +576,7 @@ KEY, asMap(
576576
aggTypedName, asMap(
577577
"value", 12.0),
578578
pipelineAggTypedName, asMap(
579-
"value", 12.0),
579+
"value", Double.NaN),
580580
DOC_COUNT, 4)
581581
));
582582

@@ -603,7 +603,7 @@ pipelineAggTypedName, asMap(
603603
targetField, "ID3",
604604
targetField2, "ID2_2",
605605
aggName, 12.0,
606-
pipelineAggName, 12.0
606+
pipelineAggName, null
607607
)
608608
);
609609
Map<String, String> fieldTypeMap = asStringMap(

0 commit comments

Comments
 (0)