Skip to content

Commit 070eef9

Browse files
speediedantchaton
andcommitted
Update lightning_examples/finetuning-scheduler/finetuning-scheduler.py
Co-authored-by: thomas chaton <[email protected]>
1 parent de2c5b2 commit 070eef9

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

lightning_examples/finetuning-scheduler/finetuning-scheduler.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# %% [markdown]
2525
# ## Basic Usage
2626
#
27-
# If no finetuning schedule is user-provided, ``FinetuningScheduler`` will generate a
27+
# If no finetuning schedule is provided by the user, ``FinetuningScheduler`` will generate a
2828
# [default schedule](#The-Default-Finetuning-Schedule) and proceed to finetune according to the generated schedule, using default ``FTSEarlyStopping`` and ``FTSCheckpoint`` callbacks with ``monitor=val_loss``.
2929
#
3030
# ```python
@@ -268,9 +268,8 @@ def __init__(
268268
max_seq_length: int = 128,
269269
train_batch_size: int = 32,
270270
eval_batch_size: int = 32,
271-
pin_memory: bool = False,
272271
tokenizers_parallelism: bool = True,
273-
num_workers: int = 0,
272+
**dataloader_kwargs: Any,
274273
):
275274
super().__init__()
276275
self.model_name_or_path = model_name_or_path
@@ -399,7 +398,7 @@ def training_step(self, batch, batch_idx):
399398

400399
def training_epoch_end(self, outputs: List[Any]) -> None:
401400
loss = torch.stack([x["loss"] for x in outputs]).mean()
402-
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
401+
self.log("train_loss", loss, prog_bar=True)
403402
if self.finetuningscheduler_callback:
404403
self.log("finetuning_schedule_depth", float(self.finetuningscheduler_callback.curr_depth))
405404

@@ -424,14 +423,14 @@ def validation_epoch_end(self, outputs):
424423
self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
425424
return loss
426425

427-
def init_pgs(self) -> List[Dict]:
426+
def _init_param_groups(self) -> List[Dict]:
428427
"""Initialize the parameter groups. Used to ensure weight_decay is not applied to our specified bias
429428
parameters when we initialize the optimizer.
430429
431430
Returns:
432431
List[Dict]: A list of parameter group dictionaries.
433432
"""
434-
pgs = [
433+
return [
435434
{
436435
"params": [
437436
p
@@ -449,7 +448,6 @@ def init_pgs(self) -> List[Dict]:
449448
"weight_decay": 0.0,
450449
},
451450
]
452-
return pgs
453451

454452
def configure_optimizers(self):
455453
# the phase 0 parameters will have been set to require gradients during setup
@@ -472,7 +470,7 @@ def configure_callbacks(self):
472470

473471

474472
# %%
475-
# let's create a finetuning schedule for our model and run an explicitly scheduled finetuning training scenario with it
473+
# Let's create a finetuning schedule for our model and run an explicitly scheduled finetuning training scenario with it
476474
# Please see the documentation for a full description of the schedule format
477475
ft_schedule_yaml = """
478476
0:
@@ -527,13 +525,14 @@ def configure_callbacks(self):
527525
# # %load_ext tensorboard
528526
# # %tensorboard --logdir example_logdir
529527
# disable progress bar by default to focus on multi-phase training logs. Set to True to re-enable if desired
530-
show_progress = False
528+
enable_progress_bar = False
531529

532530
# %%
533531
trainer = pl.Trainer(
534-
enable_progress_bar=show_progress,
532+
enable_progress_bar=enable_progress_bar,
535533
precision=16,
536-
gpus=AVAIL_GPUS,
534+
accelerator="auto",
535+
devices="auto",
537536
callbacks=callbacks,
538537
logger=logger,
539538
)

0 commit comments

Comments
 (0)