100
100
# ```python
101
101
# from pytorch_lightning import Trainer
102
102
# from pytorch_lightning.callbacks.finetuning_scheduler import FinetuningScheduler
103
- # trainer = Trainer(callbacks=[FinetuningScheduler()], resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
103
+ # trainer = Trainer(callbacks=[FinetuningScheduler()])
104
+ # trainer.fit(..., ckpt_path="some/path/to/my_checkpoint.ckpt")
104
105
# ```
105
106
#
106
107
# Training will resume at the depth/level of the provided checkpoint according the specified schedule. Schedules can be altered between training sessions but schedule compatibility is left to the user for maximal flexibility. If executing a user-defined schedule, typically the same schedule should be provided for the original and resumed training sessions.
107
108
#
108
109
# By default (``FinetuningScheduler.restore_best`` is ``True``), ``FinetuningScheduler`` will attempt to restore the best available checkpoint before finetuning depth transitions.
109
110
#
110
111
# ```python
111
- # trainer = Trainer(
112
- # callbacks=[FinetuningScheduler(new_incarnation_mode=True)],
113
- # resume_from_checkpoint="some/path/to/my_kth_best_checkpoint.ckpt",
114
- # )
112
+ # trainer = Trainer(callbacks=[FinetuningScheduler(new_incarnation_mode=True)])
113
+ # trainer.fit(..., ckpt_path="some/path/to/my_kth_best_checkpoint.ckpt")
115
114
# ```
116
115
#
117
116
# To handle the edge case wherein one is resuming scheduled finetuning from a non-best checkpoint and the previous best checkpoints may not be accessible, setting ``FinetuningScheduler.new_incarnation_mode`` to
120
119
# %% [markdown]
121
120
# <div class="alert alert-warning">
122
121
#
123
- # **Note:** Currently, _FinetuningScheduler_ only supports the following _TrainingTypePlugins_:
124
- #
125
- # - ``DDPPlugin``
126
- # - ``DDPShardedPlugin``
127
- # - ``DDPSpawnPlugin``
128
- # - ``DDPSpawnShardedPlugin``
129
- # - ``DataParallelPlugin``
130
- # - ``SingleDevicePlugin``
122
+ # **Note:** Currently, _FinetuningScheduler_ only supports the following ``StrategyType``s:
123
+ # - ``DP``
124
+ # - ``DDP``
125
+ # - ``DDP_SPAWN``
126
+ # - ``DDP_SHARDED``
127
+ # - ``DDP_SHARDED_SPAWN``
131
128
#
132
129
# </div>
133
130
142
139
#
143
140
144
141
# %%
145
- import logging
146
142
import os
147
143
import warnings
148
144
from datetime import datetime
160
156
from transformers import AutoConfig , AutoModelForSequenceClassification , AutoTokenizer
161
157
162
158
# %%
163
- # a couple helper functions to prepare code to work with the forthcoming hub and user module registry
164
- MOCK_HUB_REGISTRY = _Registry ()
159
+ # a couple helper functions to prepare code to work with a user module registry
160
+ MOCK_REGISTRY = _Registry ()
165
161
166
162
167
- def module_hub_mock (key : str , require_fqn : bool = False ) -> List :
163
+ def mock_register_module (key : str , require_fqn : bool = False ) -> List :
168
164
if key .lower () == "finetuningscheduler" :
169
165
mod = import_module ("pytorch_lightning.callbacks.finetuning_scheduler" )
170
- MOCK_HUB_REGISTRY .register_classes (mod , pl .callbacks .Callback )
166
+ MOCK_REGISTRY .register_classes (mod , pl .callbacks .Callback )
171
167
else :
172
168
raise MisconfigurationException (f"user module key '{ key } ' not found" )
173
169
registered_list = []
174
170
# make registered class available by unqualified class name by default
175
171
if not require_fqn :
176
- for n , c in MOCK_HUB_REGISTRY .items ():
172
+ for n , c in MOCK_REGISTRY .items ():
177
173
globals ()[f"{ n } " ] = c
178
- registered_list = ", " .join ([n for n in MOCK_HUB_REGISTRY .names ])
174
+ registered_list = ", " .join ([n for n in MOCK_REGISTRY .names ])
179
175
else :
180
- registered_list = ", " .join ([c .__module__ + "." + c .__name__ for c in MOCK_HUB_REGISTRY .classes ])
176
+ registered_list = ", " .join ([c .__module__ + "." + c .__name__ for c in MOCK_REGISTRY .classes ])
181
177
print (f"Imported and registered the following callbacks: { registered_list } " )
182
178
183
179
@@ -203,7 +199,7 @@ def instantiate_registered_class(init: Dict[str, Any], args: Optional[Union[Any,
203
199
else : # class is expected to be locally defined
204
200
args_class = globals ()[init ["class_path" ]]
205
201
elif init .get ("callback_key" , None ):
206
- callback_path = CALLBACK_REGISTRY .get (init ["callback_key" ], None ) or MOCK_HUB_REGISTRY .get (
202
+ callback_path = CALLBACK_REGISTRY .get (init ["callback_key" ], None ) or MOCK_REGISTRY .get (
207
203
init ["callback_key" ], None
208
204
)
209
205
assert callback_path , MisconfigurationException (
@@ -222,49 +218,37 @@ def instantiate_registered_class(init: Dict[str, Any], args: Optional[Union[Any,
222
218
223
219
# %%
224
220
# load the pl extension module we want to use. This will import all necessary callbacks.
225
- module_hub_mock ("finetuningscheduler" )
221
+ mock_register_module ("finetuningscheduler" )
226
222
# set notebook-level variables
227
- AVAIL_GPUS = min ( 1 , torch .cuda .device_count () )
223
+ AVAIL_GPUS = torch .cuda .device_count ()
228
224
TASK_NUM_LABELS = {"boolq" : 2 , "rte" : 2 }
229
225
DEFAULT_TASK = "rte"
230
226
231
- # narrow our logging to adapt it for a notebook environment
232
- for l_key in logging .Logger .manager .loggerDict .keys ():
233
- if "pytorch_lightning" in l_key :
234
- logging .getLogger (l_key ).setLevel ("INFO" )
235
- else :
236
- logging .getLogger (l_key ).setLevel ("CRITICAL" )
237
- pl_logger = logging .getLogger ("pytorch_lightning" )
238
- pl_logger .removeHandler (pl_logger .handlers [0 ])
239
- rz_logger = logging .getLogger ("pytorch_lightning.utilities.distributed" )
240
- rz_logger .addHandler (logging .StreamHandler ())
241
- rz_logger .handlers [0 ].setLevel ("INFO" )
227
+ # ignore warnings related tokenizers_parallelism/DataLoader parallelism tradeoff and
228
+ # expected logging behavior
229
+ for warnf in [".*does not have many workers*" , ".*The number of training samples.*" ]:
230
+ warnings .filterwarnings ("ignore" , warnf )
242
231
243
232
244
233
# %%
245
234
class RteBoolqDataModule (pl .LightningDataModule ):
246
235
"""A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets."""
247
236
248
- task_text_field_map = {"rte" : [ "premise" , "hypothesis" ] , "boolq" : [ "question" , "passage" ] }
249
- loader_columns = [
237
+ TASK_TEXT_FIELD_MAP = {"rte" : ( "premise" , "hypothesis" ) , "boolq" : ( "question" , "passage" ) }
238
+ LOADER_COLUMNS = (
250
239
"datasets_idx" ,
251
240
"input_ids" ,
252
241
"token_type_ids" ,
253
242
"attention_mask" ,
254
243
"start_positions" ,
255
244
"end_positions" ,
256
245
"labels" ,
257
- ]
258
- # ignore warnings related tokenizers_parallelism/DataLoader parallelism tradeoff and
259
- # expected logging behavior
260
- for warnf in [".*does not have many workers*" , ".*The number of training samples.*" ]:
261
- warnings .filterwarnings ("ignore" , warnf )
246
+ )
262
247
263
248
def __init__ (
264
249
self ,
265
250
model_name_or_path : str ,
266
251
task_name : str = DEFAULT_TASK ,
267
- prep_on_init : bool = False ,
268
252
max_seq_length : int = 128 ,
269
253
train_batch_size : int = 32 ,
270
254
eval_batch_size : int = 32 ,
@@ -278,23 +262,22 @@ def __init__(
278
262
self .train_batch_size = train_batch_size
279
263
self .eval_batch_size = eval_batch_size
280
264
self .tokenizers_parallelism = tokenizers_parallelism
281
- self .dataloader_kwargs = {"num_workers" : num_workers , "pin_memory" : pin_memory }
282
- self .text_fields = self .task_text_field_map [self .task_name ]
265
+ self .dataloader_kwargs = {
266
+ "num_workers" : dataloader_kwargs .get ("num_workers" , 0 ),
267
+ "pin_memory" : dataloader_kwargs .get ("pin_memory" , False ),
268
+ }
269
+ self .text_fields = self .TASK_TEXT_FIELD_MAP [self .task_name ]
283
270
self .num_labels = TASK_NUM_LABELS [self .task_name ]
284
271
os .environ ["TOKENIZERS_PARALLELISM" ] = "true" if self .tokenizers_parallelism else "false"
285
272
self .tokenizer = AutoTokenizer .from_pretrained (self .model_name_or_path , use_fast = True , local_files_only = False )
286
- if prep_on_init : # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
287
- # instantiated
288
- self .prepare_data ()
289
- self .setup ("fit" )
290
273
291
274
def setup (self , stage ):
292
275
self .dataset = datasets .load_dataset ("super_glue" , self .task_name )
293
276
for split in self .dataset .keys ():
294
277
self .dataset [split ] = self .dataset [split ].map (
295
- self .convert_to_features , batched = True , remove_columns = ["label" ]
278
+ self ._convert_to_features , batched = True , remove_columns = ["label" ]
296
279
)
297
- self .columns = [c for c in self .dataset [split ].column_names if c in self .loader_columns ]
280
+ self .columns = [c for c in self .dataset [split ].column_names if c in self .LOADER_COLUMNS ]
298
281
self .dataset [split ].set_format (type = "torch" , columns = self .columns )
299
282
300
283
self .eval_splits = [x for x in self .dataset .keys () if "validation" in x ]
@@ -329,7 +312,7 @@ def test_dataloader(self):
329
312
for x in self .eval_splits
330
313
]
331
314
332
- def convert_to_features (self , example_batch ):
315
+ def _convert_to_features (self , example_batch ):
333
316
text_pairs = list (zip (example_batch [self .text_fields [0 ]], example_batch [self .text_fields [1 ]]))
334
317
# Tokenize the text/text pairs
335
318
features = self .tokenizer .batch_encode_plus (
@@ -412,15 +395,15 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
412
395
preds = logits .squeeze ()
413
396
414
397
labels = batch ["labels" ]
398
+ self .log ("val_loss" , val_loss , prog_bar = True )
415
399
return {"loss" : val_loss , "preds" : preds , "labels" : labels }
416
400
417
401
def validation_epoch_end (self , outputs ):
418
402
preds = torch .cat ([x ["preds" ] for x in outputs ]).detach ().cpu ().numpy ()
419
403
labels = torch .cat ([x ["labels" ] for x in outputs ]).detach ().cpu ().numpy ()
420
404
loss = torch .stack ([x ["loss" ] for x in outputs ]).mean ()
421
- self .log ("val_loss" , loss , prog_bar = True , sync_dist = True )
422
405
metric_dict = self .metric .compute (predictions = preds , references = labels )
423
- self .log_dict (metric_dict , prog_bar = True , sync_dist = True )
406
+ self .log_dict (metric_dict , prog_bar = True )
424
407
return loss
425
408
426
409
def _init_param_groups (self ) -> List [Dict ]:
@@ -455,7 +438,7 @@ def configure_optimizers(self):
455
438
# but in this case we pass a list of parameter groups to ensure weight_decay is
456
439
# not applied to the bias parameter (for completeness, in this case it won't make much
457
440
# performance difference)
458
- optimizer = instantiate_registered_class (args = self .init_pgs (), init = self .optimizer_init )
441
+ optimizer = instantiate_registered_class (args = self ._init_param_groups (), init = self .optimizer_init )
459
442
scheduler = {
460
443
"scheduler" : instantiate_registered_class (args = optimizer , init = self .lr_scheduler_init ),
461
444
** self .pl_lrs_cfg ,
@@ -485,24 +468,61 @@ def configure_callbacks(self):
485
468
- model.albert.encoder.*.ffn_output.*
486
469
"""
487
470
ft_schedule_name = "RteBoolqModule_ft_schedule_albert_base.yaml"
471
+ # Let's write the schedule to a file so we can simulate loading an explicitly defined finetuning
472
+ # schedule.
488
473
with open (ft_schedule_name , "w" ) as f :
489
474
f .write (ft_schedule_yaml )
490
475
491
476
# %%
492
477
datasets .set_progress_bar_enabled (False )
493
478
pl .seed_everything (42 )
494
479
dm = RteBoolqDataModule (model_name_or_path = "albert-base-v2" , tokenizers_parallelism = False )
495
- dm .setup ("fit" )
480
+
481
+ # %% [markdown]
482
+ # ### Optimizer Configuration
483
+ #
484
+ # <div id="a2">
485
+ #
486
+ # Though other optimizers can arguably yield some marginal advantage contingent on the context,
487
+ # the Adam optimizer (and the [AdamW version](https://pytorch.org/docs/stable/_modules/torch/optim/adamw.html#AdamW) which
488
+ # implements decoupled weight decay) remains robust to hyperparameter choices and is commonly used for finetuning
489
+ # foundational language models. See [(Sivaprasad et al., 2020)](#f2) and [(Mosbach, Andriushchenko & Klakow, 2020)](#f3) for theoretical and systematic empirical justifications of Adam and its use in finetuning
490
+ # large transformer-based language models. The values used here have some justification
491
+ # in the referenced literature but have been largely empirically determined and while a good
492
+ # starting point could be could be further tuned.
493
+ #
494
+ # </div>
495
+
496
+ # %%
496
497
optimizer_init = {
497
498
"class_path" : "torch.optim.AdamW" ,
498
499
"init_args" : {"weight_decay" : 1e-05 , "eps" : 1e-07 , "lr" : 1e-05 },
499
500
}
501
+
502
+ # %% [markdown]
503
+ # ### LR Scheduler Configuration
504
+ #
505
+ # <div id="a3">
506
+ #
507
+ # The [CosineAnnealingWarmRestarts scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html?highlight=cosineannealingwarm#torch.optim.lr_scheduler.CosineAnnealingWarmRestarts) nicely fits with our iterative finetuning since it does not depend upon a global max_epoch
508
+ # value. The importance of initial warmup is reduced due to the innate warmup effect of Adam bias correction [[5]](#f3)
509
+ # and the gradual thawing we are performing. Note that commonly used LR schedulers that depend on providing
510
+ # max_iterations/epochs (e.g. the
511
+ # [CosineWarmupScheduler](https://github.com/PyTorchLightning/lightning-tutorials/blob/0c325829101d5a6ebf32ed99bbf5b09badf04a59/course_UvA-DL/05-transformers-and-MH-attention/Transformers_MHAttention.py#L688)
512
+ # used in other pytorch-lightning tutorials) also work with FinetuningScheduler. Though the LR scheduler is theoretically
513
+ # justified [(Loshchilov & Hutter, 2016)](#f4), the particular values provided here are primarily empircally driven.
514
+ #
515
+ # </div>
516
+
517
+
518
+ # %%
500
519
lr_scheduler_init = {
501
520
"class_path" : "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts" ,
502
521
"init_args" : {"T_0" : 1 , "T_mult" : 2 , "eta_min" : 1e-07 },
503
522
}
504
523
pl_lrs_cfg = {"interval" : "epoch" , "frequency" : 1 , "name" : "CosineAnnealingWarmRestarts" }
505
524
525
+ # %%
506
526
model = RteBoolqModule (
507
527
model_name_or_path = "albert-base-v2" ,
508
528
optimizer_init = optimizer_init ,
@@ -528,15 +548,23 @@ def configure_callbacks(self):
528
548
enable_progress_bar = False
529
549
530
550
# %%
531
- trainer = pl .Trainer (
532
- enable_progress_bar = enable_progress_bar ,
533
- precision = 16 ,
534
- accelerator = "auto" ,
535
- devices = "auto" ,
536
- callbacks = callbacks ,
537
- logger = logger ,
538
- )
539
- trainer .fit (model , datamodule = dm )
551
+ def train () -> None :
552
+ trainer = pl .Trainer (
553
+ enable_progress_bar = enable_progress_bar ,
554
+ precision = 16 ,
555
+ gpus = 1 ,
556
+ # accelerator="auto",
557
+ # devices="auto",
558
+ callbacks = callbacks ,
559
+ logger = logger ,
560
+ )
561
+ trainer .fit (model , datamodule = dm )
562
+
563
+
564
+ if AVAIL_GPUS > 0 :
565
+ train ()
566
+ else :
567
+ print ("Given the multiple phases of finetuning demonstrated, this notebook is best used with a GPU" )
540
568
541
569
# %% [markdown]
542
570
# ## Footnotes
@@ -560,5 +588,25 @@ def configure_callbacks(self):
560
588
# [Peters, M. E., Ruder, S., & Smith, N. A. (2019)](https://arxiv.org/pdf/1903.05987.pdf). To tune or not to
561
589
# tune? adapting pretrained representations to diverse tasks. arXiv preprint arXiv:1903.05987. [↩](#a1)
562
590
#
563
- # </li>
564
- # </ul>
591
+ # </li>
592
+ # <li id="f2">
593
+ #
594
+ # [Sivaprasad, P. T., Mai, F., Vogels, T., Jaggi, M., & Fleuret, F. (2020)](https://arxiv.org/pdf/1910.11758.pdf).
595
+ # Optimizer benchmarking needs to account for hyperparameter tuning. In International Conference on Machine Learning
596
+ # (pp. 9036-9045). PMLR. [↩](#a2)
597
+ #
598
+ # </li>
599
+ # <li id="f3">
600
+ #
601
+ # [Mosbach, M., Andriushchenko, M., & Klakow, D. (2020)](https://arxiv.org/pdf/2006.04884.pdf). On the stability of
602
+ # fine-tuning bert: Misconceptions, explanations, and strong baselines. arXiv preprint arXiv:2006.04884. [↩](#a2)
603
+ #
604
+ # </li>
605
+ # <li id="f4">
606
+ #
607
+ # [Loshchilov, I., & Hutter, F. (2016)](https://arxiv.org/pdf/1608.03983.pdf). Sgdr: Stochastic gradient descent with
608
+ # warm restarts. arXiv preprint arXiv:1608.03983. [↩](#a3)
609
+ #
610
+ # </li>
611
+ #
612
+ # </ol>
0 commit comments