Skip to content

Fail the test when a DeprecationWarning is raised #9940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b9e35e6
small fixes for device stats revamp
daniellepintz Oct 16, 2021
318a4f0
fix doc
daniellepintz Oct 16, 2021
8f2bf06
fix test
daniellepintz Oct 16, 2021
455ec02
small fix
daniellepintz Oct 16, 2021
db042f1
fix docs
daniellepintz Oct 17, 2021
a4fb139
Fail the test when a `DeprecationWarning` is raised
carmocca Oct 15, 2021
40a8ebf
Fixes for LightningCLI
carmocca Oct 15, 2021
31cf068
Fixes
carmocca Oct 15, 2021
e759ce9
Undo changes
carmocca Oct 18, 2021
cbed898
Merge remote-tracking branch 'daniellepintz/gpu_metrics' into ci/fail…
carmocca Oct 18, 2021
ec6eca6
Fix one
carmocca Oct 18, 2021
d4ac3b8
Merge branch 'master' into ci/fail-on-deprecation
carmocca Oct 19, 2021
5ee5902
Remove comet
carmocca Oct 19, 2021
27ce421
Remove mlflow
carmocca Oct 19, 2021
bd61df9
Only ours
carmocca Oct 19, 2021
86de0ba
Fix accelerator connector tests
carmocca Oct 19, 2021
f56f3b6
Fix
carmocca Oct 19, 2021
d822097
task_idx -> local_rank
carmocca Oct 19, 2021
9eb0661
Merge branch 'master' into ci/fail-on-deprecation
carmocca Oct 26, 2021
a0300aa
Merge branch 'master' into ci/fail-on-deprecation
carmocca Oct 26, 2021
4737071
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 9, 2021
e379338
undo changge
carmocca Nov 9, 2021
c9149c9
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 10, 2021
343a9c3
Remove ignore after merging master
carmocca Nov 10, 2021
df8f512
Update test
carmocca Nov 10, 2021
b1c78af
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 12, 2021
ae5227f
Merge 10470
carmocca Nov 12, 2021
aae8046
Fix tests
carmocca Nov 15, 2021
e7f0fb0
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 15, 2021
1c8fe9f
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 15, 2021
1340242
Remove deprecated usage for enum
carmocca Nov 16, 2021
167a362
Disable propagate
carmocca Nov 16, 2021
afe98e6
Update accel connector tests
carmocca Nov 16, 2021
bfb18da
Fix a few
carmocca Nov 16, 2021
af0eb7c
Updatte
carmocca Nov 16, 2021
ce60e93
Try fixing IPUs
carmocca Nov 16, 2021
4188cc7
Fix last few
carmocca Nov 16, 2021
96ddbb7
Fix
carmocca Nov 16, 2021
84f000e
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 16, 2021
952cd21
Fix
carmocca Nov 16, 2021
2ed3067
Whitespace
carmocca Nov 16, 2021
a5e13ba
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 16, 2021
f8135a5
Update code
carmocca Nov 16, 2021
59a1473
Add workaround
carmocca Nov 17, 2021
3b4f7c4
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 17, 2021
bfed87c
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 17, 2021
26dfd56
Bad merge
carmocca Nov 17, 2021
952429a
Merge branch 'master' into ci/fail-on-deprecation
carmocca Nov 17, 2021
459187f
Update benchmarks
carmocca Nov 17, 2021
0c4fbf7
Add hanging fix
carmocca Nov 17, 2021
4d4dbdb
skip
carmocca Nov 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# if root logger has handlers, propagate messages up and let root logger process them
if not _root_logger.hasHandlers():
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ def supported_types() -> List[str]:
class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
"""Define type of training strategy.

>>> # you can match the type with string
>>> DistributedType.DDP == 'ddp'
True
>>> # which is case invariant
>>> DistributedType.DDP2 in ('ddp2', )
True

Deprecated since v1.6.0 and will be removed in v1.8.0.

Use `_StrategyType` instead.
Expand Down
9 changes: 8 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@ python_files =
test_*.py
# doctest_plus = disabled
addopts =
--strict
--strict-markers
--doctest-modules
--color=yes
--disable-pytest-warnings
filterwarnings =
# error out on our deprecation warnings - ensures the code and tests are kept up-to-date
error::pytorch_lightning.utilities.warnings.LightningDeprecationWarning
# warnings from deprecated modules on import
# TODO: remove in 1.7
ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.decorators
ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.memory

junit_duration_report = call

Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
return delay_dispatch

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu")))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=CustomPlugin(device=torch.device("cpu")))
trainer.fit(model)


Expand Down
22 changes: 13 additions & 9 deletions tests/callbacks/test_gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
def test_gpu_stats_monitor(tmpdir):
"""Test GPU stats are logged using a logger."""
model = BoringModel()
gpu_stats = GPUStatsMonitor(intra_step_time=True)
with pytest.deprecated_call(match="GPUStatsMonitor` callback was deprecated in v1.5"):
gpu_stats = GPUStatsMonitor(intra_step_time=True)
logger = CSVLogger(tmpdir)
log_every_n_steps = 2

Expand Down Expand Up @@ -65,12 +66,13 @@ def test_gpu_stats_monitor(tmpdir):
def test_gpu_stats_monitor_no_queries(tmpdir):
"""Test GPU logger doesn't fail if no "nvidia-smi" queries are to be performed."""
model = BoringModel()
gpu_stats = GPUStatsMonitor(
memory_utilization=False,
gpu_utilization=False,
intra_step_time=True,
inter_step_time=True,
)
with pytest.deprecated_call(match="GPUStatsMonitor` callback was deprecated in v1.5"):
gpu_stats = GPUStatsMonitor(
memory_utilization=False,
gpu_utilization=False,
intra_step_time=True,
inter_step_time=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand Down Expand Up @@ -101,7 +103,8 @@ def test_gpu_stats_monitor_cpu_machine(tmpdir):
def test_gpu_stats_monitor_no_logger(tmpdir):
"""Test GPUStatsMonitor with no logger in Trainer."""
model = BoringModel()
gpu_stats = GPUStatsMonitor()
with pytest.deprecated_call(match="GPUStatsMonitor` callback was deprecated in v1.5"):
gpu_stats = GPUStatsMonitor()

trainer = Trainer(default_root_dir=tmpdir, callbacks=[gpu_stats], max_epochs=1, gpus=1, logger=False)

Expand All @@ -113,7 +116,8 @@ def test_gpu_stats_monitor_no_logger(tmpdir):
def test_gpu_stats_monitor_no_gpu_warning(tmpdir):
"""Test GPUStatsMonitor raises a warning when not training on GPU device."""
model = BoringModel()
gpu_stats = GPUStatsMonitor()
with pytest.deprecated_call(match="GPUStatsMonitor` callback was deprecated in v1.5"):
gpu_stats = GPUStatsMonitor()

trainer = Trainer(default_root_dir=tmpdir, callbacks=[gpu_stats], max_steps=1, gpus=None)

Expand Down
14 changes: 10 additions & 4 deletions tests/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from functools import partial

import pytest

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
from tests.helpers.boring_model import BoringModel
Expand Down Expand Up @@ -46,7 +48,8 @@ def call(hook, *_, **__):
limit_val_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.fit(model)

ckpt_path = trainer.checkpoint_callback.best_model_path

Expand All @@ -60,8 +63,11 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.test(model)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.predict(model)

assert checker == hooks
23 changes: 14 additions & 9 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_swa_callback_scheduler_step(tmpdir, interval: str):

def test_swa_warns(tmpdir, caplog):
model = SwaTestModel(interval="step")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, stochastic_weight_avg=True)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging())
with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"):
trainer.fit(model)
assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
Expand Down Expand Up @@ -199,14 +199,19 @@ def configure_optimizers(self):
return optimizer

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
stochastic_weight_avg=stochastic_weight_avg,
limit_train_batches=4,
limit_val_batches=4,
max_epochs=2,
)
kwargs = {
"default_root_dir": tmpdir,
"callbacks": StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
"stochastic_weight_avg": stochastic_weight_avg,
"limit_train_batches": 4,
"limit_val_batches": 4,
"max_epochs": 2,
}
if stochastic_weight_avg:
with pytest.deprecated_call(match=r"stochastic_weight_avg=True\)` is deprecated in v1.5"):
trainer = Trainer(**kwargs)
else:
trainer = Trainer(**kwargs)
trainer.fit(model)
if use_callbacks or stochastic_weight_avg:
assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1
Expand Down
15 changes: 10 additions & 5 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,20 +442,24 @@ def test_hyperparameters_saving():


def test_define_as_dataclass():
class BoringDataModule(LightningDataModule):
def __init__(self, foo=None):
super().__init__()

# makes sure that no functionality is broken and the user can still manually make
# super().__init__ call with parameters
# also tests all the dataclass features that can be enabled without breaking anything
@dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False)
class BoringDataModule1(LightningDataModule):
class BoringDataModule1(BoringDataModule):
batch_size: int
dims: int = 2
foo: int = 2

def __post_init__(self):
super().__init__(dims=self.dims)
super().__init__(foo=self.foo)

# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
# __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule1(batch_size=64).dims == 2
assert BoringDataModule1(batch_size=64).foo == 2
assert BoringDataModule1(batch_size=32)
assert hasattr(BoringDataModule1, "__repr__")
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
Expand All @@ -477,7 +481,8 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
trainer = Trainer(prepare_data_per_node=False)
with pytest.deprecated_call(match="prepare_data_per_node` with the trainer flag is deprecated"):
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer._data_connector.prepare_data()
22 changes: 12 additions & 10 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,6 @@ def test_v1_7_0_deprecated_max_steps_none(tmpdir):


def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="a")
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
_ = trainer.resume_from_checkpoint

# test resume_from_checkpoint still works until v1.7 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
Expand All @@ -425,14 +418,22 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
ckpt_path = trainer.checkpoint_callback.best_model_path

callback = OldStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
_ = trainer.resume_from_checkpoint
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.validate(model=model, ckpt_path=ckpt_path)
assert callback.state == 222
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.fit(model)
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
trainer.fit(model)
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
Expand All @@ -445,7 +446,8 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):

# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
model = BoringModel()
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")

Expand Down
24 changes: 19 additions & 5 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
with mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock):
_test_loggers_fit_test(tmpdir, NeptuneLogger)

with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"):
with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"), pytest.deprecated_call(
match="TestTubeLogger is deprecated since v1.5"
):
_test_loggers_fit_test(tmpdir, TestTubeLogger)

with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb:
Expand Down Expand Up @@ -176,7 +178,9 @@ def test_loggers_save_dir_and_weights_save_path_all(tmpdir, monkeypatch):
):
_test_loggers_save_dir_and_weights_save_path(tmpdir, MLFlowLogger)

with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"):
with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"), pytest.deprecated_call(
match="TestTubeLogger is deprecated since v1.5"
):
_test_loggers_save_dir_and_weights_save_path(tmpdir, TestTubeLogger)

with mock.patch("pytorch_lightning.loggers.wandb.wandb"):
Expand Down Expand Up @@ -247,7 +251,11 @@ def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class):
"""
_patch_comet_atexit(monkeypatch)
try:
_test_loggers_pickle(tmpdir, monkeypatch, logger_class)
if logger_class is TestTubeLogger:
with pytest.deprecated_call(match="TestTubeLogger is deprecated since v1.5"):
_test_loggers_pickle(tmpdir, monkeypatch, logger_class)
else:
_test_loggers_pickle(tmpdir, monkeypatch, logger_class)
except (ImportError, ModuleNotFoundError):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")

Expand Down Expand Up @@ -327,7 +335,11 @@ def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
_patch_comet_atexit(monkeypatch)
try:
_test_logger_created_on_rank_zero_only(tmpdir, logger_class)
if logger_class is TestTubeLogger:
with pytest.deprecated_call(match="TestTubeLogger is deprecated since v1.5"):
_test_logger_created_on_rank_zero_only(tmpdir, logger_class)
else:
_test_logger_created_on_rank_zero_only(tmpdir, logger_class)
except (ImportError, ModuleNotFoundError):
pytest.xfail(f"multi-process test requires {logger_class.__class__} dependencies to be installed.")

Expand Down Expand Up @@ -385,7 +397,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)

# TestTube
with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"):
with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"), pytest.deprecated_call(
match="TestTubeLogger is deprecated since v1.5"
):
logger = _instantiate_logger(TestTubeLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0}, global_step=0)
Expand Down
15 changes: 10 additions & 5 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ def training_step(self, batch, batch_idx):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
trainer.fit(model)
with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"):
trainer.fit(model)
saved_ckpt = {
"callbacks": ANY,
"epoch": 1,
Expand Down Expand Up @@ -589,7 +590,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
enable_model_summary=False,
callbacks=[HookedCallback([])],
)
trainer.fit(model)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.fit(model)
best_model_path = trainer.checkpoint_callback.best_model_path

# resume from checkpoint with HookedModel
Expand All @@ -611,7 +613,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
trainer.fit(model, ckpt_path=best_model_path)
with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"):
trainer.fit(model, ckpt_path=best_model_path)
saved_ckpt = {
"callbacks": ANY,
"epoch": 2, # TODO: wrong saved epoch
Expand Down Expand Up @@ -706,7 +709,8 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
dict(name="Callback.on_init_end", args=(trainer,)),
]
fn = getattr(trainer, verb)
fn(model, verbose=False)
with pytest.deprecated_call(match=f"on_{dataloader}_dataloader` is deprecated in v1.5"):
fn(model, verbose=False)
hooks = [
dict(name="train", args=(False,)),
dict(name=f"on_{noun}_model_eval"),
Expand Down Expand Up @@ -750,7 +754,8 @@ def test_trainer_model_hook_system_predict(tmpdir):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
trainer.predict(model)
with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"):
trainer.predict(model)
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_precision_selection_raises(monkeypatch):
with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises(
MisconfigurationException, match="Sharded plugins are not supported with apex"
):
Trainer(amp_backend="apex", precision=16, gpus=1, accelerator="ddp_fully_sharded")
Trainer(amp_backend="apex", precision=16, gpus=1, strategy="ddp_fully_sharded")

import pytorch_lightning.plugins.precision.apex_amp as apex

Expand Down
Loading