1
- # -*- coding: utf-8 -*-
2
1
# ---
3
2
# jupyter:
4
3
# jupytext:
7
6
# extension: .py
8
7
# format_name: percent
9
8
# format_version: '1.3'
10
- # jupytext_version: 1.13.1
9
+ # jupytext_version: 1.13.2
11
10
# kernelspec:
12
11
# display_name: 'Python 3.7.11 64-bit (''pldev_tutorials'': conda)'
13
12
# language: python
84
83
# ```
85
84
86
85
# %% [markdown]
87
- # ## EarlyStopping and Epoch-Driven Phase Transition Criteria
86
+ # ## Early-Stopping and Epoch-Driven Phase Transition Criteria
88
87
#
89
88
#
90
89
# By default, ``FTSEarlyStopping`` and epoch-driven
155
154
# The following example demonstrates the use of ``FinetuningScheduler`` to finetune a small foundational model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task of [SuperGLUE](https://super.gluebenchmark.com/). Iterative early-stopping will be applied according to a user-specified schedule.
156
155
#
157
156
# ``FinetuningScheduler`` can be used to achieve non-trivial model performance improvements in both implicit and explicit scheduling contexts at an also non-trivial computational cost.
157
+ #
158
158
159
159
# %%
160
+ import logging
160
161
import os
161
162
import warnings
162
163
from datetime import datetime
163
- from typing import Any , Dict , List , Optional , Tuple , Union
164
164
from importlib import import_module
165
- import logging
166
-
167
- import torch
168
- from torch .utils .data import DataLoader
165
+ from typing import Any , Dict , List , Optional , Tuple , Union
169
166
167
+ import datasets
170
168
import pytorch_lightning as pl
169
+ import torch
170
+ from pytorch_lightning .loggers .tensorboard import TensorBoardLogger
171
171
from pytorch_lightning .utilities import rank_zero_warn
172
172
from pytorch_lightning .utilities .cli import CALLBACK_REGISTRY , _Registry
173
173
from pytorch_lightning .utilities .exceptions import MisconfigurationException
174
- from pytorch_lightning .loggers .tensorboard import TensorBoardLogger
175
-
176
- import datasets
174
+ from torch .utils .data import DataLoader
177
175
from transformers import AutoConfig , AutoModelForSequenceClassification , AutoTokenizer
178
176
179
-
180
177
# %%
181
178
# a couple helper functions to prepare code to work with the forthcoming hub and user module registry
182
179
MOCK_HUB_REGISTRY = _Registry ()
@@ -195,17 +192,13 @@ def module_hub_mock(key: str, require_fqn: bool = False) -> List:
195
192
globals ()[f"{ n } " ] = c
196
193
registered_list = ", " .join ([n for n in MOCK_HUB_REGISTRY .names ])
197
194
else :
198
- registered_list = ", " .join (
199
- [c .__module__ + "." + c .__name__ for c in MOCK_HUB_REGISTRY .classes ]
200
- )
195
+ registered_list = ", " .join ([c .__module__ + "." + c .__name__ for c in MOCK_HUB_REGISTRY .classes ])
201
196
print (f"Imported and registered the following callbacks: { registered_list } " )
202
197
203
198
204
- def instantiate_registered_class (
205
- init : Dict [str , Any ], args : Optional [Union [Any , Tuple [Any , ...]]] = None
206
- ) -> Any :
207
- """Instantiates a class with the given args and init. Accepts class definitions in the form
208
- of a "class_path" or "callback_key" associated with a _Registry
199
+ def instantiate_registered_class (init : Dict [str , Any ], args : Optional [Union [Any , Tuple [Any , ...]]] = None ) -> Any :
200
+ """Instantiates a class with the given args and init. Accepts class definitions in the form of a "class_path"
201
+ or "callback_key" associated with a _Registry.
209
202
210
203
Args:
211
204
init: Dict of the form {"class_path":... or "callback_key":..., "init_args":...}.
@@ -225,17 +218,16 @@ def instantiate_registered_class(
225
218
else : # class is expected to be locally defined
226
219
args_class = globals ()[init ["class_path" ]]
227
220
elif init .get ("callback_key" , None ):
228
- callback_path = CALLBACK_REGISTRY .get (
221
+ callback_path = CALLBACK_REGISTRY .get (init [ "callback_key" ], None ) or MOCK_HUB_REGISTRY . get (
229
222
init ["callback_key" ], None
230
- ) or MOCK_HUB_REGISTRY . get ( init [ "callback_key" ], None )
223
+ )
231
224
assert callback_path , MisconfigurationException (
232
225
f'specified callback_key { init ["callback_key" ]} has not been registered'
233
226
)
234
227
class_module , class_name = callback_path .__module__ , callback_path .__name__
235
228
else :
236
229
raise MisconfigurationException (
237
- "Neither a class_path nor callback_key were included in a configuration that"
238
- "requires one"
230
+ "Neither a class_path nor callback_key were included in a configuration that" "requires one"
239
231
)
240
232
if not shortcircuit_local :
241
233
module = __import__ (class_module , fromlist = [class_name ])
@@ -266,7 +258,7 @@ def instantiate_registered_class(
266
258
267
259
# %%
268
260
class RteBoolqDataModule (pl .LightningDataModule ):
269
- """A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets"""
261
+ """A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets. """
270
262
271
263
task_text_field_map = {"rte" : ["premise" , "hypothesis" ], "boolq" : ["question" , "passage" ]}
272
264
loader_columns = [
@@ -306,12 +298,8 @@ def __init__(
306
298
self .text_fields = self .task_text_field_map [self .task_name ]
307
299
self .num_labels = TASK_NUM_LABELS [self .task_name ]
308
300
os .environ ["TOKENIZERS_PARALLELISM" ] = "true" if self .tokenizers_parallelism else "false"
309
- self .tokenizer = AutoTokenizer .from_pretrained (
310
- self .model_name_or_path , use_fast = True , local_files_only = False
311
- )
312
- if (
313
- prep_on_init
314
- ): # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
301
+ self .tokenizer = AutoTokenizer .from_pretrained (self .model_name_or_path , use_fast = True , local_files_only = False )
302
+ if prep_on_init : # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
315
303
# instantiated
316
304
self .prepare_data ()
317
305
self .setup ("fit" )
@@ -322,9 +310,7 @@ def setup(self, stage):
322
310
self .dataset [split ] = self .dataset [split ].map (
323
311
self .convert_to_features , batched = True , remove_columns = ["label" ]
324
312
)
325
- self .columns = [
326
- c for c in self .dataset [split ].column_names if c in self .loader_columns
327
- ]
313
+ self .columns = [c for c in self .dataset [split ].column_names if c in self .loader_columns ]
328
314
self .dataset [split ].set_format (type = "torch" , columns = self .columns )
329
315
330
316
self .eval_splits = [x for x in self .dataset .keys () if "validation" in x ]
@@ -335,9 +321,7 @@ def prepare_data(self):
335
321
datasets .load_dataset ("super_glue" , self .task_name )
336
322
337
323
def train_dataloader (self ):
338
- return DataLoader (
339
- self .dataset ["train" ], batch_size = self .train_batch_size , ** self .dataloader_kwargs
340
- )
324
+ return DataLoader (self .dataset ["train" ], batch_size = self .train_batch_size , ** self .dataloader_kwargs )
341
325
342
326
def val_dataloader (self ):
343
327
if len (self .eval_splits ) == 1 :
@@ -348,29 +332,21 @@ def val_dataloader(self):
348
332
)
349
333
elif len (self .eval_splits ) > 1 :
350
334
return [
351
- DataLoader (
352
- self .dataset [x ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs
353
- )
335
+ DataLoader (self .dataset [x ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs )
354
336
for x in self .eval_splits
355
337
]
356
338
357
339
def test_dataloader (self ):
358
340
if len (self .eval_splits ) == 1 :
359
- return DataLoader (
360
- self .dataset ["test" ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs
361
- )
341
+ return DataLoader (self .dataset ["test" ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs )
362
342
elif len (self .eval_splits ) > 1 :
363
343
return [
364
- DataLoader (
365
- self .dataset [x ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs
366
- )
344
+ DataLoader (self .dataset [x ], batch_size = self .eval_batch_size , ** self .dataloader_kwargs )
367
345
for x in self .eval_splits
368
346
]
369
347
370
348
def convert_to_features (self , example_batch ):
371
- text_pairs = list (
372
- zip (example_batch [self .text_fields [0 ]], example_batch [self .text_fields [1 ]])
373
- )
349
+ text_pairs = list (zip (example_batch [self .text_fields [0 ]], example_batch [self .text_fields [1 ]]))
374
350
# Tokenize the text/text pairs
375
351
features = self .tokenizer .batch_encode_plus (
376
352
text_pairs , max_length = self .max_seq_length , padding = "longest" , truncation = True
@@ -382,10 +358,8 @@ def convert_to_features(self, example_batch):
382
358
383
359
# %%
384
360
class RteBoolqModule (pl .LightningModule ):
385
- """A ``LightningModule`` that can be used to finetune a foundational
386
- model on either the RTE or BoolQ SuperGLUE tasks using Hugging Face
387
- implementations of a given model and the `SuperGLUE Hugging Face dataset.
388
- """
361
+ """A ``LightningModule`` that can be used to finetune a foundational model on either the RTE or BoolQ SuperGLUE
362
+ tasks using Hugging Face implementations of a given model and the `SuperGLUE Hugging Face dataset."""
389
363
390
364
def __init__ (
391
365
self ,
@@ -396,7 +370,6 @@ def __init__(
396
370
model_cfg : Optional [Dict [str , Any ]] = None ,
397
371
task_name : str = DEFAULT_TASK ,
398
372
experiment_tag : str = "default" ,
399
- plot_liveloss : bool = False ,
400
373
):
401
374
"""
402
375
Args:
@@ -414,29 +387,20 @@ def __init__(
414
387
super ().__init__ ()
415
388
self .optimizer_init = optimizer_init
416
389
self .lr_scheduler_init = lr_scheduler_init
417
- self .plot_liveloss = plot_liveloss
418
390
self .pl_lrs_cfg = pl_lrs_cfg or {}
419
391
if task_name in TASK_NUM_LABELS .keys ():
420
392
self .task_name = task_name
421
393
else :
422
394
self .task_name = DEFAULT_TASK
423
- rank_zero_warn (
424
- f"Invalid task_name '{ task_name } '. Proceeding with the default task: '{ DEFAULT_TASK } '"
425
- )
395
+ rank_zero_warn (f"Invalid task_name '{ task_name } '. Proceeding with the default task: '{ DEFAULT_TASK } '" )
426
396
self .num_labels = TASK_NUM_LABELS [self .task_name ]
427
397
self .save_hyperparameters ()
428
398
self .experiment_id = f"{ datetime .now ().strftime ('%Y%m%d_%H%M%S' )} _{ experiment_tag } "
429
399
self .model_cfg = model_cfg or {}
430
- conf = AutoConfig .from_pretrained (
431
- model_name_or_path , num_labels = self .num_labels , local_files_only = False
432
- )
433
- self .model = AutoModelForSequenceClassification .from_pretrained (
434
- model_name_or_path , config = conf
435
- )
400
+ conf = AutoConfig .from_pretrained (model_name_or_path , num_labels = self .num_labels , local_files_only = False )
401
+ self .model = AutoModelForSequenceClassification .from_pretrained (model_name_or_path , config = conf )
436
402
self .model .config .update (self .model_cfg ) # apply model config overrides
437
- self .metric = datasets .load_metric (
438
- "super_glue" , self .task_name , experiment_id = self .experiment_id
439
- )
403
+ self .metric = datasets .load_metric ("super_glue" , self .task_name , experiment_id = self .experiment_id )
440
404
self .no_decay = ["bias" , "LayerNorm.weight" ]
441
405
self .finetuningscheduler_callback = None
442
406
@@ -476,8 +440,8 @@ def validation_epoch_end(self, outputs):
476
440
return loss
477
441
478
442
def init_pgs (self ) -> List [Dict ]:
479
- """Initialize the parameter groups. Used to ensure weight_decay is not applied
480
- to our specified bias parameters when we initialize the optimizer.
443
+ """Initialize the parameter groups. Used to ensure weight_decay is not applied to our specified bias
444
+ parameters when we initialize the optimizer.
481
445
482
446
Returns:
483
447
List[Dict]: A list of parameter group dictionaries.
@@ -510,9 +474,7 @@ def configure_optimizers(self):
510
474
# performance difference)
511
475
optimizer = instantiate_registered_class (args = self .init_pgs (), init = self .optimizer_init )
512
476
scheduler = {
513
- "scheduler" : instantiate_registered_class (
514
- args = optimizer , init = self .lr_scheduler_init
515
- ),
477
+ "scheduler" : instantiate_registered_class (args = optimizer , init = self .lr_scheduler_init ),
516
478
** self .pl_lrs_cfg ,
517
479
}
518
480
return [optimizer ], [scheduler ]
@@ -568,7 +530,7 @@ def configure_callbacks(self):
568
530
callbacks = [
569
531
FinetuningScheduler (ft_schedule = ft_schedule_name , max_depth = 2 ), # type: ignore # noqa
570
532
FTSEarlyStopping (monitor = "val_loss" , min_delta = 0.001 , patience = 2 ), # type: ignore # noqa
571
- FTSCheckpoint (monitor = "val_loss" , save_top_k = 5 ), # type: ignore # noqa
533
+ FTSCheckpoint (monitor = "val_loss" , save_top_k = 5 ), # type: ignore # noqa
572
534
]
573
535
example_logdir = "lightning_logs"
574
536
logger = TensorBoardLogger (example_logdir , name = "fts_explicit" )
0 commit comments