Skip to content

Commit ed0c0e6

Browse files
authored
SQL: Convert ST_Distance into query when possible (#40595)
* SQL: Convert ST_Distance into query when possible Adds additional optimization logic to convert ST_Distance function calls into geo_distance query when it is called in WHERE clauses.
1 parent b4231c9 commit ed0c0e6

File tree

7 files changed

+186
-29
lines changed

7 files changed

+186
-29
lines changed

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/GeoShape.java

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.xcontent.ToXContentFragment;
1616
import org.elasticsearch.common.xcontent.XContentBuilder;
17+
import org.elasticsearch.geo.geometry.Geometry;
1718
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
1819

1920
import java.io.IOException;
@@ -58,6 +59,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5859
return builder.value(shapeBuilder.toWKT());
5960
}
6061

62+
public Geometry toGeometry() {
63+
return shapeBuilder.buildGeometry();
64+
}
65+
6166
public static double distance(GeoShape shape1, GeoShape shape2) {
6267
if (shape1.shapeBuilder instanceof PointBuilder == false) {
6368
throw new SqlIllegalArgumentException("distance calculation is only supported for points; received [{}]", shape1);

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistance.java

+12-20
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,32 @@
99
import org.elasticsearch.xpack.sql.expression.Expression;
1010
import org.elasticsearch.xpack.sql.expression.Expressions;
1111
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
12-
import org.elasticsearch.xpack.sql.expression.function.scalar.BinaryScalarFunction;
1312
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
1413
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
14+
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
1515
import org.elasticsearch.xpack.sql.tree.NodeInfo;
1616
import org.elasticsearch.xpack.sql.tree.Source;
1717
import org.elasticsearch.xpack.sql.type.DataType;
1818

1919
import static org.elasticsearch.xpack.sql.expression.TypeResolutions.isGeo;
20-
import static org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistanceProcessor.process;
2120
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
2221

2322
/**
2423
* Calculates the distance between two points
2524
*/
26-
public class StDistance extends BinaryScalarFunction {
25+
public class StDistance extends BinaryOperator<Object, Object, Double, StDistanceFunction> {
26+
27+
private static final StDistanceFunction FUNCTION = new StDistanceFunction();
2728

2829
public StDistance(Source source, Expression source1, Expression source2) {
29-
super(source, source1, source2);
30+
super(source, source1, source2, FUNCTION);
3031
}
3132

3233
@Override
3334
protected StDistance replaceChildren(Expression newLeft, Expression newRight) {
3435
return new StDistance(source(), newLeft, newRight);
3536
}
3637

37-
@Override
38-
protected TypeResolution resolveType() {
39-
if (!childrenResolved()) {
40-
return new TypeResolution("Unresolved children");
41-
}
42-
43-
TypeResolution resolution = isGeo(left(), functionName(), Expressions.ParamOrdinal.FIRST);
44-
if (resolution.unresolved()) {
45-
return resolution;
46-
}
47-
48-
return isGeo(right(), functionName(), Expressions.ParamOrdinal.SECOND);
49-
}
50-
5138
@Override
5239
public DataType dataType() {
5340
return DataType.DOUBLE;
@@ -66,8 +53,13 @@ public ScriptTemplate scriptWithField(FieldAttribute field) {
6653
}
6754

6855
@Override
69-
public Object fold() {
70-
return process(left().fold(), right().fold());
56+
protected TypeResolution resolveInputType(Expression e, Expressions.ParamOrdinal paramOrdinal) {
57+
return isGeo(e, sourceText(), paramOrdinal);
58+
}
59+
60+
@Override
61+
public StDistance swapLeftAndRight() {
62+
return new StDistance(source(), right(), left());
7163
}
7264

7365
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
package org.elasticsearch.xpack.sql.expression.function.scalar.geo;
8+
9+
import org.elasticsearch.xpack.sql.expression.predicate.PredicateBiFunction;
10+
11+
class StDistanceFunction implements PredicateBiFunction<Object, Object, Double> {
12+
13+
@Override
14+
public String name() {
15+
return "ST_DISTANCE";
16+
}
17+
18+
@Override
19+
public String symbol() {
20+
return "ST_DISTANCE";
21+
}
22+
23+
@Override
24+
public Double doApply(Object s1, Object s2) {
25+
return StDistanceProcessor.process(s1, s2);
26+
}
27+
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java

+26
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
*/
66
package org.elasticsearch.xpack.sql.planner;
77

8+
import org.elasticsearch.geo.geometry.Geometry;
9+
import org.elasticsearch.geo.geometry.Point;
810
import org.elasticsearch.search.sort.SortOrder;
911
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
1012
import org.elasticsearch.xpack.sql.expression.Attribute;
@@ -38,6 +40,8 @@
3840
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
3941
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction;
4042
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction;
43+
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.GeoShape;
44+
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance;
4145
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
4246
import org.elasticsearch.xpack.sql.expression.literal.Intervals;
4347
import org.elasticsearch.xpack.sql.expression.predicate.Range;
@@ -85,6 +89,7 @@
8589
import org.elasticsearch.xpack.sql.querydsl.agg.TopHitsAgg;
8690
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
8791
import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery;
92+
import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery;
8893
import org.elasticsearch.xpack.sql.querydsl.query.MatchQuery;
8994
import org.elasticsearch.xpack.sql.querydsl.query.MultiMatchQuery;
9095
import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery;
@@ -656,6 +661,24 @@ private static Query translateQuery(BinaryComparison bc) {
656661
Object value = valueOf(bc.right());
657662
String format = dateFormat(bc.left());
658663

664+
// Possible geo optimization
665+
if (bc.left() instanceof StDistance && value instanceof Number) {
666+
if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
667+
// Special case for ST_Distance translatable into geo_distance query
668+
StDistance stDistance = (StDistance) bc.left();
669+
if (stDistance.left() instanceof FieldAttribute && stDistance.right().foldable()) {
670+
Object geoShape = valueOf(stDistance.right());
671+
if (geoShape instanceof GeoShape) {
672+
Geometry geometry = ((GeoShape) geoShape).toGeometry();
673+
if (geometry instanceof Point) {
674+
String field = nameOf(stDistance.left());
675+
return new GeoDistanceQuery(source, field, ((Number) value).doubleValue(),
676+
((Point) geometry).getLat(), ((Point) geometry).getLon());
677+
}
678+
}
679+
}
680+
}
681+
}
659682
if (bc instanceof GreaterThan) {
660683
return new RangeQuery(source, name, value, false, null, false, format);
661684
}
@@ -954,6 +977,9 @@ public QueryTranslation translate(Expression exp, boolean onAggs) {
954977

955978
protected static Query handleQuery(ScalarFunction sf, Expression field, Supplier<Query> query) {
956979
Query q = query.get();
980+
if (field instanceof StDistance && q instanceof GeoDistanceQuery) {
981+
return wrapIfNested(q, ((StDistance) field).left());
982+
}
957983
if (field instanceof FieldAttribute) {
958984
return wrapIfNested(q, field);
959985
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.sql.querydsl.query;
7+
8+
import org.elasticsearch.common.unit.DistanceUnit;
9+
import org.elasticsearch.index.query.QueryBuilder;
10+
import org.elasticsearch.index.query.QueryBuilders;
11+
import org.elasticsearch.xpack.sql.tree.Source;
12+
13+
import java.util.Objects;
14+
15+
public class GeoDistanceQuery extends LeafQuery {
16+
17+
private final String field;
18+
private final double lat;
19+
private final double lon;
20+
private final double distance;
21+
22+
public GeoDistanceQuery(Source source, String field, double distance, double lat, double lon) {
23+
super(source);
24+
this.field = field;
25+
this.distance = distance;
26+
this.lat = lat;
27+
this.lon = lon;
28+
}
29+
30+
public String field() {
31+
return field;
32+
}
33+
34+
public double lat() {
35+
return lat;
36+
}
37+
38+
public double lon() {
39+
return lon;
40+
}
41+
42+
public double distance() {
43+
return distance;
44+
}
45+
46+
@Override
47+
public QueryBuilder asBuilder() {
48+
return QueryBuilders.geoDistanceQuery(field).distance(distance, DistanceUnit.METERS).point(lat, lon);
49+
}
50+
51+
@Override
52+
public int hashCode() {
53+
return Objects.hash(field, distance, lat, lon);
54+
}
55+
56+
@Override
57+
public boolean equals(Object obj) {
58+
if (this == obj) {
59+
return true;
60+
}
61+
62+
if (obj == null || getClass() != obj.getClass()) {
63+
return false;
64+
}
65+
66+
GeoDistanceQuery other = (GeoDistanceQuery) obj;
67+
return Objects.equals(field, other.field) &&
68+
Objects.equals(distance, other.distance) &&
69+
Objects.equals(lat, other.lat) &&
70+
Objects.equals(lon, other.lon);
71+
}
72+
73+
@Override
74+
protected String innerToString() {
75+
return field + ":" + "(" + distance + "," + "(" + lat + ", " + lon + "))";
76+
}
77+
}

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java

+10
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.IsoWeekOfYear;
3232
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MonthOfYear;
3333
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year;
34+
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance;
3435
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ACos;
3536
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ASin;
3637
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ATan;
@@ -622,6 +623,15 @@ public void testLiteralsOnTheRight() {
622623
assertEquals(FIVE, nullEquals.right());
623624
}
624625

626+
public void testLiteralsOnTheRightInStDistance() {
627+
Alias a = new Alias(EMPTY, "a", L(10));
628+
Expression result = new BooleanLiteralsOnTheRight().rule(new StDistance(EMPTY, FIVE, a));
629+
assertTrue(result instanceof StDistance);
630+
StDistance sd = (StDistance) result;
631+
assertEquals(a, sd.left());
632+
assertEquals(FIVE, sd.right());
633+
}
634+
625635
public void testBoolSimplifyNotIsNullAndNotIsNotNull() {
626636
BooleanSimplification simplification = new BooleanSimplification();
627637
assertTrue(simplification.rule(new Not(EMPTY, new IsNull(EMPTY, ONE))) instanceof IsNotNull);

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

+29-9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram;
4040
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
4141
import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery;
42+
import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery;
4243
import org.elasticsearch.xpack.sql.querydsl.query.NotQuery;
4344
import org.elasticsearch.xpack.sql.querydsl.query.Query;
4445
import org.elasticsearch.xpack.sql.querydsl.query.RangeQuery;
@@ -614,22 +615,41 @@ public void testTranslateStWktToSql() {
614615
assertEquals("[{v=keyword}, {v=point (10.0 20.0)}]", aggFilter.scriptTemplate().params().toString());
615616
}
616617

617-
public void testTranslateStDistance() {
618-
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) > 20");
618+
public void testTranslateStDistanceToScript() {
619+
String operator = randomFrom(">", ">=");
620+
String operatorFunction = operator.equalsIgnoreCase(">") ? "gt" : "gte";
621+
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 20");
619622
assertThat(p, instanceOf(Project.class));
620623
assertThat(p.children().get(0), instanceOf(Filter.class));
621624
Expression condition = ((Filter) p.children().get(0)).condition();
622625
assertFalse(condition.foldable());
623-
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
624-
assertNull(translation.query);
625-
AggFilter aggFilter = translation.aggFilter;
626-
626+
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
627+
assertNull(translation.aggFilter);
628+
assertTrue(translation.query instanceof ScriptQuery);
629+
ScriptQuery sc = (ScriptQuery) translation.query;
627630
assertEquals("InternalSqlScriptUtils.nullSafeFilter(" +
628-
"InternalSqlScriptUtils.gt(" +
631+
"InternalSqlScriptUtils." + operatorFunction + "(" +
629632
"InternalSqlScriptUtils.stDistance(" +
630633
"InternalSqlScriptUtils.geoDocValue(doc,params.v0),InternalSqlScriptUtils.stWktToSql(params.v1)),params.v2))",
631-
aggFilter.scriptTemplate().toString());
632-
assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", aggFilter.scriptTemplate().params().toString());
634+
sc.script().toString());
635+
assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", sc.script().params().toString());
636+
}
637+
638+
public void testTranslateStDistanceToQuery() {
639+
String operator = randomFrom("<", "<=");
640+
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 25");
641+
assertThat(p, instanceOf(Project.class));
642+
assertThat(p.children().get(0), instanceOf(Filter.class));
643+
Expression condition = ((Filter) p.children().get(0)).condition();
644+
assertFalse(condition.foldable());
645+
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
646+
assertNull(translation.aggFilter);
647+
assertTrue(translation.query instanceof GeoDistanceQuery);
648+
GeoDistanceQuery gq = (GeoDistanceQuery) translation.query;
649+
assertEquals("shape", gq.field());
650+
assertEquals(20.0, gq.lat(), 0.00001);
651+
assertEquals(10.0, gq.lon(), 0.00001);
652+
assertEquals(25.0, gq.distance(), 0.00001);
633653
}
634654

635655
public void testTranslateCoalesce_GroupBy_Painless() {

0 commit comments

Comments
 (0)