Skip to content

Commit f2b6879

Browse files
waaeervitcpp
authored andcommitted
KNN support for spoints
1 parent 597600e commit f2b6879

File tree

6 files changed

+279
-3
lines changed

6 files changed

+279
-3
lines changed

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ DATA_built = $(RELEASE_SQL) \
3939
DOCS = README.pg_sphere COPYRIGHT.pg_sphere
4040
TESTS = version tables points euler circle line ellipse poly path box \
4141
index contains_ops contains_ops_compat bounding_box_gist gnomo \
42-
epochprop contains overlaps spoint_brin sbox_brin selectivity
42+
epochprop contains overlaps spoint_brin sbox_brin selectivity knn
4343
REGRESS = init $(TESTS)
4444

4545
PG_CFLAGS += -DPGSPHERE_VERSION=$(PGSPHERE_VERSION)

Diff for: doc/indices.sgm

+12-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@
6868
</para>
6969
</listitem>
7070
</itemizedlist>
71+
<para>
72+
GiST index can be used also for fast finding points closest to the given one
73+
when ordering by an expression with the <literal>&lt;-&gt;</literal> operator
74+
is used, as shown in an example below.
75+
</para>
7176
<para>
7277
BRIN indexing supports just spherical points (<type>spoint</type>)
7378
and spherical coordinates range (<type>sbox</type>) at the moment.
@@ -82,6 +87,13 @@
8287
<![CDATA[CREATE INDEX test_pos_idx ON test USING GIST (pos);]]>
8388
<![CDATA[VACUUM ANALYZE test;]]>
8489
</programlisting>
90+
<para>
91+
To find points closest to a given spherical position, use the <literal>&lt;-&gt;</literal> operator:
92+
</para>
93+
<programlisting>
94+
<![CDATA[SELECT * FROM test ORDER BY pos <-> spoint (0.2, 0.3) LIMIT 10 ]]>
95+
</programlisting>
96+
8597
<para>
8698
BRIN index can be created through the following syntax:
8799
</para>
@@ -100,7 +112,6 @@
100112
<![CDATA[CREATE INDEX test_pos_idx USING BRIN ON test (pos) WITH (pages_per_range = 16);]]>
101113
</programlisting>
102114
</example>
103-
104115
</sect1>
105116

106117
<sect1 id="ind.smoc">

Diff for: expected/knn.out

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
CREATE TABLE points (id int, p spoint, pos int);
2+
INSERT INTO points (id, p) SELECT x, spoint(random()*6.28, (2*random()-1)*1.57) FROM generate_series(1,314159) x;
3+
CREATE INDEX i ON points USING gist (p);
4+
SET enable_indexscan = true;
5+
EXPLAIN (costs off) SELECT p <-> spoint (0.2, 0.3) FROM points ORDER BY 1 LIMIT 100;
6+
QUERY PLAN
7+
-------------------------------------------------
8+
Limit
9+
-> Index Scan using i on points
10+
Order By: (p <-> '(0.2 , 0.3)'::spoint)
11+
(3 rows)
12+
13+
UPDATE points SET pos = n FROM
14+
(SELECT id, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100) sel
15+
WHERE points.id = sel.id;
16+
SET enable_indexscan = false;
17+
SELECT pos, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100;
18+
pos | n
19+
-----+-----
20+
1 | 1
21+
2 | 2
22+
3 | 3
23+
4 | 4
24+
5 | 5
25+
6 | 6
26+
7 | 7
27+
8 | 8
28+
9 | 9
29+
10 | 10
30+
11 | 11
31+
12 | 12
32+
13 | 13
33+
14 | 14
34+
15 | 15
35+
16 | 16
36+
17 | 17
37+
18 | 18
38+
19 | 19
39+
20 | 20
40+
21 | 21
41+
22 | 22
42+
23 | 23
43+
24 | 24
44+
25 | 25
45+
26 | 26
46+
27 | 27
47+
28 | 28
48+
29 | 29
49+
30 | 30
50+
31 | 31
51+
32 | 32
52+
33 | 33
53+
34 | 34
54+
35 | 35
55+
36 | 36
56+
37 | 37
57+
38 | 38
58+
39 | 39
59+
40 | 40
60+
41 | 41
61+
42 | 42
62+
43 | 43
63+
44 | 44
64+
45 | 45
65+
46 | 46
66+
47 | 47
67+
48 | 48
68+
49 | 49
69+
50 | 50
70+
51 | 51
71+
52 | 52
72+
53 | 53
73+
54 | 54
74+
55 | 55
75+
56 | 56
76+
57 | 57
77+
58 | 58
78+
59 | 59
79+
60 | 60
80+
61 | 61
81+
62 | 62
82+
63 | 63
83+
64 | 64
84+
65 | 65
85+
66 | 66
86+
67 | 67
87+
68 | 68
88+
69 | 69
89+
70 | 70
90+
71 | 71
91+
72 | 72
92+
73 | 73
93+
74 | 74
94+
75 | 75
95+
76 | 76
96+
77 | 77
97+
78 | 78
98+
79 | 79
99+
80 | 80
100+
81 | 81
101+
82 | 82
102+
83 | 83
103+
84 | 84
104+
85 | 85
105+
86 | 86
106+
87 | 87
107+
88 | 88
108+
89 | 89
109+
90 | 90
110+
91 | 91
111+
92 | 92
112+
93 | 93
113+
94 | 94
114+
95 | 95
115+
96 | 96
116+
97 | 97
117+
98 | 98
118+
99 | 99
119+
100 | 100
120+
(100 rows)
121+
122+
DROP TABLE points;

Diff for: pgs_gist.sql.in

+6-1
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,15 @@ CREATE FUNCTION g_spoint_compress(internal)
9898
AS 'MODULE_PATHNAME', 'g_spoint_compress'
9999
LANGUAGE 'c';
100100

101-
102101
CREATE FUNCTION g_spoint_consistent(internal, internal, int4, oid, internal)
103102
RETURNS internal
104103
AS 'MODULE_PATHNAME', 'g_spoint_consistent'
105104
LANGUAGE 'c';
106105

106+
CREATE FUNCTION g_spoint_distance(internal, spoint, smallint, oid, internal)
107+
RETURNS internal
108+
AS 'MODULE_PATHNAME', 'g_spoint_distance'
109+
LANGUAGE 'c';
107110

108111
CREATE OPERATOR CLASS spoint
109112
DEFAULT FOR TYPE spoint USING gist AS
@@ -114,6 +117,7 @@ CREATE OPERATOR CLASS spoint
114117
OPERATOR 14 @ (spoint, spoly),
115118
OPERATOR 15 @ (spoint, sellipse),
116119
OPERATOR 16 @ (spoint, sbox),
120+
OPERATOR 17 <-> (spoint, spoint) FOR ORDER BY float_ops,
117121
OPERATOR 37 <@ (spoint, scircle),
118122
OPERATOR 38 <@ (spoint, sline),
119123
OPERATOR 39 <@ (spoint, spath),
@@ -127,6 +131,7 @@ CREATE OPERATOR CLASS spoint
127131
FUNCTION 5 g_spherekey_penalty (internal, internal, internal),
128132
FUNCTION 6 g_spherekey_picksplit (internal, internal),
129133
FUNCTION 7 g_spherekey_same (spherekey, spherekey, internal),
134+
FUNCTION 8 g_spoint_distance (internal, spoint, smallint, oid, internal),
130135
STORAGE spherekey;
131136

132137

Diff for: sql/knn.sql

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
CREATE TABLE points (id int, p spoint, pos int);
2+
INSERT INTO points (id, p) SELECT x, spoint(random()*6.28, (2*random()-1)*1.57) FROM generate_series(1,314159) x;
3+
CREATE INDEX i ON points USING gist (p);
4+
SET enable_indexscan = true;
5+
EXPLAIN (costs off) SELECT p <-> spoint (0.2, 0.3) FROM points ORDER BY 1 LIMIT 100;
6+
UPDATE points SET pos = n FROM
7+
(SELECT id, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100) sel
8+
WHERE points.id = sel.id;
9+
SET enable_indexscan = false;
10+
SELECT pos, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100;
11+
DROP TABLE points;
12+

Diff for: src/gist.c

+126
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ PG_FUNCTION_INFO_V1(g_spherekey_penalty);
3636
PG_FUNCTION_INFO_V1(g_spherekey_picksplit);
3737
PG_FUNCTION_INFO_V1(g_spoint3_penalty);
3838
PG_FUNCTION_INFO_V1(g_spoint3_picksplit);
39+
PG_FUNCTION_INFO_V1(g_spoint_distance);
3940
PG_FUNCTION_INFO_V1(g_spoint3_distance);
4041
PG_FUNCTION_INFO_V1(g_spoint3_fetch);
4142

@@ -681,6 +682,10 @@ g_spoint3_consistent(PG_FUNCTION_ARGS)
681682
PG_RETURN_BOOL(false);
682683
}
683684

685+
static double distance_vector_point_3d (Vector3D* v, double x, double y, double z) {
686+
return acos ( (v->x * x + v->y * y + v->z * z) / sqrt( x*x + y*y + z*z ) ); // as v has length=1 by design
687+
}
688+
684689
Datum
685690
g_spoint3_distance(PG_FUNCTION_ARGS)
686691
{
@@ -1672,6 +1677,127 @@ fallbackSplit(Box3D *boxes, OffsetNumber maxoff, GIST_SPLITVEC *v)
16721677
v->spl_ldatum_exists = v->spl_rdatum_exists = false;
16731678
}
16741679

1680+
1681+
Datum
1682+
g_spoint_distance(PG_FUNCTION_ARGS)
1683+
{
1684+
GISTENTRY *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
1685+
StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
1686+
Box3D* box = (Box3D *) DatumGetPointer(entry->key);
1687+
double retval;
1688+
SPoint *point = (SPoint *) PG_GETARG_POINTER(1);
1689+
Vector3D v_point, v_low, v_high;
1690+
1691+
switch (strategy)
1692+
{
1693+
case 17:
1694+
// Prepare data for calculation
1695+
spoint_vector3d(&v_point, point);
1696+
v_low.x = (double)box->low.coord[0] / MAXCVALUE;
1697+
v_low.y = (double)box->low.coord[1] / MAXCVALUE;
1698+
v_low.z = (double)box->low.coord[2] / MAXCVALUE;
1699+
v_high.x = (double)box->high.coord[0] / MAXCVALUE;
1700+
v_high.y = (double)box->high.coord[1] / MAXCVALUE;
1701+
v_high.z = (double)box->high.coord[2] / MAXCVALUE;
1702+
// a box splits space into 27 subspaces (6+12+8+1) with different distance calculation
1703+
if(v_point.x < v_low.x) {
1704+
if(v_point.y < v_low.y) {
1705+
if(v_point.z < v_low.z) {
1706+
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_low.z); //point2point distance
1707+
} else if (v_point.z < v_high.z) {
1708+
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_point.z); //point2line distance
1709+
} else {
1710+
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_high.z); //point2point distance
1711+
}
1712+
} else if(v_point.y < v_high.y) {
1713+
if(v_point.z < v_low.z) {
1714+
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y , v_low.z); //point2line distance
1715+
} else if (v_point.z < v_high.z) {
1716+
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y , v_point.z); //point2plane distance
1717+
} else {
1718+
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y, v_high.z); //point2line distance
1719+
}
1720+
} else {
1721+
if(v_point.z < v_low.z) {
1722+
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_low.z); //point2point distance
1723+
} else if (v_point.z < v_high.z) {
1724+
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_point.z); //point2line distance
1725+
} else {
1726+
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_high.z); //point2point distance
1727+
}
1728+
}
1729+
} else if(v_point.x < v_high.x) {
1730+
if(v_point.y < v_low.y) {
1731+
if(v_point.z < v_low.z) {
1732+
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_low.z); //p2line distance
1733+
} else if (v_point.z < v_high.z) {
1734+
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_point.z); //point2plane distance
1735+
} else {
1736+
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_high.z); //point2line distance
1737+
}
1738+
} else if(v_point.y < v_high.y) {
1739+
if(v_point.z < v_low.z) {
1740+
retval = distance_vector_point_3d (&v_point, v_point.x, v_point.y , v_low.z); //point2plane distance
1741+
} else if (v_point.z < v_high.z) {
1742+
retval = 0; // inside cube
1743+
} else {
1744+
retval = distance_vector_point_3d (&v_point, v_point.x, v_point.y, v_high.z); //point2plane distance
1745+
}
1746+
} else {
1747+
if(v_point.z < v_low.z) {
1748+
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_low.z); //point2line distance
1749+
} else if (v_point.z < v_high.z) {
1750+
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_point.z); //point2plane distance
1751+
} else {
1752+
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_high.z); //point2line distance
1753+
}
1754+
}
1755+
} else {
1756+
if(v_point.y < v_low.y) {
1757+
if(v_point.z < v_low.z) {
1758+
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_low.z); //p2p distance
1759+
} else if (v_point.z < v_high.z) {
1760+
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_point.z); //point2line distance
1761+
} else {
1762+
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_high.z); //point2point distance
1763+
}
1764+
} else if(v_point.y < v_high.y) {
1765+
if(v_point.z < v_low.z) {
1766+
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y , v_low.z); //point2line distance
1767+
} else if (v_point.z < v_high.z) {
1768+
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y , v_point.z); //point2plane distance
1769+
} else {
1770+
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y, v_high.z); //point2line distance
1771+
}
1772+
} else {
1773+
if(v_point.z < v_low.z) {
1774+
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_low.z); //point2point distance
1775+
} else if (v_point.z < v_high.z) {
1776+
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_point.z); //point2line distance
1777+
} else {
1778+
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_high.z); //point2point distance
1779+
}
1780+
}
1781+
}
1782+
1783+
elog(DEBUG1, "distance (%lg,%lg,%lg %lg,%lg,%lg) <-> (%lg,%lg) = %lg",
1784+
v_low.x, v_low.y, v_low.z,
1785+
v_high.x, v_high.y, v_high.z,
1786+
point->lng, point->lat,
1787+
retval
1788+
);
1789+
break;
1790+
1791+
default:
1792+
elog(ERROR, "unrecognized cube strategy number: %d", strategy);
1793+
retval = 0; /* keep compiler quiet */
1794+
break;
1795+
}
1796+
PG_RETURN_FLOAT8(retval);
1797+
}
1798+
1799+
1800+
16751801
/*
16761802
* Represents information about an entry that can be placed to either group
16771803
* without affecting overlap over selected axis ("common entry").

0 commit comments

Comments
 (0)