@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License.
14
14
==============================================================================*/
15
- import { findKNN , findKNNGPUCosDistNorm , NearestEntry , TEST_ONLY } from './knn' ;
16
- import { cosDistNorm , unit } from './vector' ;
15
+ import { findKNN , findKNNGPUCosDist , NearestEntry , TEST_ONLY } from './knn' ;
16
+ import { cosDist } from './vector' ;
17
17
18
18
describe ( 'projector knn test' , ( ) => {
19
19
function getIndices ( nearest : NearestEntry [ ] [ ] ) : number [ ] [ ] {
@@ -22,22 +22,16 @@ describe('projector knn test', () => {
22
22
} ) ;
23
23
}
24
24
25
- function unitVector ( vector : Float32Array ) : Float32Array {
26
- // `unit` method replaces the vector in-place.
27
- unit ( vector ) ;
28
- return vector ;
29
- }
30
-
31
- describe ( '#findKNNGPUCosDistNorm' , ( ) => {
25
+ describe ( '#findKNNGPUCosDist' , ( ) => {
32
26
it ( 'finds n-nearest neighbor for each item' , async ( ) => {
33
- const values = await findKNNGPUCosDistNorm (
27
+ const values = await findKNNGPUCosDist (
34
28
[
35
- { a : unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) } ,
36
- { a : unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) } ,
37
- { a : unitVector ( new Float32Array ( [ 100 , 30 , 0 ] ) ) } ,
38
- { a : unitVector ( new Float32Array ( [ 95 , 23 , 3 ] ) ) } ,
39
- { a : unitVector ( new Float32Array ( [ 100 , 10 , 0 ] ) ) } ,
40
- { a : unitVector ( new Float32Array ( [ 95 , 23 , 100 ] ) ) } ,
29
+ { a : new Float32Array ( [ 1 , 2 , 0 ] ) } ,
30
+ { a : new Float32Array ( [ 1 , 1 , 3 ] ) } ,
31
+ { a : new Float32Array ( [ 100 , 30 , 0 ] ) } ,
32
+ { a : new Float32Array ( [ 95 , 23 , 3 ] ) } ,
33
+ { a : new Float32Array ( [ 100 , 10 , 0 ] ) } ,
34
+ { a : new Float32Array ( [ 95 , 23 , 100 ] ) } ,
41
35
] ,
42
36
4 ,
43
37
( data ) => data . a
@@ -54,11 +48,8 @@ describe('projector knn test', () => {
54
48
} ) ;
55
49
56
50
it ( 'returns less than N when number of item is lower' , async ( ) => {
57
- const values = await findKNNGPUCosDistNorm (
58
- [
59
- unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) ,
60
- unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) ,
61
- ] ,
51
+ const values = await findKNNGPUCosDist (
52
+ [ new Float32Array ( [ 1 , 2 , 0 ] ) , new Float32Array ( [ 1 , 1 , 3 ] ) ] ,
62
53
4 ,
63
54
( a ) => a
64
55
) ;
@@ -68,10 +59,8 @@ describe('projector knn test', () => {
68
59
69
60
it ( 'splits a large data into one that would fit into GPU memory' , async ( ) => {
70
61
const size = TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE + 5 ;
71
- const data = new Array ( size ) . fill (
72
- unitVector ( new Float32Array ( [ 1 , 1 , 1 ] ) )
73
- ) ;
74
- const values = await findKNNGPUCosDistNorm ( data , 1 , ( a ) => a ) ;
62
+ const data = new Array ( size ) . fill ( new Float32Array ( [ 1 , 1 , 1 ] ) ) ;
63
+ const values = await findKNNGPUCosDist ( data , 1 , ( a ) => a ) ;
75
64
76
65
expect ( getIndices ( values ) ) . toEqual ( [
77
66
// Since distance to the diagonal entries (distance to self is 0) is
@@ -84,25 +73,25 @@ describe('projector knn test', () => {
84
73
} ) ;
85
74
86
75
describe ( '#findKNN' , ( ) => {
87
- // Covered by equality tests below (#findKNNGPUCosDistNorm == #findKNN).
76
+ // Covered by equality tests below (#findKNNGPUCosDist == #findKNN).
88
77
} ) ;
89
78
90
- describe ( '#findKNNGPUCosDistNorm and #findKNN' , ( ) => {
79
+ describe ( '#findKNNGPUCosDist and #findKNN' , ( ) => {
91
80
it ( 'returns same value when dist metrics are cosine' , async ( ) => {
92
81
const data = [
93
- unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) ,
94
- unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) ,
95
- unitVector ( new Float32Array ( [ 100 , 30 , 0 ] ) ) ,
96
- unitVector ( new Float32Array ( [ 95 , 23 , 3 ] ) ) ,
97
- unitVector ( new Float32Array ( [ 100 , 10 , 0 ] ) ) ,
98
- unitVector ( new Float32Array ( [ 95 , 23 , 100 ] ) ) ,
82
+ new Float32Array ( [ 1 , 2 , 0 ] ) ,
83
+ new Float32Array ( [ 1 , 1 , 3 ] ) ,
84
+ new Float32Array ( [ 100 , 30 , 0 ] ) ,
85
+ new Float32Array ( [ 95 , 23 , 3 ] ) ,
86
+ new Float32Array ( [ 100 , 10 , 0 ] ) ,
87
+ new Float32Array ( [ 95 , 23 , 100 ] ) ,
99
88
] ;
100
- const findKnnGpuCosVal = await findKNNGPUCosDistNorm ( data , 2 , ( a ) => a ) ;
89
+ const findKnnGpuCosVal = await findKNNGPUCosDist ( data , 2 , ( a ) => a ) ;
101
90
const findKnnVal = await findKNN (
102
91
data ,
103
92
2 ,
104
93
( a ) => a ,
105
- ( a , b , limit ) => cosDistNorm ( a , b )
94
+ ( a , b , limit ) => cosDist ( a , b )
106
95
) ;
107
96
108
97
// Floating point precision makes it hard to test. Just assert indices.
@@ -112,15 +101,15 @@ describe('projector knn test', () => {
112
101
it ( 'splits a large data without the result being wrong' , async ( ) => {
113
102
const size = TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE + 5 ;
114
103
const data = Array . from ( new Array ( size ) ) . map ( ( _ , index ) => {
115
- return unitVector ( new Float32Array ( [ index + 1 , index + 1 ] ) ) ;
104
+ return new Float32Array ( [ index + 1 , index + 1 ] ) ;
116
105
} ) ;
117
106
118
- const findKnnGpuCosVal = await findKNNGPUCosDistNorm ( data , 2 , ( a ) => a ) ;
107
+ const findKnnGpuCosVal = await findKNNGPUCosDist ( data , 2 , ( a ) => a ) ;
119
108
const findKnnVal = await findKNN (
120
109
data ,
121
110
2 ,
122
111
( a ) => a ,
123
- ( a , b , limit ) => cosDistNorm ( a , b )
112
+ ( a , b , limit ) => cosDist ( a , b )
124
113
) ;
125
114
126
115
expect ( getIndices ( findKnnGpuCosVal ) ) . toEqual ( getIndices ( findKnnVal ) ) ;
0 commit comments