Skip to content

Commit 62ca073

Browse files
carmoccaotaj
and
otaj
authored
Introduce base collective and main subclasses (#15016)
Co-authored-by: otaj <[email protected]>
1 parent 7e518ca commit 62ca073

26 files changed

+713
-42
lines changed

src/lightning_lite/plugins/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
1615
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
1716
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
@@ -28,10 +27,10 @@
2827
"CheckpointIO",
2928
"TorchCheckpointIO",
3029
"XLACheckpointIO",
30+
"Precision",
3131
"DeepSpeedPrecision",
3232
"DoublePrecision",
3333
"NativeMixedPrecision",
34-
"Precision",
3534
"TPUPrecision",
3635
"TPUBf16Precision",
3736
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from lightning_lite.plugins.collectives.collective import Collective
2+
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective
3+
from lightning_lite.plugins.collectives.torch_collective import TorchCollective
4+
5+
__all__ = [
6+
"Collective",
7+
"TorchCollective",
8+
"SingleDeviceCollective",
9+
]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, List, Optional
3+
4+
import torch
5+
from typing_extensions import Self
6+
7+
from lightning_lite.utilities.types import CollectibleGroup
8+
9+
10+
class Collective(ABC):
11+
"""Interface for collective operations.
12+
13+
Supports communications between multiple processes and multiple nodes. A collective owns a group.
14+
15+
.. warning::
16+
This API is experimental and subject to change
17+
"""
18+
19+
def __init__(self) -> None:
20+
self._group: Optional[CollectibleGroup] = None
21+
22+
@property
23+
@abstractmethod
24+
def rank(self) -> int:
25+
...
26+
27+
@property
28+
@abstractmethod
29+
def world_size(self) -> int:
30+
...
31+
32+
@property
33+
def group(self) -> CollectibleGroup:
34+
if self._group is None:
35+
raise RuntimeError(
36+
f"`{type(self).__name__}` does not own a group. HINT: try `collective.create_group().group`"
37+
)
38+
return self._group
39+
40+
@abstractmethod
41+
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
42+
...
43+
44+
@abstractmethod
45+
def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor:
46+
...
47+
48+
@abstractmethod
49+
def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor:
50+
...
51+
52+
@abstractmethod
53+
def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
54+
...
55+
56+
@abstractmethod
57+
def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
58+
...
59+
60+
@abstractmethod
61+
def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
62+
...
63+
64+
@abstractmethod
65+
def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor:
66+
...
67+
68+
@abstractmethod
69+
def all_to_all(
70+
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
71+
) -> List[torch.Tensor]:
72+
...
73+
74+
@abstractmethod
75+
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
76+
...
77+
78+
@abstractmethod
79+
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
80+
...
81+
82+
@abstractmethod
83+
def barrier(self, device_ids: Optional[List[int]] = None) -> None:
84+
...
85+
86+
@classmethod
87+
@abstractmethod
88+
def is_available(cls) -> bool:
89+
...
90+
91+
@classmethod
92+
@abstractmethod
93+
def is_initialized(cls) -> bool:
94+
...
95+
96+
@classmethod
97+
@abstractmethod
98+
def init_group(cls, **kwargs: Any) -> None:
99+
...
100+
101+
@classmethod
102+
@abstractmethod
103+
def new_group(cls, **kwargs: Any) -> CollectibleGroup:
104+
...
105+
106+
@classmethod
107+
@abstractmethod
108+
def destroy_group(cls, group: CollectibleGroup) -> None:
109+
...
110+
111+
@classmethod
112+
@abstractmethod
113+
def _convert_to_native_op(cls, op: str) -> Any:
114+
...
115+
116+
def setup(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
117+
if not self.is_initialized():
118+
self.init_group(**kwargs)
119+
return self
120+
121+
def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
122+
"""Create a group.
123+
124+
This assumes that :meth:`~lightning_lite.plugins.collectives.Collective.init_group` has been
125+
called already by the user.
126+
"""
127+
if self._group is not None:
128+
raise RuntimeError(f"`{type(self).__name__}` already owns a group.")
129+
self._group = self.new_group(**kwargs)
130+
return self
131+
132+
def teardown(self) -> Self: # type: ignore[valid-type]
133+
if self._group is None:
134+
raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.")
135+
self.destroy_group(self._group)
136+
self._group = None
137+
return self
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Any, List
2+
3+
import torch
4+
5+
from lightning_lite.plugins.collectives.collective import Collective
6+
from lightning_lite.utilities.types import CollectibleGroup
7+
8+
9+
class SingleDeviceCollective(Collective):
10+
@property
11+
def rank(self) -> int:
12+
return 0
13+
14+
@property
15+
def world_size(self) -> int:
16+
return 1
17+
18+
def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
19+
return tensor
20+
21+
def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
22+
return tensor
23+
24+
def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
25+
return tensor
26+
27+
def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]:
28+
return [tensor]
29+
30+
def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]:
31+
return [tensor]
32+
33+
def scatter(
34+
self,
35+
tensor: torch.Tensor,
36+
scatter_list: List[torch.Tensor],
37+
*_: Any,
38+
**__: Any,
39+
) -> torch.Tensor:
40+
return scatter_list[0]
41+
42+
def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
43+
return input_list[0]
44+
45+
def all_to_all(
46+
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any
47+
) -> List[torch.Tensor]:
48+
return input_tensor_list
49+
50+
def send(self, *_: Any, **__: Any) -> None:
51+
pass
52+
53+
def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
54+
return tensor
55+
56+
def barrier(self, *_: Any, **__: Any) -> None:
57+
pass
58+
59+
@classmethod
60+
def is_available(cls) -> bool:
61+
return True # vacuous truth
62+
63+
@classmethod
64+
def is_initialized(cls) -> bool:
65+
return True # vacuous truth
66+
67+
@classmethod
68+
def init_group(cls, **_: Any) -> None:
69+
pass
70+
71+
@classmethod
72+
def new_group(cls, **_: Any) -> CollectibleGroup:
73+
return object() # type: ignore[return-value]
74+
75+
@classmethod
76+
def destroy_group(cls, group: CollectibleGroup) -> None:
77+
pass
78+
79+
@classmethod
80+
def _convert_to_native_op(cls, op: str) -> str:
81+
return op

0 commit comments

Comments
 (0)