Skip to content

Commit c417891

Browse files
authored
Embedding Projector: fix tSNE tweaking hyperparams (#6320)
## Motivation for features / changes #6289 ## Technical description of changes Restart tsne with current solution when dimension or perplexity is changed ## Screenshots of UI changes N/A ## Detailed steps to verify changes work correctly (as executed by you) 1. open t-SNE projection. 2. toggle 2D/3D will result in a new t-SNE run with the existing projection ## Alternate designs / implementations considered
1 parent 01391fc commit c417891

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

Diff for: tensorboard/plugins/projector/vz_projector/bh_tsne.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,10 @@ export class TSNE {
289289
}
290290
// (re)initializes the solution to random
291291
initSolution() {
292-
// generate random solution to t-SNE
293-
this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution
292+
if (!this.Y) {
293+
// generate random solution to t-SNE
294+
this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution
295+
}
294296
this.gains = arrayofs(this.N, this.dim, 1); // step gains
295297
// to accelerate progress in unchanging directions
296298
this.ystep = arrayofs(this.N, this.dim, 0); // momentum accumulator
@@ -299,10 +301,16 @@ export class TSNE {
299301
getDim() {
300302
return this.dim;
301303
}
304+
getRng() {
305+
return this.rng;
306+
}
302307
// return pointer to current solution
303308
getSolution() {
304309
return this.Y;
305310
}
311+
setSolution(solution: Float64Array) {
312+
this.Y = solution;
313+
}
306314
// For each point, randomly offset point within a 5% hypersphere centered
307315
// around it, whilst remaining in the assumed t-SNE plot hypersphere
308316
perturb() {

Diff for: tensorboard/plugins/projector/vz_projector/data.ts

+21-4
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ export class DataSet {
336336
this.tSNEIteration = 0;
337337
let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
338338
let step = () => {
339+
if (!this.tsne.getSolution()) {
340+
this.initTsneSolutionFromCurrentProjection();
341+
return;
342+
}
339343
if (this.tSNEShouldStop) {
340344
this.projections['tsne'] = false;
341345
stepCallback(null!);
@@ -346,12 +350,11 @@ export class DataSet {
346350
if (!this.tSNEShouldPause) {
347351
this.tsne.step();
348352
let result = this.tsne.getSolution();
353+
const d = this.tsne.getDim();
349354
sampledIndices.forEach((index, i) => {
350355
let dataPoint = this.points[index];
351-
dataPoint.projections['tsne-0'] = result[i * tsneDim + 0];
352-
dataPoint.projections['tsne-1'] = result[i * tsneDim + 1];
353-
if (tsneDim === 3) {
354-
dataPoint.projections['tsne-2'] = result[i * tsneDim + 2];
356+
for (let j = 0; j < d; j++) {
357+
dataPoint.projections[`tsne-${j}`] = result[i * d + j];
355358
}
356359
});
357360
this.projections['tsne'] = true;
@@ -493,6 +496,20 @@ export class DataSet {
493496
);
494497
}
495498
}
499+
/* initialize the new tsne solution from current projection data */
500+
initTsneSolutionFromCurrentProjection() {
501+
const sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
502+
const rng = this.tsne.getRng();
503+
const d = this.tsne.getDim();
504+
const solution = new Float64Array(sampledIndices.length * d);
505+
sampledIndices.forEach((index, i) => {
506+
const dataPoint = this.points[index];
507+
for (let j = 0; j < d; j++) {
508+
solution[i * d + j] = dataPoint.projections[`tsne-${j}`] ?? rng();
509+
}
510+
});
511+
this.tsne.setSolution(solution);
512+
}
496513
/* Perturb TSNE and update dataset point coordinates. */
497514
perturbTsne() {
498515
if (this.hasTSNERun && this.tsne) {

Diff for: tensorboard/plugins/projector/vz_projector/vz-projector-projections-panel.ts

+11
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
156156
private updateTSNEPerplexityFromSliderChange() {
157157
if (this.perplexitySlider) {
158158
this.perplexity = +this.perplexitySlider.value;
159+
if (this.dataSet?.hasTSNERun) {
160+
this.dataSet.hasTSNERun = false;
161+
this.beginProjection(this.currentProjection);
162+
}
159163
}
160164
(this.$$('.tsne-perplexity span') as HTMLSpanElement).innerText =
161165
'' + this.perplexity;
@@ -301,6 +305,10 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
301305
}
302306
if (bookmark.selectedProjection != null) {
303307
this.showTab(bookmark.selectedProjection);
308+
if (this.currentProjection === 'tsne') {
309+
this.runTSNE();
310+
this.dataSet.initTsneSolutionFromCurrentProjection();
311+
}
304312
}
305313
this.enablePolymerChangesTriggerReprojection();
306314
}
@@ -389,6 +397,9 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
389397
}
390398
@observe('tSNEis3d')
391399
_tsneDimensionToggleObserver() {
400+
if (this.dataSet?.hasTSNERun) {
401+
this.dataSet.hasTSNERun = false;
402+
}
392403
this.beginProjection(this.currentProjection);
393404
}
394405
@observe('umapIs3d')

0 commit comments

Comments
 (0)