Skip to content

Commit 9c97416

Browse files
craig[bot]DrewKimball
craig[bot]
andcommitted
Merge #143886
143886: sql: don't encode Table/Index prefix twice for vector search r=mw5h a=DrewKimball This commit fixes an oversight in how vector search prefix keys are generated; namely, previously we would include the `/Table/Index` key prefix. The vector search library already adds this prefix during a search, so the resulting search keys were incorrect. Epic: CRDB-42943 Release note: None Co-authored-by: Drew Kimball <[email protected]>
2 parents 1b96d8c + 27f6a48 commit 9c97416

File tree

4 files changed

+100
-27
lines changed

4 files changed

+100
-27
lines changed

pkg/sql/logictest/testdata/logic_test/vector_index

+44-2
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,53 @@ DROP TABLE alter_test
265265
statement ok
266266
CREATE TABLE exec_test (
267267
a INT PRIMARY KEY,
268+
b INT,
268269
vec1 VECTOR(3),
269-
VECTOR INDEX (vec1)
270+
VECTOR INDEX idx1 (vec1),
271+
VECTOR INDEX idx2 (b, vec1)
270272
)
271273

272-
# TODO(drewk): write these tests once execution is supported.
274+
statement ok
275+
INSERT INTO exec_test (a, b, vec1) VALUES
276+
(1, 1, '[1, 2, 3]'),
277+
(2, 1, '[4, 5, 6]'),
278+
(3, 2, '[7, 8, 9]'),
279+
(4, 2, '[10, 11, 12]'),
280+
(5, 2, '[13, 14, 15]'),
281+
(6, NULL, '[16, 17, 18]'),
282+
(7, NULL, '[1, 1, 1]');
283+
284+
# TODO(143209): write a full set of tests once we can make them deterministic.
285+
# For now, we can write tests that return every vector with a given prefix.
286+
query I rowsort
287+
SELECT a FROM exec_test@idx1 ORDER BY vec1 <-> '[1, 1, 2]' LIMIT 7;
288+
----
289+
7
290+
1
291+
2
292+
3
293+
4
294+
5
295+
6
296+
297+
query I rowsort
298+
SELECT a FROM exec_test@idx2 WHERE b = 1 ORDER BY vec1 <-> '[1, 1, 2]' LIMIT 2;
299+
----
300+
1
301+
2
302+
303+
query I rowsort
304+
SELECT a FROM exec_test@idx2 WHERE b = 2 ORDER BY vec1 <-> '[1, 1, 2]' LIMIT 3;
305+
----
306+
3
307+
4
308+
5
309+
310+
query I rowsort
311+
SELECT a FROM exec_test WHERE b IS NULL ORDER BY vec1 <-> '[1, 1, 2]' LIMIT 3;
312+
----
313+
7
314+
6
273315

274316
statement ok
275317
DROP TABLE exec_test

pkg/sql/span/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ go_test(
4141
embed = [":span"],
4242
deps = [
4343
"//pkg/base",
44+
"//pkg/keys",
4445
"//pkg/security/securityassets",
4546
"//pkg/security/securitytest",
4647
"//pkg/server",
@@ -51,6 +52,7 @@ go_test(
5152
"//pkg/sql/catalog/fetchpb",
5253
"//pkg/sql/catalog/systemschema",
5354
"//pkg/sql/opt/constraint",
55+
"//pkg/sql/rowenc",
5456
"//pkg/sql/sem/tree",
5557
"//pkg/testutils/serverutils",
5658
"//pkg/util/encoding",

pkg/sql/span/span_builder.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ func (s *Builder) appendSpansFromConstraintSpan(
299299
var err error
300300
var containsNull bool
301301
// Encode each logical part of the start key.
302-
span.Key, containsNull, err = s.encodeConstraintKey(cs.StartKey())
302+
span.Key, containsNull, err = s.encodeConstraintKey(cs.StartKey(), true /* includePrefix */)
303303
if err != nil {
304304
return nil, err
305305
}
@@ -312,7 +312,7 @@ func (s *Builder) appendSpansFromConstraintSpan(
312312
span.Key = span.Key.PrefixEnd()
313313
}
314314
// Encode each logical part of the end key.
315-
span.EndKey, _, err = s.encodeConstraintKey(cs.EndKey())
315+
span.EndKey, _, err = s.encodeConstraintKey(cs.EndKey(), true /* includePrefix */)
316316
if err != nil {
317317
return nil, err
318318
}
@@ -337,13 +337,18 @@ func (s *Builder) appendSpansFromConstraintSpan(
337337

338338
// encodeConstraintKey encodes each logical part of a constraint.Key into a
339339
// roachpb.Key.
340+
//
341+
// includePrefix is true if the KeyPrefix bytes should be included in the
342+
// returned key.
340343
func (s *Builder) encodeConstraintKey(
341-
ck constraint.Key,
344+
ck constraint.Key, includePrefix bool,
342345
) (key roachpb.Key, containsNull bool, _ error) {
343346
if ck.IsEmpty() {
344347
return key, containsNull, nil
345348
}
346-
key = append(key, s.KeyPrefix...)
349+
if includePrefix {
350+
key = append(key, s.KeyPrefix...)
351+
}
347352
for i := 0; i < ck.Length(); i++ {
348353
val := ck.Value(i)
349354
if val == tree.DNull {
@@ -494,8 +499,10 @@ func (s *Builder) KeysFromVectorPrefixConstraint(
494499
if !span.HasSingleKey(ctx, s.evalCtx) {
495500
return nil, errors.AssertionFailedf("constraint span %s does not have a single key", span)
496501
}
502+
// Do not include the /Table/Index prefix bytes - we only want the portion
503+
// of the prefix that corresponds to the prefix columns.
497504
var err error
498-
prefixKeys[i], _, err = s.encodeConstraintKey(span.StartKey())
505+
prefixKeys[i], _, err = s.encodeConstraintKey(span.StartKey(), false /* includePrefix */)
499506
if err != nil {
500507
return nil, err
501508
}

pkg/sql/span/span_builder_test.go

+42-20
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,29 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/cockroachdb/cockroach/pkg/keys"
1314
"github.com/cockroachdb/cockroach/pkg/sql/catalog/catenumpb"
15+
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
1416
"github.com/cockroachdb/cockroach/pkg/sql/catalog/fetchpb"
1517
"github.com/cockroachdb/cockroach/pkg/sql/opt/constraint"
18+
"github.com/cockroachdb/cockroach/pkg/sql/rowenc"
1619
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
1720
"github.com/cockroachdb/cockroach/pkg/util/encoding"
1821
"github.com/stretchr/testify/require"
1922
)
2023

2124
func TestBuilder_EncodeConstraintKey(t *testing.T) {
25+
const (
26+
tableID descpb.ID = 100
27+
indexID descpb.IndexID = 2
28+
)
2229
var (
23-
colDirs1 = []string{"asc", "asc", "asc"}
24-
colDirs2 = []string{"desc", "asc", "desc"}
25-
intDatum1 = tree.NewDInt(1)
26-
intDatum2 = tree.NewDInt(2)
27-
textDatum = tree.NewDString("foo")
30+
tableIndexBytes = rowenc.MakeIndexKeyPrefix(keys.SystemSQLCodec, tableID, indexID)
31+
colDirs1 = []string{"asc", "asc", "asc"}
32+
colDirs2 = []string{"desc", "asc", "desc"}
33+
intDatum1 = tree.NewDInt(1)
34+
intDatum2 = tree.NewDInt(2)
35+
textDatum = tree.NewDString("foo")
2836
)
2937
for tcIdx, tc := range []struct {
3038
dirs []string
@@ -63,22 +71,36 @@ func TestBuilder_EncodeConstraintKey(t *testing.T) {
6371
},
6472
} {
6573
t.Run(fmt.Sprintf("case %d", tcIdx+1), func(t *testing.T) {
66-
b := Builder{}
67-
b.keyAndPrefixCols = make([]fetchpb.IndexFetchSpec_KeyColumn, len(tc.dirs))
68-
valDirs := make([]encoding.Direction, len(tc.dirs))
69-
for i, dir := range tc.dirs {
70-
if dir == "asc" {
71-
b.keyAndPrefixCols[i].Direction = catenumpb.IndexColumn_ASC
72-
valDirs[i] = encoding.Ascending
73-
} else {
74-
b.keyAndPrefixCols[i].Direction = catenumpb.IndexColumn_DESC
75-
valDirs[i] = encoding.Descending
76-
}
74+
for _, usePrefix := range []bool{true, false} {
75+
t.Run(fmt.Sprintf("usePrefix=%t", usePrefix), func(t *testing.T) {
76+
b := Builder{
77+
KeyPrefix: tableIndexBytes,
78+
keyAndPrefixCols: make([]fetchpb.IndexFetchSpec_KeyColumn, len(tc.dirs)),
79+
}
80+
valDirs := make([]encoding.Direction, len(tc.dirs))
81+
for i, dir := range tc.dirs {
82+
if dir == "asc" {
83+
b.keyAndPrefixCols[i].Direction = catenumpb.IndexColumn_ASC
84+
valDirs[i] = encoding.Ascending
85+
} else {
86+
b.keyAndPrefixCols[i].Direction = catenumpb.IndexColumn_DESC
87+
valDirs[i] = encoding.Descending
88+
}
89+
}
90+
if usePrefix {
91+
prefixDirs := []encoding.Direction{encoding.Ascending, encoding.Ascending}
92+
valDirs = append(prefixDirs, valDirs...)
93+
}
94+
outKey, _, err := b.encodeConstraintKey(tc.in, usePrefix)
95+
require.NoError(t, err)
96+
vals, _ := encoding.PrettyPrintValuesWithTypes(valDirs, outKey)
97+
expected := tc.out
98+
if usePrefix && !tc.in.IsEmpty() {
99+
expected = fmt.Sprintf("/%d/%d%s", tableID, indexID, expected)
100+
}
101+
require.Equal(t, expected, "/"+strings.Join(vals, "/"))
102+
})
77103
}
78-
outKey, _, err := b.encodeConstraintKey(tc.in)
79-
require.NoError(t, err)
80-
vals, _ := encoding.PrettyPrintValuesWithTypes(valDirs, outKey)
81-
require.Equal(t, tc.out, "/"+strings.Join(vals, "/"))
82104
})
83105
}
84106
}

0 commit comments

Comments
 (0)