Skip to content

Commit 82bcdbf

Browse files
committed
add small test for add/get queue
1 parent 0cd331f commit 82bcdbf

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1964,7 +1964,7 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
19641964
queue.put(callback_metrics)
19651965

19661966
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
1967-
"""Retrieve the `trainer.callback_metrics` dictionary from the given queue.
1967+
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue.
19681968
19691969
To preserve consistency, we cast back the data to ``torch.Tensor``.
19701970

tests/plugins/test_ddp_spawn_plugin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def validation_step(self, batch, batch_idx):
3737
self.log(self.name, self.val)
3838
return super().validation_step(batch, batch_idx)
3939

40+
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
41+
queue.put("test_val")
42+
return super().add_to_queue(queue)
43+
44+
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
45+
self.test_val = queue.get()
46+
return super().get_from_queue(queue)
47+
4048

4149
@RunIf(skip_windows=True)
4250
def test_ddp_cpu():
@@ -70,3 +78,4 @@ def test_ddp_spawn_extra_parameters(tmpdir):
7078

7179
trainer.fit(model, datamodule=dm)
7280
assert trainer.callback_metrics[val_name] == torch.tensor(val)
81+
assert model.test_val == "test_val"

0 commit comments

Comments
 (0)