Skip to content

Commit 8805862

Browse files
committed
update tutorial reqs
1 parent 6400b3b commit 8805862

File tree

4 files changed

+44
-78
lines changed

4 files changed

+44
-78
lines changed

docs/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ docutils>=0.16
66
sphinx-paramlinks>=0.4.0
77
ipython[notebook]
88

9+
# temporarily included until hub available to evaluate finetuning_scheduler
10+
git+git://github.com/speediedan/pytorch-lightning.git@24d3e43568814ec381ac5be91627629808d62081#egg=pytorch-lightning
11+
912
https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#egg=pt-lightning-sphinx-theme
1013

1114
-r ../.actions/requirements.txt

lightning_examples/finetuning-scheduler/.meta.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ build: 1
77
tags:
88
- finetuning
99
description: |
10-
This notebook introduces the FinetuningScheduler callback and demonstrates the use of FinetuningScheduler to finetune
11-
a small foundational model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task
12-
of [SuperGLUE](https://super.gluebenchmark.com/) with iterative earlystopping defined according to a user-specified
13-
schedule. It uses HuggingFace's ``datasets`` and ``transformers`` libraries to retrieve the relevant benchmark data and
14-
foundational model weights.
10+
This notebook introduces the FinetuningScheduler callback and demonstrates the use of FinetuningScheduler to finetune
11+
a small foundational model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task
12+
of [SuperGLUE](https://super.gluebenchmark.com/) with iterative early-stopping defined according to a user-specified
13+
schedule. It uses HuggingFace's ``datasets`` and ``transformers`` libraries to retrieve the relevant benchmark data
14+
and foundational model weights.
1515
requirements:
1616
- transformers
1717
- datasets
18+
- scikit-learn
1819
accelerator:
1920
- GPU

lightning_examples/finetuning-scheduler/finetuning-scheduler.py

Lines changed: 35 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
# ---
32
# jupyter:
43
# jupytext:
@@ -7,7 +6,7 @@
76
# extension: .py
87
# format_name: percent
98
# format_version: '1.3'
10-
# jupytext_version: 1.13.1
9+
# jupytext_version: 1.13.2
1110
# kernelspec:
1211
# display_name: 'Python 3.7.11 64-bit (''pldev_tutorials'': conda)'
1312
# language: python
@@ -84,7 +83,7 @@
8483
# ```
8584

8685
# %% [markdown]
87-
# ## EarlyStopping and Epoch-Driven Phase Transition Criteria
86+
# ## Early-Stopping and Epoch-Driven Phase Transition Criteria
8887
#
8988
#
9089
# By default, ``FTSEarlyStopping`` and epoch-driven
@@ -155,28 +154,26 @@
155154
# The following example demonstrates the use of ``FinetuningScheduler`` to finetune a small foundational model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task of [SuperGLUE](https://super.gluebenchmark.com/). Iterative early-stopping will be applied according to a user-specified schedule.
156155
#
157156
# ``FinetuningScheduler`` can be used to achieve non-trivial model performance improvements in both implicit and explicit scheduling contexts at an also non-trivial computational cost.
157+
#
158158

159159
# %%
160+
import logging
160161
import os
161162
import warnings
162163
from datetime import datetime
163-
from typing import Any, Dict, List, Optional, Tuple, Union
164164
from importlib import import_module
165-
import logging
166-
167-
import torch
168-
from torch.utils.data import DataLoader
165+
from typing import Any, Dict, List, Optional, Tuple, Union
169166

167+
import datasets
170168
import pytorch_lightning as pl
169+
import torch
170+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
171171
from pytorch_lightning.utilities import rank_zero_warn
172172
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY, _Registry
173173
from pytorch_lightning.utilities.exceptions import MisconfigurationException
174-
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
175-
176-
import datasets
174+
from torch.utils.data import DataLoader
177175
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
178176

179-
180177
# %%
181178
# a couple helper functions to prepare code to work with the forthcoming hub and user module registry
182179
MOCK_HUB_REGISTRY = _Registry()
@@ -195,17 +192,13 @@ def module_hub_mock(key: str, require_fqn: bool = False) -> List:
195192
globals()[f"{n}"] = c
196193
registered_list = ", ".join([n for n in MOCK_HUB_REGISTRY.names])
197194
else:
198-
registered_list = ", ".join(
199-
[c.__module__ + "." + c.__name__ for c in MOCK_HUB_REGISTRY.classes]
200-
)
195+
registered_list = ", ".join([c.__module__ + "." + c.__name__ for c in MOCK_HUB_REGISTRY.classes])
201196
print(f"Imported and registered the following callbacks: {registered_list}")
202197

203198

204-
def instantiate_registered_class(
205-
init: Dict[str, Any], args: Optional[Union[Any, Tuple[Any, ...]]] = None
206-
) -> Any:
207-
"""Instantiates a class with the given args and init. Accepts class definitions in the form
208-
of a "class_path" or "callback_key" associated with a _Registry
199+
def instantiate_registered_class(init: Dict[str, Any], args: Optional[Union[Any, Tuple[Any, ...]]] = None) -> Any:
200+
"""Instantiates a class with the given args and init. Accepts class definitions in the form of a "class_path"
201+
or "callback_key" associated with a _Registry.
209202
210203
Args:
211204
init: Dict of the form {"class_path":... or "callback_key":..., "init_args":...}.
@@ -225,17 +218,16 @@ def instantiate_registered_class(
225218
else: # class is expected to be locally defined
226219
args_class = globals()[init["class_path"]]
227220
elif init.get("callback_key", None):
228-
callback_path = CALLBACK_REGISTRY.get(
221+
callback_path = CALLBACK_REGISTRY.get(init["callback_key"], None) or MOCK_HUB_REGISTRY.get(
229222
init["callback_key"], None
230-
) or MOCK_HUB_REGISTRY.get(init["callback_key"], None)
223+
)
231224
assert callback_path, MisconfigurationException(
232225
f'specified callback_key {init["callback_key"]} has not been registered'
233226
)
234227
class_module, class_name = callback_path.__module__, callback_path.__name__
235228
else:
236229
raise MisconfigurationException(
237-
"Neither a class_path nor callback_key were included in a configuration that"
238-
"requires one"
230+
"Neither a class_path nor callback_key were included in a configuration that" "requires one"
239231
)
240232
if not shortcircuit_local:
241233
module = __import__(class_module, fromlist=[class_name])
@@ -266,7 +258,7 @@ def instantiate_registered_class(
266258

267259
# %%
268260
class RteBoolqDataModule(pl.LightningDataModule):
269-
"""A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets"""
261+
"""A ``LightningDataModule`` for using either the RTE or BoolQ SuperGLUE Hugging Face datasets."""
270262

271263
task_text_field_map = {"rte": ["premise", "hypothesis"], "boolq": ["question", "passage"]}
272264
loader_columns = [
@@ -306,12 +298,8 @@ def __init__(
306298
self.text_fields = self.task_text_field_map[self.task_name]
307299
self.num_labels = TASK_NUM_LABELS[self.task_name]
308300
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.tokenizers_parallelism else "false"
309-
self.tokenizer = AutoTokenizer.from_pretrained(
310-
self.model_name_or_path, use_fast=True, local_files_only=False
311-
)
312-
if (
313-
prep_on_init
314-
): # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
301+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, local_files_only=False)
302+
if prep_on_init: # useful if one wants to load datasets as soon as the ``LightningDataModule`` is
315303
# instantiated
316304
self.prepare_data()
317305
self.setup("fit")
@@ -322,9 +310,7 @@ def setup(self, stage):
322310
self.dataset[split] = self.dataset[split].map(
323311
self.convert_to_features, batched=True, remove_columns=["label"]
324312
)
325-
self.columns = [
326-
c for c in self.dataset[split].column_names if c in self.loader_columns
327-
]
313+
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
328314
self.dataset[split].set_format(type="torch", columns=self.columns)
329315

330316
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
@@ -335,9 +321,7 @@ def prepare_data(self):
335321
datasets.load_dataset("super_glue", self.task_name)
336322

337323
def train_dataloader(self):
338-
return DataLoader(
339-
self.dataset["train"], batch_size=self.train_batch_size, **self.dataloader_kwargs
340-
)
324+
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, **self.dataloader_kwargs)
341325

342326
def val_dataloader(self):
343327
if len(self.eval_splits) == 1:
@@ -348,29 +332,21 @@ def val_dataloader(self):
348332
)
349333
elif len(self.eval_splits) > 1:
350334
return [
351-
DataLoader(
352-
self.dataset[x], batch_size=self.eval_batch_size, **self.dataloader_kwargs
353-
)
335+
DataLoader(self.dataset[x], batch_size=self.eval_batch_size, **self.dataloader_kwargs)
354336
for x in self.eval_splits
355337
]
356338

357339
def test_dataloader(self):
358340
if len(self.eval_splits) == 1:
359-
return DataLoader(
360-
self.dataset["test"], batch_size=self.eval_batch_size, **self.dataloader_kwargs
361-
)
341+
return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size, **self.dataloader_kwargs)
362342
elif len(self.eval_splits) > 1:
363343
return [
364-
DataLoader(
365-
self.dataset[x], batch_size=self.eval_batch_size, **self.dataloader_kwargs
366-
)
344+
DataLoader(self.dataset[x], batch_size=self.eval_batch_size, **self.dataloader_kwargs)
367345
for x in self.eval_splits
368346
]
369347

370348
def convert_to_features(self, example_batch):
371-
text_pairs = list(
372-
zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])
373-
)
349+
text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
374350
# Tokenize the text/text pairs
375351
features = self.tokenizer.batch_encode_plus(
376352
text_pairs, max_length=self.max_seq_length, padding="longest", truncation=True
@@ -382,10 +358,8 @@ def convert_to_features(self, example_batch):
382358

383359
# %%
384360
class RteBoolqModule(pl.LightningModule):
385-
"""A ``LightningModule`` that can be used to finetune a foundational
386-
model on either the RTE or BoolQ SuperGLUE tasks using Hugging Face
387-
implementations of a given model and the `SuperGLUE Hugging Face dataset.
388-
"""
361+
"""A ``LightningModule`` that can be used to finetune a foundational model on either the RTE or BoolQ SuperGLUE
362+
tasks using Hugging Face implementations of a given model and the `SuperGLUE Hugging Face dataset."""
389363

390364
def __init__(
391365
self,
@@ -396,7 +370,6 @@ def __init__(
396370
model_cfg: Optional[Dict[str, Any]] = None,
397371
task_name: str = DEFAULT_TASK,
398372
experiment_tag: str = "default",
399-
plot_liveloss: bool = False,
400373
):
401374
"""
402375
Args:
@@ -414,29 +387,20 @@ def __init__(
414387
super().__init__()
415388
self.optimizer_init = optimizer_init
416389
self.lr_scheduler_init = lr_scheduler_init
417-
self.plot_liveloss = plot_liveloss
418390
self.pl_lrs_cfg = pl_lrs_cfg or {}
419391
if task_name in TASK_NUM_LABELS.keys():
420392
self.task_name = task_name
421393
else:
422394
self.task_name = DEFAULT_TASK
423-
rank_zero_warn(
424-
f"Invalid task_name '{task_name}'. Proceeding with the default task: '{DEFAULT_TASK}'"
425-
)
395+
rank_zero_warn(f"Invalid task_name '{task_name}'. Proceeding with the default task: '{DEFAULT_TASK}'")
426396
self.num_labels = TASK_NUM_LABELS[self.task_name]
427397
self.save_hyperparameters()
428398
self.experiment_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{experiment_tag}"
429399
self.model_cfg = model_cfg or {}
430-
conf = AutoConfig.from_pretrained(
431-
model_name_or_path, num_labels=self.num_labels, local_files_only=False
432-
)
433-
self.model = AutoModelForSequenceClassification.from_pretrained(
434-
model_name_or_path, config=conf
435-
)
400+
conf = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels, local_files_only=False)
401+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=conf)
436402
self.model.config.update(self.model_cfg) # apply model config overrides
437-
self.metric = datasets.load_metric(
438-
"super_glue", self.task_name, experiment_id=self.experiment_id
439-
)
403+
self.metric = datasets.load_metric("super_glue", self.task_name, experiment_id=self.experiment_id)
440404
self.no_decay = ["bias", "LayerNorm.weight"]
441405
self.finetuningscheduler_callback = None
442406

@@ -476,8 +440,8 @@ def validation_epoch_end(self, outputs):
476440
return loss
477441

478442
def init_pgs(self) -> List[Dict]:
479-
"""Initialize the parameter groups. Used to ensure weight_decay is not applied
480-
to our specified bias parameters when we initialize the optimizer.
443+
"""Initialize the parameter groups. Used to ensure weight_decay is not applied to our specified bias
444+
parameters when we initialize the optimizer.
481445
482446
Returns:
483447
List[Dict]: A list of parameter group dictionaries.
@@ -510,9 +474,7 @@ def configure_optimizers(self):
510474
# performance difference)
511475
optimizer = instantiate_registered_class(args=self.init_pgs(), init=self.optimizer_init)
512476
scheduler = {
513-
"scheduler": instantiate_registered_class(
514-
args=optimizer, init=self.lr_scheduler_init
515-
),
477+
"scheduler": instantiate_registered_class(args=optimizer, init=self.lr_scheduler_init),
516478
**self.pl_lrs_cfg,
517479
}
518480
return [optimizer], [scheduler]
@@ -568,7 +530,7 @@ def configure_callbacks(self):
568530
callbacks = [
569531
FinetuningScheduler(ft_schedule=ft_schedule_name, max_depth=2), # type: ignore # noqa
570532
FTSEarlyStopping(monitor="val_loss", min_delta=0.001, patience=2), # type: ignore # noqa
571-
FTSCheckpoint(monitor="val_loss", save_top_k=5), # type: ignore # noqa
533+
FTSCheckpoint(monitor="val_loss", save_top_k=5), # type: ignore # noqa
572534
]
573535
example_logdir = "lightning_logs"
574536
logger = TensorBoardLogger(example_logdir, name="fts_explicit")
Loading

0 commit comments

Comments
 (0)