-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Lite's collectives feature #14996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Lite's collectives feature #14996
Changes from 35 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
96dad36
Collectives initial commit
carmocca 84a3fab
collective methods copy
8c2c03e
collective method cm
e129a3e
more collectives
85d0d95
replace str by valid group
334abfb
new design
1b04a5a
apply suggestions
92ff4ba
carlos behind my back
7c11799
carlos is not here
42f62a3
Add FIXME. Move type annotation
carmocca 44b3d17
Merge branch 'master' into feat/collectives
carmocca 776a39b
Remove object and some of the stranger ones
carmocca 35d3eda
torch_collective
1a1c6a6
create_group()
carmocca ee74638
Fix staticmethod
carmocca 69f5f25
mypy changes
5892def
Pull types into types file
carmocca e6d0e29
Suggestion to not create magically
carmocca cec0130
small reorg
2c26ee4
move default_pg_timeout
carmocca 43ceaff
is deepspeed really just different import?
916eb6c
typo fix
2a846c0
guarding
373e69f
single device
691b84e
single device with more carefullness
36265da
fix imports
20a12a8
horovod stub
cc33633
bloody mypy
abd8a3d
destroy_group
53c3c34
circular import
1d6ef13
rename methods
377caea
XLACollective
carmocca c06ce37
Remove reduce from strategies
carmocca 77410c6
Revert "circular import"
carmocca 84611e7
Revert Ota's dep resolution in favor of this
carmocca 4fdb716
convert op privately
4b8e7d3
no PL integration yet
f36d13b
Merge branch 'master' into feat/collectives
2af9ded
add xla collective
158f97f
send and recv
464164e
fix failing ci
5644f5b
add very basic test
f5bbfac
fix broken test
3e961f8
empty file for now
dd648f3
Merge branch 'master' into feat/collectives
carmocca 90d57c7
Merge branch 'master' into feat/collectives
carmocca d4753e3
Names
carmocca 7b4da93
Fixes
carmocca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from lightning_lite.plugins.collectives.collective import Collective | ||
from lightning_lite.plugins.collectives.deepspeed_collective import DeepSpeedCollective | ||
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective | ||
from lightning_lite.plugins.collectives.torch_collective import TorchCollective | ||
|
||
__all__ = [ | ||
"Collective", | ||
"DeepSpeedCollective", | ||
"TorchCollective", | ||
"SingleDeviceCollective", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Optional | ||
|
||
import torch | ||
from typing_extensions import Self | ||
|
||
from lightning_lite.utilities.types import CollectibleGroup | ||
|
||
|
||
class Collective(ABC): | ||
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None: | ||
self._group_kwargs = group_kwargs | ||
self._group: Optional[CollectibleGroup] = None | ||
if instantiate_group: | ||
self.create_group() | ||
|
||
def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type] | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._group is not None: | ||
raise RuntimeError(f"{type(self).__name__} already owns a group.") | ||
self._group_kwargs.update(kwargs) | ||
self._group = self.init_group(**self._group_kwargs) | ||
return self | ||
|
||
@property | ||
def group(self) -> CollectibleGroup: | ||
if self._group is None: | ||
raise RuntimeError( | ||
f"{type(self).__name__} does not own a group. HINT: try `collective.create_group().group`" | ||
) | ||
return self._group | ||
|
||
@property | ||
@abstractmethod | ||
def rank(self) -> int: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def world_size(self) -> int: | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def init_group( | ||
**kwargs: Any, | ||
) -> CollectibleGroup: | ||
pass | ||
|
||
def teardown(self) -> None: | ||
if self._group is None: | ||
raise RuntimeError(f"{type(self).__name__} does not own a group to destroy.") | ||
self.destroy_group(self._group) | ||
self._group = None | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def destroy_group(group: CollectibleGroup) -> None: | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def convert_to_native_op(op: str) -> Any: | ||
... | ||
|
||
@abstractmethod | ||
def broadcast( | ||
self, | ||
tensor: torch.Tensor, | ||
src: int, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
op: Any, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
dst: int, | ||
op: Any, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_gather( | ||
self, | ||
tensor_list: List[torch.Tensor], | ||
tensor: torch.Tensor, | ||
) -> List[torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def gather( | ||
self, | ||
tensor: torch.Tensor, | ||
gather_list: Optional[List[torch.Tensor]] = None, | ||
dst: int = 0, | ||
) -> Optional[List[torch.Tensor]]: | ||
pass | ||
|
||
@abstractmethod | ||
def scatter( | ||
self, | ||
tensor: torch.Tensor, | ||
scatter_list: Optional[List[torch.Tensor]] = None, | ||
src: int = 0, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce_scatter( | ||
self, | ||
output: torch.Tensor, | ||
input_list: List[torch.Tensor], | ||
op: Any, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def all_to_all( | ||
self, | ||
output_tensor_list: List[torch.Tensor], | ||
input_tensor_list: List[torch.Tensor], | ||
) -> List[torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def barrier( | ||
self, | ||
device_ids: Optional[List[int]] = None, | ||
) -> None: | ||
pass |
152 changes: 152 additions & 0 deletions
152
src/lightning_lite/plugins/collectives/deepspeed_collective.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import datetime | ||
from typing import Any, List, Optional | ||
|
||
import torch | ||
|
||
from lightning_lite.plugins.collectives.collective import Collective | ||
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE | ||
from lightning_lite.utilities.types import CollectibleGroup | ||
|
||
if _DEEPSPEED_AVAILABLE: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the spirit of #12786, we should make these local imports |
||
import deepspeed.comm as dist | ||
|
||
|
||
class DeepSpeedCollective(Collective): | ||
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None: | ||
if not _DEEPSPEED_AVAILABLE: | ||
raise RuntimeError("Torch distributed is not available.") | ||
super().__init__(instantiate_group, **group_kwargs) | ||
|
||
@property | ||
def rank(self) -> int: | ||
return dist.get_rank(self.group) | ||
|
||
@property | ||
def world_size(self) -> int: | ||
return dist.get_world_size(self.group) | ||
|
||
@staticmethod | ||
def init_group( | ||
**kwargs: Any, | ||
) -> CollectibleGroup: | ||
return dist.init_process_group(**kwargs) | ||
|
||
@staticmethod | ||
def destroy_group(group: CollectibleGroup) -> None: | ||
dist.destroy_process_group(group) | ||
|
||
def broadcast( | ||
self, | ||
tensor: torch.Tensor, | ||
src: int, | ||
) -> torch.Tensor: | ||
dist.broadcast(tensor, src, group=self.group) | ||
return tensor | ||
|
||
def all_reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
op: dist.ReduceOp = dist.ReduceOp.SUM, | ||
) -> torch.Tensor: | ||
dist.all_reduce(tensor, op=op, group=self.group) | ||
return tensor | ||
|
||
def reduce( | ||
self, | ||
tensor: torch.Tensor, | ||
dst: int, | ||
op: dist.ReduceOp = dist.ReduceOp.SUM, | ||
) -> torch.Tensor: | ||
dist.reduce(tensor, dst, op=op, group=self.group) | ||
return tensor | ||
|
||
def all_gather( | ||
self, | ||
tensor_list: List[torch.Tensor], | ||
tensor: torch.Tensor, | ||
) -> List[torch.Tensor]: | ||
dist.all_gather(tensor_list, tensor, group=self.group) | ||
return tensor_list | ||
|
||
def gather( | ||
self, | ||
tensor: torch.Tensor, | ||
gather_list: Optional[List[torch.Tensor]] = None, | ||
dst: int = 0, | ||
) -> Optional[List[torch.Tensor]]: | ||
dist.gather(tensor, gather_list, dst, group=self.group) | ||
return gather_list | ||
|
||
def scatter( | ||
self, | ||
tensor: torch.Tensor, | ||
scatter_list: Optional[List[torch.Tensor]] = None, | ||
src: int = 0, | ||
) -> torch.Tensor: | ||
dist.scatter(tensor, scatter_list, src, group=self.group) | ||
return tensor | ||
|
||
def reduce_scatter( | ||
self, | ||
output: torch.Tensor, | ||
input_list: List[torch.Tensor], | ||
op: dist.ReduceOp = dist.ReduceOp.SUM, | ||
) -> torch.Tensor: | ||
dist.reduce_scatter(output, input_list, op=op, group=self.group) | ||
return output | ||
|
||
def all_to_all( | ||
self, | ||
output_tensor_list: List[torch.Tensor], | ||
input_tensor_list: List[torch.Tensor], | ||
) -> List[torch.Tensor]: | ||
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) | ||
return output_tensor_list | ||
|
||
def barrier( | ||
self, | ||
device_ids: Optional[List[int]] = None, | ||
) -> None: | ||
dist.barrier(group=self.group, device_ids=device_ids) | ||
|
||
def all_gather_object( | ||
self, | ||
object_list: List[Any], | ||
object: Any, | ||
) -> List[Any]: | ||
dist.all_gather_object(object_list, object, group=self.group) | ||
return object_list | ||
|
||
def broadcast_object_list( | ||
self, | ||
object_list: List[Any], | ||
src: int, | ||
device: Optional[torch.device] = None, | ||
) -> List[Any]: | ||
dist.broadcast_object_list(object_list, src, group=self.group, device=device) | ||
return object_list | ||
|
||
def gather_object( | ||
self, | ||
obj: Any, | ||
object_gather_list: Optional[List[Any]] = None, | ||
dst: int = 0, | ||
) -> Optional[List[Any]]: | ||
dist.gather_object(obj, object_gather_list, dst, group=self.group) | ||
return object_gather_list | ||
|
||
def scatter_object_list( | ||
self, | ||
scatter_object_output_list: List[Any], | ||
scatter_object_input_list: Optional[List[Any]], | ||
src: int = 0, | ||
) -> List[Any]: | ||
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) | ||
return scatter_object_output_list | ||
|
||
def monitored_barrier( | ||
self, | ||
timeout: Optional[datetime.timedelta] = None, | ||
wait_all_ranks: bool = False, | ||
) -> None: | ||
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.