Skip to content

Lite's collectives feature #14996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
96dad36
Collectives initial commit
carmocca Oct 4, 2022
84a3fab
collective methods copy
Oct 4, 2022
8c2c03e
collective method cm
Oct 4, 2022
e129a3e
more collectives
Oct 5, 2022
85d0d95
replace str by valid group
Oct 5, 2022
334abfb
new design
Oct 5, 2022
1b04a5a
apply suggestions
Oct 5, 2022
92ff4ba
carlos behind my back
Oct 5, 2022
7c11799
carlos is not here
Oct 5, 2022
42f62a3
Add FIXME. Move type annotation
carmocca Oct 5, 2022
44b3d17
Merge branch 'master' into feat/collectives
carmocca Oct 5, 2022
776a39b
Remove object and some of the stranger ones
carmocca Oct 5, 2022
35d3eda
torch_collective
Oct 5, 2022
1a1c6a6
create_group()
carmocca Oct 5, 2022
ee74638
Fix staticmethod
carmocca Oct 5, 2022
69f5f25
mypy changes
Oct 5, 2022
5892def
Pull types into types file
carmocca Oct 5, 2022
e6d0e29
Suggestion to not create magically
carmocca Oct 5, 2022
cec0130
small reorg
Oct 5, 2022
2c26ee4
move default_pg_timeout
carmocca Oct 5, 2022
43ceaff
is deepspeed really just different import?
Oct 5, 2022
916eb6c
typo fix
Oct 5, 2022
2a846c0
guarding
Oct 5, 2022
373e69f
single device
Oct 5, 2022
691b84e
single device with more carefullness
Oct 5, 2022
36265da
fix imports
Oct 5, 2022
20a12a8
horovod stub
Oct 5, 2022
cc33633
bloody mypy
Oct 5, 2022
abd8a3d
destroy_group
Oct 5, 2022
53c3c34
circular import
Oct 5, 2022
1d6ef13
rename methods
Oct 5, 2022
377caea
XLACollective
carmocca Oct 5, 2022
c06ce37
Remove reduce from strategies
carmocca Oct 5, 2022
77410c6
Revert "circular import"
carmocca Oct 5, 2022
84611e7
Revert Ota's dep resolution in favor of this
carmocca Oct 5, 2022
4fdb716
convert op privately
Oct 5, 2022
4b8e7d3
no PL integration yet
Oct 5, 2022
f36d13b
Merge branch 'master' into feat/collectives
Oct 6, 2022
2af9ded
add xla collective
Oct 6, 2022
158f97f
send and recv
Oct 6, 2022
464164e
fix failing ci
Oct 6, 2022
5644f5b
add very basic test
Oct 6, 2022
f5bbfac
fix broken test
Oct 6, 2022
3e961f8
empty file for now
Oct 6, 2022
dd648f3
Merge branch 'master' into feat/collectives
carmocca Oct 8, 2022
90d57c7
Merge branch 'master' into feat/collectives
carmocca Oct 21, 2022
d4753e3
Names
carmocca Oct 21, 2022
7b4da93
Fixes
carmocca Oct 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TPUBf16Precision,
TPUPrecision,
)
from lightning_lite.plugins.collectives import Collective
from lightning_lite.plugins.environments import (
ClusterEnvironment,
KubeflowEnvironment,
Expand All @@ -57,7 +58,7 @@
from lightning_lite.utilities.device_parser import determine_root_gpu_device
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO, Collective]
_PLUGIN_INPUT = Union[_PLUGIN, str]
_PRECISION_INPUT = Literal[16, 32, 64, "bf16"]

Expand Down
17 changes: 12 additions & 5 deletions src/lightning_lite/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.plugins.collectives.deepspeed_collective import DeepSpeedCollective
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective
from lightning_lite.plugins.collectives.torch_collective import TorchCollective
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
Expand All @@ -24,14 +27,18 @@
from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision

__all__ = [
"ClusterEnvironment",
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"ClusterEnvironment",
"Collective",
"DeepSpeedCollective",
"DeepSpeedPrecision",
"DoublePrecision",
"NativeMixedPrecision",
"Precision",
"TPUPrecision",
"SingleDeviceCollective",
"TorchCheckpointIO",
"TorchCollective",
"TPUBf16Precision",
"TPUPrecision",
"XLACheckpointIO",
]
11 changes: 11 additions & 0 deletions src/lightning_lite/plugins/collectives/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.plugins.collectives.deepspeed_collective import DeepSpeedCollective
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective
from lightning_lite.plugins.collectives.torch_collective import TorchCollective

__all__ = [
"Collective",
"DeepSpeedCollective",
"TorchCollective",
"SingleDeviceCollective",
]
138 changes: 138 additions & 0 deletions src/lightning_lite/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch
from typing_extensions import Self

from lightning_lite.utilities.types import CollectibleGroup


class Collective(ABC):
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None:
self._group_kwargs = group_kwargs
self._group: Optional[CollectibleGroup] = None
if instantiate_group:
self.create_group()

def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
if self._group is not None:
raise RuntimeError(f"{type(self).__name__} already owns a group.")
self._group_kwargs.update(kwargs)
self._group = self.init_group(**self._group_kwargs)
return self

@property
def group(self) -> CollectibleGroup:
if self._group is None:
raise RuntimeError(
f"{type(self).__name__} does not own a group. HINT: try `collective.create_group().group`"
)
return self._group

@property
@abstractmethod
def rank(self) -> int:
pass

@property
@abstractmethod
def world_size(self) -> int:
pass

@staticmethod
@abstractmethod
def init_group(
**kwargs: Any,
) -> CollectibleGroup:
pass

def teardown(self) -> None:
if self._group is None:
raise RuntimeError(f"{type(self).__name__} does not own a group to destroy.")
self.destroy_group(self._group)
self._group = None

@staticmethod
@abstractmethod
def destroy_group(group: CollectibleGroup) -> None:
pass

@staticmethod
@abstractmethod
def convert_to_native_op(op: str) -> Any:
...

@abstractmethod
def broadcast(
self,
tensor: torch.Tensor,
src: int,
) -> torch.Tensor:
pass

@abstractmethod
def all_reduce(
self,
tensor: torch.Tensor,
op: Any,
) -> torch.Tensor:
pass

@abstractmethod
def reduce(
self,
tensor: torch.Tensor,
dst: int,
op: Any,
) -> torch.Tensor:
pass

@abstractmethod
def all_gather(
self,
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
) -> List[torch.Tensor]:
pass

@abstractmethod
def gather(
self,
tensor: torch.Tensor,
gather_list: Optional[List[torch.Tensor]] = None,
dst: int = 0,
) -> Optional[List[torch.Tensor]]:
pass

@abstractmethod
def scatter(
self,
tensor: torch.Tensor,
scatter_list: Optional[List[torch.Tensor]] = None,
src: int = 0,
) -> torch.Tensor:
pass

@abstractmethod
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
op: Any,
) -> torch.Tensor:
pass

@abstractmethod
def all_to_all(
self,
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
) -> List[torch.Tensor]:
pass

@abstractmethod
def barrier(
self,
device_ids: Optional[List[int]] = None,
) -> None:
pass
152 changes: 152 additions & 0 deletions src/lightning_lite/plugins/collectives/deepspeed_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import datetime
from typing import Any, List, Optional

import torch

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.utilities.types import CollectibleGroup

if _DEEPSPEED_AVAILABLE:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spirit of #12786, we should make these local imports

import deepspeed.comm as dist


class DeepSpeedCollective(Collective):
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None:
if not _DEEPSPEED_AVAILABLE:
raise RuntimeError("Torch distributed is not available.")
super().__init__(instantiate_group, **group_kwargs)

@property
def rank(self) -> int:
return dist.get_rank(self.group)

@property
def world_size(self) -> int:
return dist.get_world_size(self.group)

@staticmethod
def init_group(
**kwargs: Any,
) -> CollectibleGroup:
return dist.init_process_group(**kwargs)

@staticmethod
def destroy_group(group: CollectibleGroup) -> None:
dist.destroy_process_group(group)

def broadcast(
self,
tensor: torch.Tensor,
src: int,
) -> torch.Tensor:
dist.broadcast(tensor, src, group=self.group)
return tensor

def all_reduce(
self,
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
) -> torch.Tensor:
dist.all_reduce(tensor, op=op, group=self.group)
return tensor

def reduce(
self,
tensor: torch.Tensor,
dst: int,
op: dist.ReduceOp = dist.ReduceOp.SUM,
) -> torch.Tensor:
dist.reduce(tensor, dst, op=op, group=self.group)
return tensor

def all_gather(
self,
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
) -> List[torch.Tensor]:
dist.all_gather(tensor_list, tensor, group=self.group)
return tensor_list

def gather(
self,
tensor: torch.Tensor,
gather_list: Optional[List[torch.Tensor]] = None,
dst: int = 0,
) -> Optional[List[torch.Tensor]]:
dist.gather(tensor, gather_list, dst, group=self.group)
return gather_list

def scatter(
self,
tensor: torch.Tensor,
scatter_list: Optional[List[torch.Tensor]] = None,
src: int = 0,
) -> torch.Tensor:
dist.scatter(tensor, scatter_list, src, group=self.group)
return tensor

def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
op: dist.ReduceOp = dist.ReduceOp.SUM,
) -> torch.Tensor:
dist.reduce_scatter(output, input_list, op=op, group=self.group)
return output

def all_to_all(
self,
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
) -> List[torch.Tensor]:
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list

def barrier(
self,
device_ids: Optional[List[int]] = None,
) -> None:
dist.barrier(group=self.group, device_ids=device_ids)

def all_gather_object(
self,
object_list: List[Any],
object: Any,
) -> List[Any]:
dist.all_gather_object(object_list, object, group=self.group)
return object_list

def broadcast_object_list(
self,
object_list: List[Any],
src: int,
device: Optional[torch.device] = None,
) -> List[Any]:
dist.broadcast_object_list(object_list, src, group=self.group, device=device)
return object_list

def gather_object(
self,
obj: Any,
object_gather_list: Optional[List[Any]] = None,
dst: int = 0,
) -> Optional[List[Any]]:
dist.gather_object(obj, object_gather_list, dst, group=self.group)
return object_gather_list

def scatter_object_list(
self,
scatter_object_output_list: List[Any],
scatter_object_input_list: Optional[List[Any]],
src: int = 0,
) -> List[Any]:
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group)
return scatter_object_output_list

def monitored_barrier(
self,
timeout: Optional[datetime.timedelta] = None,
wait_all_ranks: bool = False,
) -> None:
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks)
Loading