Skip to content

Commit 9dfec1e

Browse files
committed
Embedding Projector: fix tSNE tweaking hyperparams
1 parent 963afb7 commit 9dfec1e

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

tensorboard/plugins/projector/vz_projector/bh_tsne.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,10 @@ export class TSNE {
289289
}
290290
// (re)initializes the solution to random
291291
initSolution() {
292-
// generate random solution to t-SNE
293-
this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution
292+
if (!this.Y) {
293+
// generate random solution to t-SNE
294+
this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution
295+
}
294296
this.gains = arrayofs(this.N, this.dim, 1); // step gains
295297
// to accelerate progress in unchanging directions
296298
this.ystep = arrayofs(this.N, this.dim, 0); // momentum accumulator
@@ -299,10 +301,16 @@ export class TSNE {
299301
getDim() {
300302
return this.dim;
301303
}
304+
getRng() {
305+
return this.rng;
306+
}
302307
// return pointer to current solution
303308
getSolution() {
304309
return this.Y;
305310
}
311+
setSolution(solution: Float64Array) {
312+
this.Y = solution;
313+
}
306314
// For each point, randomly offset point within a 5% hypersphere centered
307315
// around it, whilst remaining in the assumed t-SNE plot hypersphere
308316
perturb() {

tensorboard/plugins/projector/vz_projector/data.ts

+21-4
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ export class DataSet {
334334
this.tSNEIteration = 0;
335335
let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
336336
let step = () => {
337+
if (!this.tsne.getSolution()) {
338+
this.initTsneSolutionFromCurrentProjection();
339+
return;
340+
}
337341
if (this.tSNEShouldStop) {
338342
this.projections['tsne'] = false;
339343
stepCallback(null!);
@@ -344,12 +348,11 @@ export class DataSet {
344348
if (!this.tSNEShouldPause) {
345349
this.tsne.step();
346350
let result = this.tsne.getSolution();
351+
const d = this.tsne.getDim();
347352
sampledIndices.forEach((index, i) => {
348353
let dataPoint = this.points[index];
349-
dataPoint.projections['tsne-0'] = result[i * tsneDim + 0];
350-
dataPoint.projections['tsne-1'] = result[i * tsneDim + 1];
351-
if (tsneDim === 3) {
352-
dataPoint.projections['tsne-2'] = result[i * tsneDim + 2];
354+
for (let j = 0; j < d; j++) {
355+
dataPoint.projections[`tsne-${j}`] = result[i * d + j];
353356
}
354357
});
355358
this.projections['tsne'] = true;
@@ -486,6 +489,20 @@ export class DataSet {
486489
return Promise.resolve(result);
487490
}
488491
}
492+
/* initialize the new tsne solution from current projection data */
493+
initTsneSolutionFromCurrentProjection() {
494+
const sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
495+
const rng = this.tsne.getRng();
496+
const d = this.tsne.getDim();
497+
const solution = new Float64Array(sampledIndices.length * d);
498+
sampledIndices.forEach((index, i) => {
499+
const dataPoint = this.points[index];
500+
for (let j = 0; j < d; j++) {
501+
solution[i * d + j] = dataPoint.projections[`tsne-${j}`] ?? rng();
502+
}
503+
});
504+
this.tsne.setSolution(solution);
505+
}
489506
/* Perturb TSNE and update dataset point coordinates. */
490507
perturbTsne() {
491508
if (this.hasTSNERun && this.tsne) {

tensorboard/plugins/projector/vz_projector/vz-projector-projections-panel.ts

+11
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
154154
private updateTSNEPerplexityFromSliderChange() {
155155
if (this.perplexitySlider) {
156156
this.perplexity = +this.perplexitySlider.value;
157+
if (this.dataSet?.hasTSNERun) {
158+
this.dataSet.hasTSNERun = false;
159+
this.beginProjection(this.currentProjection);
160+
}
157161
}
158162
(this.$$('.tsne-perplexity span') as HTMLSpanElement).innerText =
159163
'' + this.perplexity;
@@ -295,6 +299,10 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
295299
}
296300
if (bookmark.selectedProjection != null) {
297301
this.showTab(bookmark.selectedProjection);
302+
if (this.currentProjection === 'tsne') {
303+
this.runTSNE();
304+
this.dataSet.initTsneSolutionFromCurrentProjection();
305+
}
298306
}
299307
this.enablePolymerChangesTriggerReprojection();
300308
}
@@ -380,6 +388,9 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
380388
}
381389
@observe('tSNEis3d')
382390
_tsneDimensionToggleObserver() {
391+
if (this.dataSet?.hasTSNERun) {
392+
this.dataSet.hasTSNERun = false;
393+
}
383394
this.beginProjection(this.currentProjection);
384395
}
385396
@observe('umapIs3d')

0 commit comments

Comments
 (0)