-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Introduce base collective and main subclasses #15016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
cd0287a
Introduce base collective and main subclasses
carmocca dbd2a20
Chery-pick tag fix
carmocca e25e7ff
Sort methods
carmocca 34aaf00
Fix import
carmocca 515e54d
Support passing ReduceOp to TorchCollective
carmocca a2187a6
test
63f0399
working full test
6a6af88
Refactor impl
carmocca 5298d9f
Remove extra argument
carmocca d40a7d1
I think we can assume that whoever runs our tests has torch.distributed?
carmocca 266069b
Fix 1.9 compat
carmocca 5c63779
convert_ops test
021604e
mark
b6f9356
1.9 compatibility
37fc617
niceties
299e5ef
teardown fixture
6109b90
teardown fixture
f3875f9
removing param
fc7ff69
create and destroy tests
5e832c6
Push current fixes before Ota gives me conflicts
carmocca 76d92c9
All gather true test
carmocca 5f5eab4
Fixing tests
carmocca 814719f
Single device tests
carmocca 96fee6b
Simplify tests
carmocca 1354c08
Fix mypy
carmocca 438b4fd
Reduce true test
carmocca 9fafbfd
singledevice strategy test
293d0f2
remove unfinished comment
661abb8
Merge into 1 test to ammortize launch
carmocca f368b89
Tiny docstring change
carmocca eccbd96
One fixture is enough
carmocca 60b8741
Do we even need launch?
carmocca bcb4534
Assert not initialized in fixture
carmocca 7c5b92a
add test, that recreate is possible
308b27d
Revert to launch
carmocca 2b40b96
distributed tests are passing now
7c01765
fix other tests
4c51bff
test two groups
4ecafd1
Wrapper
carmocca d8c31ab
Move method
carmocca 19fb7c3
merge
9766deb
is_initialized/available
carmocca 092dd0b
remove init_kwargs
f1fc251
Replace RunIf
carmocca 28d0d34
finalize tests
1df314e
Docstring
carmocca 31ae5fa
Cleanup logic
carmocca 00e4df7
we are passing now
7fb617d
Typing
carmocca 76c29ae
Environ tests
carmocca ba313b9
Drop instantiate_group
carmocca f03cae8
Docstring
carmocca 97b80fc
warning in docstring
carmocca 08858c3
Cleanup env right after
carmocca 545e911
unify types
e556e5a
Fix mypy
carmocca 67a5ee6
remove debug file
d1e9209
test_two_groups is hanging in a job. try barrier
carmocca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from lightning_lite.plugins.collectives.collective import Collective | ||
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective | ||
from lightning_lite.plugins.collectives.torch_collective import TorchCollective | ||
|
||
__all__ = [ | ||
"Collective", | ||
"TorchCollective", | ||
"SingleDeviceCollective", | ||
] | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Optional | ||
|
||
import torch | ||
from typing_extensions import Self | ||
|
||
from lightning_lite.utilities.types import CollectibleGroup | ||
|
||
|
||
class Collective(ABC): | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None: | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._group_kwargs = group_kwargs | ||
self._group: Optional[CollectibleGroup] = None | ||
if instantiate_group: | ||
self.create_group() | ||
|
||
def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type] | ||
if self._group is not None: | ||
raise RuntimeError(f"{type(self).__name__} already owns a group.") | ||
self._group_kwargs.update(kwargs) | ||
self._group = self.init_group(**self._group_kwargs) | ||
return self | ||
|
||
@property | ||
def group(self) -> CollectibleGroup: | ||
if self._group is None: | ||
raise RuntimeError( | ||
f"{type(self).__name__} does not own a group. HINT: try `collective.create_group().group`" | ||
) | ||
return self._group | ||
|
||
@property | ||
@abstractmethod | ||
def rank(self) -> int: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def world_size(self) -> int: | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def init_group( | ||
**kwargs: Any, | ||
) -> CollectibleGroup: | ||
pass | ||
|
||
def teardown(self) -> None: | ||
if self._group is None: | ||
raise RuntimeError(f"{type(self).__name__} does not own a group to destroy.") | ||
self.destroy_group(self._group) | ||
self._group = None | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def destroy_group(group: CollectibleGroup) -> None: | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def _convert_to_native_op(op: str) -> Any: | ||
pass | ||
|
||
@abstractmethod | ||
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def broadcast( | ||
self, | ||
tensor: torch.Tensor, | ||
src: int, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
op: str, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
dst: int, | ||
op: str, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_gather( | ||
self, | ||
tensor_list: List[torch.Tensor], | ||
tensor: torch.Tensor, | ||
) -> List[torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def gather( | ||
self, | ||
tensor: torch.Tensor, | ||
gather_list: Optional[List[torch.Tensor]] = None, | ||
dst: int = 0, | ||
) -> Optional[List[torch.Tensor]]: | ||
pass | ||
|
||
@abstractmethod | ||
def scatter( | ||
self, | ||
tensor: torch.Tensor, | ||
scatter_list: Optional[List[torch.Tensor]] = None, | ||
src: int = 0, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce_scatter( | ||
self, | ||
output: torch.Tensor, | ||
input_list: List[torch.Tensor], | ||
op: str, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_to_all( | ||
self, | ||
output_tensor_list: List[torch.Tensor], | ||
input_tensor_list: List[torch.Tensor], | ||
) -> List[torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def barrier( | ||
self, | ||
device_ids: Optional[List[int]] = None, | ||
) -> None: | ||
pass |
110 changes: 110 additions & 0 deletions
110
src/lightning_lite/plugins/collectives/single_device_collective.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import Any, List, Optional | ||
|
||
import torch | ||
|
||
from lightning_lite.plugins.collectives.collective import Collective | ||
from lightning_lite.utilities.types import CollectibleGroup | ||
|
||
|
||
class SingleDeviceCollective(Collective): | ||
@property | ||
def rank(self) -> int: | ||
return 0 | ||
|
||
@property | ||
def world_size(self) -> int: | ||
return 1 | ||
|
||
@staticmethod | ||
def init_group( | ||
**kwargs: Any, | ||
) -> CollectibleGroup: | ||
return object() # type: ignore[return-value] | ||
|
||
@staticmethod | ||
def _convert_to_native_op(op: str) -> str: | ||
return op | ||
|
||
@staticmethod | ||
def destroy_group(group: CollectibleGroup) -> None: | ||
pass | ||
|
||
def send(self, *_: Any, **__: Any) -> None: | ||
pass | ||
|
||
def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: | ||
return tensor | ||
|
||
def broadcast( | ||
self, | ||
tensor: torch.Tensor, | ||
*_: Any, | ||
**__: Any, | ||
) -> torch.Tensor: | ||
return tensor | ||
|
||
def all_reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
*_: Any, | ||
**__: Any, | ||
) -> torch.Tensor: | ||
return tensor | ||
|
||
def reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
*_: Any, | ||
**__: Any, | ||
) -> torch.Tensor: | ||
return tensor | ||
|
||
def all_gather( | ||
self, | ||
tensor_list: List[torch.Tensor], | ||
tensor: torch.Tensor, | ||
**__: Any, | ||
) -> List[torch.Tensor]: | ||
return [tensor] | ||
|
||
def gather( | ||
self, | ||
tensor: torch.Tensor, | ||
*_: Any, | ||
**__: Any, | ||
) -> Optional[List[torch.Tensor]]: | ||
return [tensor] | ||
|
||
def scatter( # type: ignore[override] | ||
self, | ||
tensor: torch.Tensor, | ||
scatter_list: List[torch.Tensor], # it doesn't make sense to have a None here for a single device | ||
otaj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
*_: Any, | ||
**__: Any, | ||
) -> torch.Tensor: | ||
return scatter_list[0] | ||
|
||
def reduce_scatter( | ||
self, | ||
output: torch.Tensor, | ||
input_list: List[torch.Tensor], | ||
*_: Any, | ||
**__: Any, | ||
) -> torch.Tensor: | ||
return input_list[0] | ||
|
||
def all_to_all( | ||
self, | ||
output_tensor_list: List[torch.Tensor], | ||
input_tensor_list: List[torch.Tensor], | ||
*_: Any, | ||
**__: Any, | ||
) -> List[torch.Tensor]: | ||
return input_tensor_list | ||
|
||
def barrier( | ||
self, | ||
*_: Any, | ||
**__: Any, | ||
) -> None: | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.