Skip to content

Commit 9be329a

Browse files
committed
facepalm, forgot to include quantizer parameters. recon loss now looks normal
1 parent dc22504 commit 9be329a

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

magvit2_pytorch/magvit2_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,7 @@ def parameters(self):
13531353
*self.decoder_layers.parameters(),
13541354
*self.encoder_cond_in.parameters(),
13551355
*self.decoder_cond_in.parameters(),
1356+
*self.quantizers.parameters()
13561357
]
13571358

13581359
def discr_parameters(self):

magvit2_pytorch/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def train_step(self, dl_iter):
253253

254254
self.accelerator.backward(loss / self.grad_accum_every)
255255

256-
self.print(f'loss: {loss.item():.3f}')
256+
self.print(f'recon loss: {loss_breakdown.recon_loss.item():.3f}')
257257

258258
if exists(self.max_grad_norm):
259259
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
@@ -336,7 +336,7 @@ def valid_step(
336336
valid_videos.append(valid_video)
337337
recon_videos.append(recon_video)
338338

339-
self.print(f'validation loss {recon_loss:.3f}')
339+
self.print(f'validation recon loss {recon_loss:.3f}')
340340

341341
if not save_recons:
342342
return

magvit2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.70'
1+
__version__ = '0.1.0'

0 commit comments

Comments
 (0)