-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Fixes access to callback_metrics in ddp_spawn #7916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 40 commits
0567de9
e5fd47c
d6d6c19
6e9715c
a1b3865
9b5a97c
c637998
97fd769
2eb0836
1e60b9f
af4cf60
3b35572
ee719ee
284f057
37672e8
cb4ca61
a455a55
2ea111c
ba891f2
fa147ab
1c96204
1e7fd93
7b8c442
78177bb
2cbb0a7
572c2c1
403b0ce
8260e91
f1eba0a
80b8714
8805d3d
c9e3ec8
788a425
fe6fd3a
b016585
0cd331f
90ff74e
5c06cf2
8ae14e6
4aa8d9a
6c219e1
c441b52
6a6ca3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
from pathlib import Path | ||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch import ScriptModule, Tensor | ||
from torch.nn import Module | ||
|
@@ -387,11 +388,11 @@ def log_dict( | |
on_step: Optional[bool] = None, | ||
on_epoch: Optional[bool] = None, | ||
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 | ||
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 | ||
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 | ||
tbptt_reduce_fx: Optional[Any] = None, # noqa: Remove in 1.6 | ||
tbptt_pad_token: Optional[Any] = None, # noqa: Remove in 1.6 | ||
enable_graph: bool = False, | ||
sync_dist: bool = False, | ||
sync_dist_op: Optional = None, # noqa: Remove in 1.6 | ||
sync_dist_op: Optional[Any] = None, # noqa: Remove in 1.6 | ||
sync_dist_group: Optional[Any] = None, | ||
add_dataloader_idx: bool = True, | ||
) -> None: | ||
|
@@ -1948,3 +1949,30 @@ def model_size(self) -> float: | |
size_mb = os.path.getsize(tmp_name) / 1e6 | ||
os.remove(tmp_name) | ||
return size_mb | ||
|
||
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: | ||
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. | ||
|
||
To avoid issues with memory sharing, we cast the data to numpy. | ||
|
||
Args: | ||
queue: the instance of the queue to append the data. | ||
""" | ||
callback_metrics: dict = apply_to_collection( | ||
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() | ||
) # send as numpy to avoid issues with memory sharing | ||
queue.put(callback_metrics) | ||
|
||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: | ||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. | ||
|
||
To preserve consistency, we cast back the data to ``torch.Tensor``. | ||
|
||
Args: | ||
queue: the instance of the queue from where to get the data. | ||
""" | ||
# NOTE: this must be called in the right order to get the `callback_metrics` | ||
edgarriba marked this conversation as resolved.
Show resolved
Hide resolved
|
||
callback_metrics: dict = queue.get() | ||
self.trainer.callback_metrics.update( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have to update the callback metrics here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll answer for Edgar: The purpose of this PR was to provide a mechanism for users to add items to consume from callbacks in the Hence why we update callback metrics here. Callbacks read metrics off that dictionary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @carmocca thanks so much!! |
||
apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,6 +202,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): | |
self.mp_queue.put(best_model_path) | ||
self.mp_queue.put(last_path) | ||
self.mp_queue.put(results) | ||
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue | ||
|
||
def save(self, state_dict: Dict, path: str) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @edgarriba is there a reason you add_to_q in tpu_spawn, but dont get_from_q? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah nevermind, I think its just bc tpu_spawn doesnt override post_dispatch |
||
xm.save(state_dict, path) | ||
|
Uh oh!
There was an error while loading. Please reload this page.