Skip to content

Commit 3a71b91

Browse files
authored
[ML][Data Frame] add support for geo_bounds aggregation (#44441) (#45281)
This adds support for `geo_bounds` aggregation inside the `pivot.aggregations` configuration. The two points returned from the `geo_bounds` aggregation are transformed into `geo_shape` whose types are dynamic given the point's similarity. * `point` if the two points are identical * `linestring` if the two points share either a latitude or longitude * `polygon` if the two points are completely different The automatically deduced mapping for the resulting field is a `geo_shape`.
1 parent 95d3a8e commit 3a71b91

File tree

5 files changed

+267
-4
lines changed

5 files changed

+267
-4
lines changed

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java

+54
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,60 @@ public void testPivotWithBucketScriptAgg() throws Exception {
677677
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
678678
}
679679

680+
@SuppressWarnings("unchecked")
681+
public void testPivotWithGeoBoundsAgg() throws Exception {
682+
String transformId = "geo_bounds_pivot";
683+
String dataFrameIndex = "geo_bounds_pivot_reviews";
684+
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
685+
686+
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
687+
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
688+
689+
String config = "{"
690+
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
691+
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
692+
693+
config += " \"pivot\": {"
694+
+ " \"group_by\": {"
695+
+ " \"reviewer\": {"
696+
+ " \"terms\": {"
697+
+ " \"field\": \"user_id\""
698+
+ " } } },"
699+
+ " \"aggregations\": {"
700+
+ " \"avg_rating\": {"
701+
+ " \"avg\": {"
702+
+ " \"field\": \"stars\""
703+
+ " } },"
704+
+ " \"boundary\": {"
705+
+ " \"geo_bounds\": {\"field\": \"location\"}"
706+
+ " } } }"
707+
+ "}";
708+
709+
createDataframeTransformRequest.setJsonEntity(config);
710+
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
711+
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
712+
713+
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
714+
assertTrue(indexExists(dataFrameIndex));
715+
716+
// we expect 27 documents as there shall be 27 user_id's
717+
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
718+
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
719+
720+
// get and check some users
721+
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
722+
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
723+
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
724+
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
725+
Map<String, Object> actualObj = (Map<String, Object>) ((List<?>) XContentMapValues.extractValue("hits.hits._source.boundary",
726+
searchResult))
727+
.get(0);
728+
assertThat(actualObj.get("type"), equalTo("point"));
729+
List<Double> coordinates = (List<Double>)actualObj.get("coordinates");
730+
assertEquals((4 + 10), coordinates.get(1), 0.000001);
731+
assertEquals((4 + 15), coordinates.get(0), 0.000001);
732+
}
733+
680734
public void testPivotWithGeoCentroidAgg() throws Exception {
681735
String transformId = "geo_centroid_pivot";
682736
String dataFrameIndex = "geo_centroid_pivot_reviews";

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java

+52-4
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88

99
import org.elasticsearch.ElasticsearchException;
1010
import org.elasticsearch.common.Numbers;
11+
import org.elasticsearch.common.geo.GeoPoint;
12+
import org.elasticsearch.common.geo.builders.LineStringBuilder;
13+
import org.elasticsearch.common.geo.builders.PointBuilder;
14+
import org.elasticsearch.common.geo.builders.PolygonBuilder;
15+
import org.elasticsearch.common.geo.parsers.ShapeParser;
1116
import org.elasticsearch.search.aggregations.Aggregation;
1217
import org.elasticsearch.search.aggregations.AggregationBuilder;
1318
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
1419
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
20+
import org.elasticsearch.search.aggregations.metrics.GeoBounds;
1521
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
1622
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
1723
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
@@ -20,6 +26,7 @@
2026
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
2127
import org.elasticsearch.xpack.dataframe.transforms.IDGenerator;
2228

29+
import java.util.Arrays;
2330
import java.util.Collection;
2431
import java.util.Collections;
2532
import java.util.HashMap;
@@ -38,6 +45,7 @@ public final class AggregationResultUtils {
3845
tempMap.put(SingleValue.class.getName(), new SingleValueAggExtractor());
3946
tempMap.put(ScriptedMetric.class.getName(), new ScriptedMetricAggExtractor());
4047
tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
48+
tempMap.put(GeoBounds.class.getName(), new GeoBoundsAggExtractor());
4149
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
4250
}
4351

@@ -99,6 +107,8 @@ static AggValueExtractor getExtractor(Aggregation aggregation) {
99107
return TYPE_VALUE_EXTRACTOR_MAP.get(ScriptedMetric.class.getName());
100108
} else if (aggregation instanceof GeoCentroid) {
101109
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoCentroid.class.getName());
110+
} else if (aggregation instanceof GeoBounds) {
111+
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoBounds.class.getName());
102112
} else {
103113
// Execution should never reach this point!
104114
// Creating transforms with unsupported aggregations shall not be possible
@@ -155,11 +165,11 @@ public static class AggregationExtractionException extends ElasticsearchExceptio
155165
}
156166
}
157167

158-
private interface AggValueExtractor {
168+
interface AggValueExtractor {
159169
Object value(Aggregation aggregation, String fieldType);
160170
}
161171

162-
private static class SingleValueAggExtractor implements AggValueExtractor {
172+
static class SingleValueAggExtractor implements AggValueExtractor {
163173
@Override
164174
public Object value(Aggregation agg, String fieldType) {
165175
SingleValue aggregation = (SingleValue)agg;
@@ -178,20 +188,58 @@ public Object value(Aggregation agg, String fieldType) {
178188
}
179189
}
180190

181-
private static class ScriptedMetricAggExtractor implements AggValueExtractor {
191+
static class ScriptedMetricAggExtractor implements AggValueExtractor {
182192
@Override
183193
public Object value(Aggregation agg, String fieldType) {
184194
ScriptedMetric aggregation = (ScriptedMetric)agg;
185195
return aggregation.aggregation();
186196
}
187197
}
188198

189-
private static class GeoCentroidAggExtractor implements AggValueExtractor {
199+
static class GeoCentroidAggExtractor implements AggValueExtractor {
190200
@Override
191201
public Object value(Aggregation agg, String fieldType) {
192202
GeoCentroid aggregation = (GeoCentroid)agg;
193203
// if the account is `0` iff there is no contained centroid
194204
return aggregation.count() > 0 ? aggregation.centroid().toString() : null;
195205
}
196206
}
207+
208+
static class GeoBoundsAggExtractor implements AggValueExtractor {
209+
@Override
210+
public Object value(Aggregation agg, String fieldType) {
211+
GeoBounds aggregation = (GeoBounds)agg;
212+
if (aggregation.bottomRight() == null || aggregation.topLeft() == null) {
213+
return null;
214+
}
215+
final Map<String, Object> geoShape = new HashMap<>();
216+
// If the two geo_points are equal, it is a point
217+
if (aggregation.topLeft().equals(aggregation.bottomRight())) {
218+
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), PointBuilder.TYPE.shapeName());
219+
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
220+
Arrays.asList(aggregation.topLeft().getLon(), aggregation.bottomRight().getLat()));
221+
// If only the lat or the lon of the two geo_points are equal, than we know it should be a line
222+
} else if (Double.compare(aggregation.topLeft().getLat(), aggregation.bottomRight().getLat()) == 0
223+
|| Double.compare(aggregation.topLeft().getLon(), aggregation.bottomRight().getLon()) == 0) {
224+
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), LineStringBuilder.TYPE.shapeName());
225+
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
226+
Arrays.asList(
227+
new Double[]{aggregation.topLeft().getLon(), aggregation.topLeft().getLat()},
228+
new Double[]{aggregation.bottomRight().getLon(), aggregation.bottomRight().getLat()}));
229+
} else {
230+
// neither points are equal, we have a polygon that is a square
231+
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), PolygonBuilder.TYPE.shapeName());
232+
final GeoPoint tl = aggregation.topLeft();
233+
final GeoPoint br = aggregation.bottomRight();
234+
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
235+
Collections.singletonList(Arrays.asList(
236+
new Double[]{tl.getLon(), tl.getLat()},
237+
new Double[]{br.getLon(), tl.getLat()},
238+
new Double[]{br.getLon(), br.getLat()},
239+
new Double[]{tl.getLon(), br.getLat()},
240+
new Double[]{tl.getLon(), tl.getLat()})));
241+
}
242+
return geoShape;
243+
}
244+
}
197245
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum AggregationType {
3636
MIN("min", SOURCE),
3737
SUM("sum", "double"),
3838
GEO_CENTROID("geo_centroid", "geo_point"),
39+
GEO_BOUNDS("geo_bounds", "geo_shape"),
3940
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
4041
WEIGHTED_AVG("weighted_avg", DYNAMIC),
4142
BUCKET_SELECTOR("bucket_selector", DYNAMIC),

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java

+156
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package org.elasticsearch.xpack.dataframe.transforms.pivot;
88

99
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.geo.GeoPoint;
1011
import org.elasticsearch.common.xcontent.ContextParser;
1112
import org.elasticsearch.common.xcontent.DeprecationHandler;
1213
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -31,8 +32,11 @@
3132
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
3233
import org.elasticsearch.search.aggregations.metrics.CardinalityAggregationBuilder;
3334
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
35+
import org.elasticsearch.search.aggregations.metrics.GeoBounds;
36+
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
3437
import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder;
3538
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
39+
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
3640
import org.elasticsearch.search.aggregations.metrics.ParsedAvg;
3741
import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
3842
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
@@ -42,6 +46,7 @@
4246
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
4347
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
4448
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
49+
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
4550
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
4651
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
4752
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
@@ -56,6 +61,7 @@
5661
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
5762

5863
import java.io.IOException;
64+
import java.util.Arrays;
5965
import java.util.Collection;
6066
import java.util.Collections;
6167
import java.util.HashMap;
@@ -67,6 +73,11 @@
6773

6874
import static java.util.Arrays.asList;
6975
import static org.hamcrest.CoreMatchers.equalTo;
76+
import static org.hamcrest.CoreMatchers.hasItem;
77+
import static org.hamcrest.CoreMatchers.is;
78+
import static org.hamcrest.CoreMatchers.nullValue;
79+
import static org.mockito.Mockito.mock;
80+
import static org.mockito.Mockito.when;
7081

7182
public class AggregationResultUtilsTests extends ESTestCase {
7283

@@ -781,6 +792,151 @@ public void testUpdateDocumentWithObjectAndNotObject() {
781792
equalTo("mixed object types of nested and non-nested fields [foo.bar]"));
782793
}
783794

795+
private NumericMetricsAggregation.SingleValue createSingleMetricAgg(Double value, String valueAsString) {
796+
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
797+
when(agg.value()).thenReturn(value);
798+
when(agg.getValueAsString()).thenReturn(valueAsString);
799+
return agg;
800+
}
801+
802+
public void testSingleValueAggExtractor() {
803+
Aggregation agg = createSingleMetricAgg(Double.NaN, "NaN");
804+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), is(nullValue()));
805+
806+
agg = createSingleMetricAgg(Double.POSITIVE_INFINITY, "NaN");
807+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), is(nullValue()));
808+
809+
agg = createSingleMetricAgg(100.0, "100.0");
810+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), equalTo(100.0));
811+
812+
agg = createSingleMetricAgg(100.0, "one_hundred");
813+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), equalTo(100.0));
814+
815+
agg = createSingleMetricAgg(100.0, "one_hundred");
816+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "string"), equalTo("one_hundred"));
817+
}
818+
819+
private ScriptedMetric createScriptedMetric(Object returnValue) {
820+
ScriptedMetric agg = mock(ScriptedMetric.class);
821+
when(agg.aggregation()).thenReturn(returnValue);
822+
return agg;
823+
}
824+
825+
@SuppressWarnings("unchecked")
826+
public void testScriptedMetricAggExtractor() {
827+
Aggregation agg = createScriptedMetric(null);
828+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "object"), is(nullValue()));
829+
830+
agg = createScriptedMetric(Collections.singletonList("values"));
831+
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "object");
832+
assertThat((List<String>)val, hasItem("values"));
833+
834+
agg = createScriptedMetric(Collections.singletonMap("key", 100));
835+
val = AggregationResultUtils.getExtractor(agg).value(agg, "object");
836+
assertThat(((Map<String, Object>)val).get("key"), equalTo(100));
837+
}
838+
839+
private GeoCentroid createGeoCentroid(GeoPoint point, long count) {
840+
GeoCentroid agg = mock(GeoCentroid.class);
841+
when(agg.centroid()).thenReturn(point);
842+
when(agg.count()).thenReturn(count);
843+
return agg;
844+
}
845+
846+
public void testGeoCentroidAggExtractor() {
847+
Aggregation agg = createGeoCentroid(null, 0);
848+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), is(nullValue()));
849+
850+
agg = createGeoCentroid(new GeoPoint(100.0, 101.0), 0);
851+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), is(nullValue()));
852+
853+
agg = createGeoCentroid(new GeoPoint(100.0, 101.0), randomIntBetween(1, 100));
854+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), equalTo("100.0, 101.0"));
855+
}
856+
857+
private GeoBounds createGeoBounds(GeoPoint tl, GeoPoint br) {
858+
GeoBounds agg = mock(GeoBounds.class);
859+
when(agg.bottomRight()).thenReturn(br);
860+
when(agg.topLeft()).thenReturn(tl);
861+
return agg;
862+
}
863+
864+
@SuppressWarnings("unchecked")
865+
public void testGeoBoundsAggExtractor() {
866+
final int numberOfRuns = 25;
867+
Aggregation agg = createGeoBounds(null, new GeoPoint(100.0, 101.0));
868+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), is(nullValue()));
869+
870+
agg = createGeoBounds(new GeoPoint(100.0, 101.0), null);
871+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), is(nullValue()));
872+
873+
String type = "point";
874+
for (int i = 0; i < numberOfRuns; i++) {
875+
Map<String, Object> expectedObject = new HashMap<>();
876+
expectedObject.put("type", type);
877+
double lat = randomDoubleBetween(-90.0, 90.0, false);
878+
double lon = randomDoubleBetween(-180.0, 180.0, false);
879+
expectedObject.put("coordinates", Arrays.asList(lon, lat));
880+
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat, lon));
881+
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), equalTo(expectedObject));
882+
}
883+
884+
type = "linestring";
885+
for (int i = 0; i < numberOfRuns; i++) {
886+
double lat = randomDoubleBetween(-90.0, 90.0, false);
887+
double lon = randomDoubleBetween(-180.0, 180.0, false);
888+
double lat2 = lat;
889+
double lon2 = lon;
890+
if (randomBoolean()) {
891+
lat2 = randomDoubleBetween(-90.0, 90.0, false);
892+
} else {
893+
lon2 = randomDoubleBetween(-180.0, 180.0, false);
894+
}
895+
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat2, lon2));
896+
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape");
897+
Map<String, Object> geoJson = (Map<String, Object>)val;
898+
assertThat(geoJson.get("type"), equalTo(type));
899+
List<Double[]> coordinates = (List<Double[]>)geoJson.get("coordinates");
900+
for(Double[] coor : coordinates) {
901+
assertThat(coor.length, equalTo(2));
902+
}
903+
assertThat(coordinates.get(0)[0], equalTo(lon));
904+
assertThat(coordinates.get(0)[1], equalTo(lat));
905+
assertThat(coordinates.get(1)[0], equalTo(lon2));
906+
assertThat(coordinates.get(1)[1], equalTo(lat2));
907+
}
908+
909+
type = "polygon";
910+
for (int i = 0; i < numberOfRuns; i++) {
911+
double lat = randomDoubleBetween(-90.0, 90.0, false);
912+
double lon = randomDoubleBetween(-180.0, 180.0, false);
913+
double lat2 = randomDoubleBetween(-90.0, 90.0, false);
914+
double lon2 = randomDoubleBetween(-180.0, 180.0, false);
915+
while (Double.compare(lat, lat2) == 0 || Double.compare(lon, lon2) == 0) {
916+
lat2 = randomDoubleBetween(-90.0, 90.0, false);
917+
lon2 = randomDoubleBetween(-180.0, 180.0, false);
918+
}
919+
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat2, lon2));
920+
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape");
921+
Map<String, Object> geoJson = (Map<String, Object>)val;
922+
assertThat(geoJson.get("type"), equalTo(type));
923+
List<List<Double[]>> coordinates = (List<List<Double[]>>)geoJson.get("coordinates");
924+
assertThat(coordinates.size(), equalTo(1));
925+
assertThat(coordinates.get(0).size(), equalTo(5));
926+
List<List<Double>> expected = Arrays.asList(
927+
Arrays.asList(lon, lat),
928+
Arrays.asList(lon2, lat),
929+
Arrays.asList(lon2, lat2),
930+
Arrays.asList(lon, lat2),
931+
Arrays.asList(lon, lat));
932+
for(int j = 0; j < 5; j++) {
933+
Double[] coordinate = coordinates.get(0).get(j);
934+
assertThat(coordinate.length, equalTo(2));
935+
assertThat(coordinate[0], equalTo(expected.get(j).get(0)));
936+
assertThat(coordinate[1], equalTo(expected.get(j).get(1)));
937+
}
938+
}
939+
}
784940

785941
private void executeTest(GroupConfig groups,
786942
Collection<AggregationBuilder> aggregationBuilders,

0 commit comments

Comments
 (0)