Skip to content

Commit d7da388

Browse files
committed
add fts_superglue pl_example graceful exit and horovod test condition
1 parent 4c4c747 commit d7da388

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

pl_examples/basic_examples/fts_superglue.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@
3838
from datetime import datetime
3939
from typing import Any, Dict, List, Optional
4040

41-
import datasets
4241
import torch
4342
from torch.utils.data import DataLoader
44-
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
4543

4644
import pytorch_lightning as pl
45+
from pl_examples import _HF_AVAILABLE
4746
from pytorch_lightning.utilities import rank_zero_warn
4847
from pytorch_lightning.utilities.cli import instantiate_class, LightningCLI
4948

49+
if _HF_AVAILABLE:
50+
import datasets
51+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
52+
5053
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
5154
DEFAULT_TASK = "rte"
5255

@@ -286,6 +289,9 @@ def add_arguments_to_parser(self, parser):
286289

287290

288291
def cli_main() -> None:
292+
if not _HF_AVAILABLE: # pragma: no cover
293+
print("Running the fts_superglue example requires the `transformers` and `datasets` packages from Hugging Face")
294+
return
289295
# every configuration of this example depends upon a shared set of defaults.
290296
default_config_file = os.path.join(os.path.dirname(__file__), "config", "fts", "fts_defaults.yaml")
291297
_ = CustLightningCLI(

tests/callbacks/test_finetuning_scheduler_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_finetuningscheduling_misconfiguration(tmpdir, callbacks: List[Callback]
414414
[
415415
pytest.param("ddp2", 2, None, marks=RunIf(min_gpus=2)),
416416
pytest.param("ddp_fully_sharded", 2, None, marks=RunIf(min_gpus=2)),
417-
("horovod", None, None),
417+
pytest.param("horovod", None, None, marks=RunIf(min_gpus=2)),
418418
pytest.param("ddp", 2, "deepspeed_stage_2", marks=RunIf(deepspeed=True, min_gpus=2)),
419419
],
420420
)

0 commit comments

Comments
 (0)