Skip to content

Commit 4bd2c8f

Browse files
committed
Embedding Projector: fix tSNE tweaking hyperparams
1 parent 963afb7 commit 4bd2c8f

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

Diff for: tensorboard/plugins/projector/vz_projector/bh_tsne.ts

+6
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,16 @@ export class TSNE {
299299
getDim() {
300300
return this.dim;
301301
}
302+
getRng() {
303+
return this.rng;
304+
}
302305
// return pointer to current solution
303306
getSolution() {
304307
return this.Y;
305308
}
309+
setSolution(solution: Float64Array) {
310+
this.Y = solution;
311+
}
306312
// For each point, randomly offset point within a 5% hypersphere centered
307313
// around it, whilst remaining in the assumed t-SNE plot hypersphere
308314
perturb() {

Diff for: tensorboard/plugins/projector/vz_projector/data.ts

+16-4
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,19 @@ export class DataSet {
334334
this.tSNEIteration = 0;
335335
let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
336336
let step = () => {
337+
if (!this.tsne.getSolution()) {
338+
const rng = this.tsne.getRng();
339+
const d = this.tsne.getDim();
340+
const solution = new Float64Array(sampledIndices.length * d);
341+
sampledIndices.forEach((index, i) => {
342+
let dataPoint = this.points[index];
343+
for (let j = 0; j < d; j++) {
344+
solution[i * d + j] = dataPoint.projections[`tsne-${j}`] ?? rng();
345+
}
346+
});
347+
this.tsne.setSolution(solution);
348+
return;
349+
}
337350
if (this.tSNEShouldStop) {
338351
this.projections['tsne'] = false;
339352
stepCallback(null!);
@@ -344,12 +357,11 @@ export class DataSet {
344357
if (!this.tSNEShouldPause) {
345358
this.tsne.step();
346359
let result = this.tsne.getSolution();
360+
const d = this.tsne.getDim();
347361
sampledIndices.forEach((index, i) => {
348362
let dataPoint = this.points[index];
349-
dataPoint.projections['tsne-0'] = result[i * tsneDim + 0];
350-
dataPoint.projections['tsne-1'] = result[i * tsneDim + 1];
351-
if (tsneDim === 3) {
352-
dataPoint.projections['tsne-2'] = result[i * tsneDim + 2];
363+
for (let j = 0; j < d; j++) {
364+
dataPoint.projections[`tsne-${j}`] = result[i * d + j];
353365
}
354366
});
355367
this.projections['tsne'] = true;

Diff for: tensorboard/plugins/projector/vz_projector/vz-projector-projections-panel.ts

+7
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
154154
private updateTSNEPerplexityFromSliderChange() {
155155
if (this.perplexitySlider) {
156156
this.perplexity = +this.perplexitySlider.value;
157+
if (this.dataSet?.hasTSNERun) {
158+
this.dataSet.hasTSNERun = false;
159+
this.beginProjection(this.currentProjection);
160+
}
157161
}
158162
(this.$$('.tsne-perplexity span') as HTMLSpanElement).innerText =
159163
'' + this.perplexity;
@@ -380,6 +384,9 @@ class ProjectionsPanel extends LegacyElementMixin(PolymerElement) {
380384
}
381385
@observe('tSNEis3d')
382386
_tsneDimensionToggleObserver() {
387+
if (this.dataSet?.hasTSNERun) {
388+
this.dataSet.hasTSNERun = false;
389+
}
383390
this.beginProjection(this.currentProjection);
384391
}
385392
@observe('umapIs3d')

0 commit comments

Comments
 (0)