Skip to content

[ML][Data Frame] allow null values for aggs with sparse data #42966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

package org.elasticsearch.xpack.dataframe.transforms.pivot;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
Expand All @@ -23,6 +21,7 @@
import org.elasticsearch.xpack.dataframe.transforms.IDGenerator;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +31,15 @@
import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;

public final class AggregationResultUtils {
private static final Logger logger = LogManager.getLogger(AggregationResultUtils.class);

private static final Map<String, AggValueExtractor> TYPE_VALUE_EXTRACTOR_MAP;
static {
Map<String, AggValueExtractor> tempMap = new HashMap<>();
tempMap.put(SingleValue.class.getName(), new SingleValueAggExtractor());
tempMap.put(ScriptedMetric.class.getName(), new ScriptedMetricAggExtractor());
tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
}

/**
* Extracts aggregation results from a composite aggregation and puts it into a map.
Expand Down Expand Up @@ -73,27 +80,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
// TODO: support other aggregation types
Aggregation aggResult = bucket.getAggregations().get(aggName);

if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
// If the type is numeric or if the formatted string is the same as simply making the value a string,
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
if (isNumericType(fieldType) ||
(aggResultSingleValue.getValueAsString().equals(String.valueOf(aggResultSingleValue.value())))) {
updateDocument(document, aggName, aggResultSingleValue.value());
} else {
updateDocument(document, aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
updateDocument(document, aggName, ((ScriptedMetric) aggResult).aggregation());
} else if (aggResult instanceof GeoCentroid) {
updateDocument(document, aggName, ((GeoCentroid) aggResult).centroid().toString());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
aggResult.getType(),
aggResult.getName());
}
AggValueExtractor extractor = getExtractor(aggResult);
updateDocument(document, aggName, extractor.value(aggResult, fieldType));
}

document.put(DataFrameField.DOCUMENT_ID_FIELD, idGen.getID());
Expand All @@ -102,6 +90,23 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
});
}

static AggValueExtractor getExtractor(Aggregation aggregation) {
if (aggregation instanceof SingleValue) {
return TYPE_VALUE_EXTRACTOR_MAP.get(SingleValue.class.getName());
} else if (aggregation instanceof ScriptedMetric) {
return TYPE_VALUE_EXTRACTOR_MAP.get(ScriptedMetric.class.getName());
} else if (aggregation instanceof GeoCentroid) {
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoCentroid.class.getName());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
aggregation.getType(),
aggregation.getName());
}
}


@SuppressWarnings("unchecked")
static void updateDocument(Map<String, Object> document, String fieldName, Object value) {
String[] fieldTokens = fieldName.split("\\.");
Expand Down Expand Up @@ -147,4 +152,44 @@ public static class AggregationExtractionException extends ElasticsearchExceptio
super(msg, args);
}
}

private interface AggValueExtractor {
Object value(Aggregation aggregation, String fieldType);
}

private static class SingleValueAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
SingleValue aggregation = (SingleValue)agg;
// If the double is invalid, this indicates sparse data
if (Numbers.isValidDouble(aggregation.value()) == false) {
return null;
}
// If the type is numeric or if the formatted string is the same as simply making the value a string,
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
if (isNumericType(fieldType) ||
aggregation.getValueAsString().equals(String.valueOf(aggregation.value()))){
return aggregation.value();
} else {
return aggregation.getValueAsString();
}
}
}

private static class ScriptedMetricAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
ScriptedMetric aggregation = (ScriptedMetric)agg;
return aggregation.aggregation();
}
}

private static class GeoCentroidAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
GeoCentroid aggregation = (GeoCentroid)agg;
// if the account is `0` iff there is no contained centroid
return aggregation.count() > 0 ? aggregation.centroid().toString() : null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ aggTypedName, asMap(
KEY, asMap(
targetField, "ID3"),
aggTypedName, asMap(
"value", 12.55),
DOC_COUNT, 9)
"value", Double.NaN),
DOC_COUNT, 0)
));

List<Map<String, Object>> expected = asList(
Expand All @@ -150,14 +150,14 @@ aggTypedName, asMap(
),
asMap(
targetField, "ID3",
aggName, 12.55
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
aggName, "double"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 20);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 11);
}

public void testExtractCompositeAggregationResultsMultipleGroups() throws IOException {
Expand Down Expand Up @@ -212,8 +212,8 @@ KEY, asMap(
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", 12.55),
DOC_COUNT, 4)
"value", Double.NaN),
DOC_COUNT, 0)
));

List<Map<String, Object>> expected = asList(
Expand All @@ -235,15 +235,15 @@ aggTypedName, asMap(
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.55
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
aggName, "double",
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
}

public void testExtractCompositeAggregationResultsMultiAggregations() throws IOException {
Expand Down Expand Up @@ -287,7 +287,7 @@ KEY, asMap(
aggTypedName, asMap(
"value", 12.55),
aggTypedName2, asMap(
"value", -2.44),
"value", Double.NaN),
DOC_COUNT, 1)
));

Expand All @@ -305,7 +305,7 @@ aggTypedName2, asMap(
asMap(
targetField, "ID3",
aggName, 12.55,
aggName2, -2.44
aggName2, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
Expand Down Expand Up @@ -383,8 +383,8 @@ KEY, asMap(
aggTypedName, asMap(
"value", 12.55),
aggTypedName2, asMap(
"value", -100.44,
"value_as_string", "-100.44F"),
"value", Double.NaN,
"value_as_string", "NaN"),
DOC_COUNT, 4)
));

Expand All @@ -411,7 +411,7 @@ aggTypedName2, asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.55,
aggName2, "-100.44F"
aggName2, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
Expand Down Expand Up @@ -476,8 +476,8 @@ KEY, asMap(
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
"value", null),
DOC_COUNT, 0)
));

List<Map<String, Object>> expected = asList(
Expand All @@ -499,14 +499,14 @@ aggName, asMap("field", 2.13)
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
}

public void testExtractCompositeAggregationResultsWithPipelineAggregation() throws IOException {
Expand Down Expand Up @@ -576,7 +576,7 @@ KEY, asMap(
aggTypedName, asMap(
"value", 12.0),
pipelineAggTypedName, asMap(
"value", 12.0),
"value", Double.NaN),
DOC_COUNT, 4)
));

Expand All @@ -603,7 +603,7 @@ pipelineAggTypedName, asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.0,
pipelineAggName, 12.0
pipelineAggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
Expand Down