Skip to content

Commit 3c6669a

Browse files
committed
fix tests
1 parent c48500e commit 3c6669a

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _log_gpus_metrics(self) -> None:
223223
self.trainer.lightning_module.log(key, mem, prog_bar=False, logger=True)
224224
else:
225225
gpu_id = int(key.split("/")[0].split(":")[1])
226-
if gpu_id in self.trainer.data_parallel_device_ids:
226+
if self.trainer.data_parallel_device_ids and gpu_id in self.trainer.data_parallel_device_ids:
227227
self.trainer.lightning_module.log(
228228
key, mem, prog_bar=False, logger=True, on_step=True, on_epoch=False
229229
)

tests/callbacks/test_gpu_stats_monitor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def test_gpu_stats_monitor(tmpdir):
4747
logger=logger,
4848
)
4949

50-
trainer.fit(model)
50+
with pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."):
51+
trainer.fit(model)
5152
assert trainer.state.finished, f"Training failed with {trainer.state}"
5253

5354
path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
@@ -84,7 +85,9 @@ def test_gpu_stats_monitor_no_queries(tmpdir):
8485
devices=1,
8586
callbacks=[gpu_stats],
8687
)
87-
with mock.patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") as log_metrics_mock:
88+
with mock.patch(
89+
"pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics"
90+
) as log_metrics_mock, pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."):
8891
trainer.fit(model)
8992

9093
assert log_metrics_mock.mock_calls[1:] == [

0 commit comments

Comments
 (0)