Skip to content

Commit bbed4ac

Browse files
committed
Embedding Projector:knn for non-normalized vectors
1 parent 4f0cd50 commit bbed4ac

File tree

3 files changed

+38
-50
lines changed

3 files changed

+38
-50
lines changed

tensorboard/plugins/projector/vz_projector/data.ts

+2-6
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,12 @@ export class DataSet {
481481
MIN_NUM_KNN_NEIGHBORS
482482
);
483483
const result = await (knnGpuEnabled
484-
? knn.findKNNGPUCosDistNorm(
485-
data,
486-
numKnnNeighborsToCompute,
487-
(d) => d.vector
488-
)
484+
? knn.findKNNGPUCosDist(data, numKnnNeighborsToCompute, (d) => d.vector)
489485
: knn.findKNN(
490486
data,
491487
numKnnNeighborsToCompute,
492488
(d) => d.vector,
493-
(a, b) => vector.cosDistNorm(a, b)
489+
(a, b) => vector.cosDist(a, b)
494490
));
495491
this.nearest = result;
496492
return Promise.resolve(

tensorboard/plugins/projector/vz_projector/knn.ts

+9-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ const KNN_GPU_MSG_ID = 'knn-gpu';
4444
* @param k Number of nearest neighbors to find.
4545
* @param accessor A method that returns the vector, given the data point.
4646
*/
47-
export function findKNNGPUCosDistNorm<T>(
47+
export function findKNNGPUCosDist<T>(
4848
dataPoints: T[],
4949
k: number,
5050
accessor: (dataPoint: T) => Float32Array
@@ -61,6 +61,10 @@ export function findKNNGPUCosDistNorm<T>(
6161
// pair of points, which we sort using KMin data structure to obtain the
6262
// K nearest neighbors for each point.
6363
const nearest: NearestEntry[][] = new Array(N);
64+
const dpNorm: number[] = new Array(N);
65+
for (let i = 0; i < N; i++) {
66+
dpNorm[i] = Math.sqrt(vector.norm2(accessor(dataPoints[i])));
67+
}
6468
let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE);
6569
const actualPieceSize = Math.floor(N / numPieces);
6670
const modulo = N % actualPieceSize;
@@ -120,10 +124,9 @@ export function findKNNGPUCosDistNorm<T>(
120124
// Access i * N's row at `j` column.
121125
// Reach row has N entries and j-th index has cosine distance
122126
// between iReal vs. j-th vectors.
123-
const cosDist = partial[i * N + j];
124-
if (cosDist >= 0) {
125-
kMin.add(cosDist, {index: j, dist: cosDist});
126-
}
127+
const cosDist =
128+
1 - (1 - partial[i * N + j]) / (dpNorm[i] * dpNorm[j]);
129+
kMin.add(cosDist, {index: j, dist: cosDist});
127130
}
128131
nearest[iReal] = kMin.getMinKItems();
129132
}
@@ -151,7 +154,7 @@ export function findKNNGPUCosDistNorm<T>(
151154
(error) => {
152155
// GPU failed. Reverting back to CPU.
153156
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
154-
let distFunc = (a, b, limit) => vector.cosDistNorm(a, b);
157+
let distFunc = (a, b, limit) => vector.cosDist(a, b);
155158
findKNN(dataPoints, k, accessor, distFunc).then((nearest) => {
156159
resolve(nearest);
157160
});

tensorboard/plugins/projector/vz_projector/knn_test.ts

+27-38
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
import {findKNN, findKNNGPUCosDistNorm, NearestEntry, TEST_ONLY} from './knn';
16-
import {cosDistNorm, unit} from './vector';
15+
import {findKNN, findKNNGPUCosDist, NearestEntry, TEST_ONLY} from './knn';
16+
import {cosDist} from './vector';
1717

1818
describe('projector knn test', () => {
1919
function getIndices(nearest: NearestEntry[][]): number[][] {
@@ -22,22 +22,16 @@ describe('projector knn test', () => {
2222
});
2323
}
2424

25-
function unitVector(vector: Float32Array): Float32Array {
26-
// `unit` method replaces the vector in-place.
27-
unit(vector);
28-
return vector;
29-
}
30-
31-
describe('#findKNNGPUCosDistNorm', () => {
25+
describe('#findKNNGPUCosDist', () => {
3226
it('finds n-nearest neighbor for each item', async () => {
33-
const values = await findKNNGPUCosDistNorm(
27+
const values = await findKNNGPUCosDist(
3428
[
35-
{a: unitVector(new Float32Array([1, 2, 0]))},
36-
{a: unitVector(new Float32Array([1, 1, 3]))},
37-
{a: unitVector(new Float32Array([100, 30, 0]))},
38-
{a: unitVector(new Float32Array([95, 23, 3]))},
39-
{a: unitVector(new Float32Array([100, 10, 0]))},
40-
{a: unitVector(new Float32Array([95, 23, 100]))},
29+
{a: new Float32Array([1, 2, 0])},
30+
{a: new Float32Array([1, 1, 3])},
31+
{a: new Float32Array([100, 30, 0])},
32+
{a: new Float32Array([95, 23, 3])},
33+
{a: new Float32Array([100, 10, 0])},
34+
{a: new Float32Array([95, 23, 100])},
4135
],
4236
4,
4337
(data) => data.a
@@ -54,11 +48,8 @@ describe('projector knn test', () => {
5448
});
5549

5650
it('returns less than N when number of item is lower', async () => {
57-
const values = await findKNNGPUCosDistNorm(
58-
[
59-
unitVector(new Float32Array([1, 2, 0])),
60-
unitVector(new Float32Array([1, 1, 3])),
61-
],
51+
const values = await findKNNGPUCosDist(
52+
[new Float32Array([1, 2, 0]), new Float32Array([1, 1, 3])],
6253
4,
6354
(a) => a
6455
);
@@ -68,10 +59,8 @@ describe('projector knn test', () => {
6859

6960
it('splits a large data into one that would fit into GPU memory', async () => {
7061
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
71-
const data = new Array(size).fill(
72-
unitVector(new Float32Array([1, 1, 1]))
73-
);
74-
const values = await findKNNGPUCosDistNorm(data, 1, (a) => a);
62+
const data = new Array(size).fill(new Float32Array([1, 1, 1]));
63+
const values = await findKNNGPUCosDist(data, 1, (a) => a);
7564

7665
expect(getIndices(values)).toEqual([
7766
// Since distance to the diagonal entries (distance to self is 0) is
@@ -84,25 +73,25 @@ describe('projector knn test', () => {
8473
});
8574

8675
describe('#findKNN', () => {
87-
// Covered by equality tests below (#findKNNGPUCosDistNorm == #findKNN).
76+
// Covered by equality tests below (#findKNNGPUCosDist == #findKNN).
8877
});
8978

90-
describe('#findKNNGPUCosDistNorm and #findKNN', () => {
79+
describe('#findKNNGPUCosDist and #findKNN', () => {
9180
it('returns same value when dist metrics are cosine', async () => {
9281
const data = [
93-
unitVector(new Float32Array([1, 2, 0])),
94-
unitVector(new Float32Array([1, 1, 3])),
95-
unitVector(new Float32Array([100, 30, 0])),
96-
unitVector(new Float32Array([95, 23, 3])),
97-
unitVector(new Float32Array([100, 10, 0])),
98-
unitVector(new Float32Array([95, 23, 100])),
82+
new Float32Array([1, 2, 0]),
83+
new Float32Array([1, 1, 3]),
84+
new Float32Array([100, 30, 0]),
85+
new Float32Array([95, 23, 3]),
86+
new Float32Array([100, 10, 0]),
87+
new Float32Array([95, 23, 100]),
9988
];
100-
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
89+
const findKnnGpuCosVal = await findKNNGPUCosDist(data, 2, (a) => a);
10190
const findKnnVal = await findKNN(
10291
data,
10392
2,
10493
(a) => a,
105-
(a, b, limit) => cosDistNorm(a, b)
94+
(a, b, limit) => cosDist(a, b)
10695
);
10796

10897
// Floating point precision makes it hard to test. Just assert indices.
@@ -112,15 +101,15 @@ describe('projector knn test', () => {
112101
it('splits a large data without the result being wrong', async () => {
113102
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
114103
const data = Array.from(new Array(size)).map((_, index) => {
115-
return unitVector(new Float32Array([index + 1, index + 1]));
104+
return new Float32Array([index + 1, index + 2]);
116105
});
117106

118-
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
107+
const findKnnGpuCosVal = await findKNNGPUCosDist(data, 2, (a) => a);
119108
const findKnnVal = await findKNN(
120109
data,
121110
2,
122111
(a) => a,
123-
(a, b, limit) => cosDistNorm(a, b)
112+
(a, b, limit) => cosDist(a, b)
124113
);
125114

126115
expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));

0 commit comments

Comments
 (0)