Skip to content

Commit dc2d25f

Browse files
committed
additional tuning literature references and several additional recommended improvements for finetuning_scheduler tutorial
1 parent 070eef9 commit dc2d25f

File tree

1 file changed

+115
-67
lines changed

1 file changed

+115
-67
lines changed

lightning_examples/finetuning-scheduler/finetuning-scheduler.py

Lines changed: 115 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,17 @@
100100
# ```python
101101
# from pytorch_lightning import Trainer
102102
# from pytorch_lightning.callbacks.finetuning_scheduler import FinetuningScheduler
103-
# trainer = Trainer(callbacks=[FinetuningScheduler()], resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
103+
# trainer = Trainer(callbacks=[FinetuningScheduler()])
104+
# trainer.fit(..., ckpt_path="some/path/to/my_checkpoint.ckpt")
104105
# ```
105106
#
106107
# Training will resume at the depth/level of the provided checkpoint according the specified schedule. Schedules can be altered between training sessions but schedule compatibility is left to the user for maximal flexibility. If executing a user-defined schedule, typically the same schedule should be provided for the original and resumed training sessions.
107108
#
108109
# By default (``FinetuningScheduler.restore_best`` is ``True``), ``FinetuningScheduler`` will attempt to restore the best available checkpoint before finetuning depth transitions.
109110
#
110111
# ```python
111-
# trainer = Trainer(
112-
# callbacks=[FinetuningScheduler(new_incarnation_mode=True)],
113-
# resume_from_checkpoint="some/path/to/my_kth_best_checkpoint.ckpt",
114-
# )
112+
# trainer = Trainer(callbacks=[FinetuningScheduler(new_incarnation_mode=True)])
113+
# trainer.fit(..., ckpt_path="some/path/to/my_kth_best_checkpoint.ckpt")
115114
# ```
116115
#
117116
# To handle the edge case wherein one is resuming scheduled finetuning from a non-best checkpoint and the previous best checkpoints may not be accessible, setting ``FinetuningScheduler.new_incarnation_mode`` to
@@ -120,14 +119,12 @@
120119
# %% [markdown]
121120
# <div class="alert alert-warning">
122121
#
123-
# **Note:** Currently, _FinetuningScheduler_ only supports the following _TrainingTypePlugins_:
124-
#
125-
# - ``DDPPlugin``
126-
# - ``DDPShardedPlugin``
127-
# - ``DDPSpawnPlugin``
128-
# - ``DDPSpawnShardedPlugin``
129-
# - ``DataParallelPlugin``
130-
# - ``SingleDevicePlugin``
122+
# **Note:** Currently, _FinetuningScheduler_ only supports the following ``StrategyType``s:
123+
# - ``DP``
124+
# - ``DDP``
125+
# - ``DDP_SPAWN``
126+
# - ``DDP_SHARDED``
127+
# - ``DDP_SHARDED_SPAWN``
131128
#
132129
# </div>
133130

@@ -142,7 +139,6 @@
142139
#
143140

144141
# %%
145-
import logging
146142
import os
147143
import warnings
148144
from datetime import datetime
@@ -160,24 +156,24 @@
160156
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
161157

162158
# %%
163-
# a couple helper functions to prepare code to work with the forthcoming hub and user module registry
164-
MOCK_HUB_REGISTRY = _Registry()
159+
# a couple helper functions to prepare code to work with a user module registry
160+
MOCK_REGISTRY = _Registry()
165161

166162

167-
def module_hub_mock(key: str, require_fqn: bool = False) -> List:
163+
def mock_register_module(key: str, require_fqn: bool = False) -> List:
168164
if key.lower() == "finetuningscheduler":
169165
mod = import_module("pytorch_lightning.callbacks.finetuning_scheduler")
170-
MOCK_HUB_REGISTRY.register_classes(mod, pl.callbacks.Callback)
166+
MOCK_REGISTRY.register_classes(mod, pl.callbacks.Callback)
171167
else:
172168
raise MisconfigurationException(f"user module key '{key}' not found")
173169
registered_list = []
174170
# make registered class available by unqualified class name by default
175171
if not require_fqn:
176-
for n, c in MOCK_HUB_REGISTRY.items():
172+
for n, c in MOCK_REGISTRY.items():
177173
globals()[f"{n}"] = c
178-
registered_list = ", ".join([n for n in MOCK_HUB_REGISTRY.names])
174+
registered_list = ", ".join([n for n in MOCK_REGISTRY.names])
179175
else:
180-
registered_list = ", ".join([c.__module__ + "." + c.__name__ for c in MOCK_HUB_REGISTRY.classes])
176+
registered_list = ", ".join([c.__module__ + "." + c.__name__ for c in MOCK_REGISTRY.classes])
181177
print(f"Imported and registered the following callbacks: {registered_list}")
182178

183179

@@ -203,7 +199,7 @@ def instantiate_registered_class(init: Dict[str, Any], args: Optional[Union[Any,
203199
else: # class is expected to be locally defined
204200
args_class = globals()[init["class_path"]]
205201
elif init.get("callback_key", None):
206-
callback_path = CALLBACK_REGISTRY.get(init["callback_key"], None) or MOCK_HUB_REGISTRY.get(
202+
callback_path = CALLBACK_REGISTRY.get(init["callback_key"], None) or MOCK_REGISTRY.get(
207203
init["callback_key"], None
208204
)
209205
assert callback_path, MisconfigurationException(
@@ -222,49 +218,37 @@ def instantiate_registered_class(init: Dict[str, Any], args: Optional[Union[Any,
222218

223219
# %%
224220
# load the pl extension module we want to use. This will import all necessary callbacks.
225-
module_hub_mock("finetuningscheduler")
221+
mock_register_module("finetuningscheduler")
226222
# set notebook-level variables
227-
AVAIL_GPUS = min(1, torch.cuda.device_count())
223+
AVAIL_GPUS = torch.cuda.device_count()
228224
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
229225
DEFAULT_TASK = "rte"
230226

231-
# narrow our logging to adapt it for a notebook environment
232-
for l_key in logging.Logger.manager.loggerDict.keys():
233-
if "pytorch_lightning" in l_key:
234-
logging.getLogger(l_key).setLevel("INFO")
235-
else:
236-
logging.getLogger(l_key).setLevel("CRITICAL")
237-
pl_logger = logging.getLogger("pytorch_lightning")
238-
pl_logger.removeHandler(pl_logger.handlers[0])
239-
rz_logger = logging.getLogger("pytorch_lightning.utilities.distributed")
240-
rz_logger.addHandler(logging.StreamHandler())
241-
rz_logger.handlers[0].setLevel("INFO")
227+
# ignore warnings related tokenizers_parallelism/DataLoader parallelism tradeoff and
228+
# expected logging behavior
229+
for warnf in [".*does not have many workers*", ".*The number of training samples.*"]:
230+
warnings.filterwarnings("ignore", warnf)
242231

243232

244233
# %%
245234
class RteBoolqDataModule(pl.LightningDataModule):
246235
"""A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets."""
247236

248-
task_text_field_map = {"rte": ["premise", "hypothesis"], "boolq": ["question", "passage"]}
249-
loader_columns = [
237+
TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("question", "passage")}
238+
LOADER_COLUMNS = (
250239
"datasets_idx",
251240
"input_ids",
252241
"token_type_ids",
253242
"attention_mask",
254243
"start_positions",
255244
"end_positions",
256245
"labels",
257-
]
258-
# ignore warnings related tokenizers_parallelism/DataLoader parallelism tradeoff and
259-
# expected logging behavior
260-
for warnf in [".*does not have many workers*", ".*The number of training samples.*"]:
261-
warnings.filterwarnings("ignore", warnf)
246+
)
262247

263248
def __init__(
264249
self,
265250
model_name_or_path: str,
266251
task_name: str = DEFAULT_TASK,
267-
prep_on_init: bool = False,
268252
max_seq_length: int = 128,
269253
train_batch_size: int = 32,
270254
eval_batch_size: int = 32,
@@ -278,23 +262,22 @@ def __init__(
278262
self.train_batch_size = train_batch_size
279263
self.eval_batch_size = eval_batch_size
280264
self.tokenizers_parallelism = tokenizers_parallelism
281-
self.dataloader_kwargs = {"num_workers": num_workers, "pin_memory": pin_memory}
282-
self.text_fields = self.task_text_field_map[self.task_name]
265+
self.dataloader_kwargs = {
266+
"num_workers": dataloader_kwargs.get("num_workers", 0),
267+
"pin_memory": dataloader_kwargs.get("pin_memory", False),
268+
}
269+
self.text_fields = self.TASK_TEXT_FIELD_MAP[self.task_name]
283270
self.num_labels = TASK_NUM_LABELS[self.task_name]
284271
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.tokenizers_parallelism else "false"
285272
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, local_files_only=False)
286-
if prep_on_init: # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
287-
# instantiated
288-
self.prepare_data()
289-
self.setup("fit")
290273

291274
def setup(self, stage):
292275
self.dataset = datasets.load_dataset("super_glue", self.task_name)
293276
for split in self.dataset.keys():
294277
self.dataset[split] = self.dataset[split].map(
295-
self.convert_to_features, batched=True, remove_columns=["label"]
278+
self._convert_to_features, batched=True, remove_columns=["label"]
296279
)
297-
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
280+
self.columns = [c for c in self.dataset[split].column_names if c in self.LOADER_COLUMNS]
298281
self.dataset[split].set_format(type="torch", columns=self.columns)
299282

300283
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
@@ -329,7 +312,7 @@ def test_dataloader(self):
329312
for x in self.eval_splits
330313
]
331314

332-
def convert_to_features(self, example_batch):
315+
def _convert_to_features(self, example_batch):
333316
text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
334317
# Tokenize the text/text pairs
335318
features = self.tokenizer.batch_encode_plus(
@@ -412,15 +395,15 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
412395
preds = logits.squeeze()
413396

414397
labels = batch["labels"]
398+
self.log("val_loss", val_loss, prog_bar=True)
415399
return {"loss": val_loss, "preds": preds, "labels": labels}
416400

417401
def validation_epoch_end(self, outputs):
418402
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
419403
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
420404
loss = torch.stack([x["loss"] for x in outputs]).mean()
421-
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
422405
metric_dict = self.metric.compute(predictions=preds, references=labels)
423-
self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
406+
self.log_dict(metric_dict, prog_bar=True)
424407
return loss
425408

426409
def _init_param_groups(self) -> List[Dict]:
@@ -455,7 +438,7 @@ def configure_optimizers(self):
455438
# but in this case we pass a list of parameter groups to ensure weight_decay is
456439
# not applied to the bias parameter (for completeness, in this case it won't make much
457440
# performance difference)
458-
optimizer = instantiate_registered_class(args=self.init_pgs(), init=self.optimizer_init)
441+
optimizer = instantiate_registered_class(args=self._init_param_groups(), init=self.optimizer_init)
459442
scheduler = {
460443
"scheduler": instantiate_registered_class(args=optimizer, init=self.lr_scheduler_init),
461444
**self.pl_lrs_cfg,
@@ -485,24 +468,61 @@ def configure_callbacks(self):
485468
- model.albert.encoder.*.ffn_output.*
486469
"""
487470
ft_schedule_name = "RteBoolqModule_ft_schedule_albert_base.yaml"
471+
# Let's write the schedule to a file so we can simulate loading an explicitly defined finetuning
472+
# schedule.
488473
with open(ft_schedule_name, "w") as f:
489474
f.write(ft_schedule_yaml)
490475

491476
# %%
492477
datasets.set_progress_bar_enabled(False)
493478
pl.seed_everything(42)
494479
dm = RteBoolqDataModule(model_name_or_path="albert-base-v2", tokenizers_parallelism=False)
495-
dm.setup("fit")
480+
481+
# %% [markdown]
482+
# ### Optimizer Configuration
483+
#
484+
# <div id="a2">
485+
#
486+
# Though other optimizers can arguably yield some marginal advantage contingent on the context,
487+
# the Adam optimizer (and the [AdamW version](https://pytorch.org/docs/stable/_modules/torch/optim/adamw.html#AdamW) which
488+
# implements decoupled weight decay) remains robust to hyperparameter choices and is commonly used for finetuning
489+
# foundational language models. See [(Sivaprasad et al., 2020)](#f2) and [(Mosbach, Andriushchenko & Klakow, 2020)](#f3) for theoretical and systematic empirical justifications of Adam and its use in finetuning
490+
# large transformer-based language models. The values used here have some justification
491+
# in the referenced literature but have been largely empirically determined and while a good
492+
# starting point could be could be further tuned.
493+
#
494+
# </div>
495+
496+
# %%
496497
optimizer_init = {
497498
"class_path": "torch.optim.AdamW",
498499
"init_args": {"weight_decay": 1e-05, "eps": 1e-07, "lr": 1e-05},
499500
}
501+
502+
# %% [markdown]
503+
# ### LR Scheduler Configuration
504+
#
505+
# <div id="a3">
506+
#
507+
# The [CosineAnnealingWarmRestarts scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html?highlight=cosineannealingwarm#torch.optim.lr_scheduler.CosineAnnealingWarmRestarts) nicely fits with our iterative finetuning since it does not depend upon a global max_epoch
508+
# value. The importance of initial warmup is reduced due to the innate warmup effect of Adam bias correction [[5]](#f3)
509+
# and the gradual thawing we are performing. Note that commonly used LR schedulers that depend on providing
510+
# max_iterations/epochs (e.g. the
511+
# [CosineWarmupScheduler](https://github.com/PyTorchLightning/lightning-tutorials/blob/0c325829101d5a6ebf32ed99bbf5b09badf04a59/course_UvA-DL/05-transformers-and-MH-attention/Transformers_MHAttention.py#L688)
512+
# used in other pytorch-lightning tutorials) also work with FinetuningScheduler. Though the LR scheduler is theoretically
513+
# justified [(Loshchilov & Hutter, 2016)](#f4), the particular values provided here are primarily empircally driven.
514+
#
515+
# </div>
516+
517+
518+
# %%
500519
lr_scheduler_init = {
501520
"class_path": "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts",
502521
"init_args": {"T_0": 1, "T_mult": 2, "eta_min": 1e-07},
503522
}
504523
pl_lrs_cfg = {"interval": "epoch", "frequency": 1, "name": "CosineAnnealingWarmRestarts"}
505524

525+
# %%
506526
model = RteBoolqModule(
507527
model_name_or_path="albert-base-v2",
508528
optimizer_init=optimizer_init,
@@ -528,15 +548,23 @@ def configure_callbacks(self):
528548
enable_progress_bar = False
529549

530550
# %%
531-
trainer = pl.Trainer(
532-
enable_progress_bar=enable_progress_bar,
533-
precision=16,
534-
accelerator="auto",
535-
devices="auto",
536-
callbacks=callbacks,
537-
logger=logger,
538-
)
539-
trainer.fit(model, datamodule=dm)
551+
def train() -> None:
552+
trainer = pl.Trainer(
553+
enable_progress_bar=enable_progress_bar,
554+
precision=16,
555+
gpus=1,
556+
# accelerator="auto",
557+
# devices="auto",
558+
callbacks=callbacks,
559+
logger=logger,
560+
)
561+
trainer.fit(model, datamodule=dm)
562+
563+
564+
if AVAIL_GPUS > 0:
565+
train()
566+
else:
567+
print("Given the multiple phases of finetuning demonstrated, this notebook is best used with a GPU")
540568

541569
# %% [markdown]
542570
# ## Footnotes
@@ -560,5 +588,25 @@ def configure_callbacks(self):
560588
# [Peters, M. E., Ruder, S., & Smith, N. A. (2019)](https://arxiv.org/pdf/1903.05987.pdf). To tune or not to
561589
# tune? adapting pretrained representations to diverse tasks. arXiv preprint arXiv:1903.05987. [↩](#a1)
562590
#
563-
# </li>
564-
# </ul>
591+
# </li>
592+
# <li id="f2">
593+
#
594+
# [Sivaprasad, P. T., Mai, F., Vogels, T., Jaggi, M., & Fleuret, F. (2020)](https://arxiv.org/pdf/1910.11758.pdf).
595+
# Optimizer benchmarking needs to account for hyperparameter tuning. In International Conference on Machine Learning
596+
# (pp. 9036-9045). PMLR. [↩](#a2)
597+
#
598+
# </li>
599+
# <li id="f3">
600+
#
601+
# [Mosbach, M., Andriushchenko, M., & Klakow, D. (2020)](https://arxiv.org/pdf/2006.04884.pdf). On the stability of
602+
# fine-tuning bert: Misconceptions, explanations, and strong baselines. arXiv preprint arXiv:2006.04884. [↩](#a2)
603+
#
604+
# </li>
605+
# <li id="f4">
606+
#
607+
# [Loshchilov, I., & Hutter, F. (2016)](https://arxiv.org/pdf/1608.03983.pdf). Sgdr: Stochastic gradient descent with
608+
# warm restarts. arXiv preprint arXiv:1608.03983. [↩](#a3)
609+
#
610+
# </li>
611+
#
612+
# </ol>

0 commit comments

Comments
 (0)