|
7 | 7 | import dataclasses
|
8 | 8 | import datetime
|
9 | 9 | import pickle
|
| 10 | +import socket |
10 | 11 | import time
|
11 | 12 | from collections import deque
|
12 | 13 | from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
@@ -123,6 +124,10 @@ class StatelessProcessGroup:
|
123 | 124 | rank: int
|
124 | 125 | world_size: int
|
125 | 126 | store: torch._C._distributed_c10d.Store
|
| 127 | + |
| 128 | + # stores a reference to the socket so that the file descriptor stays alive |
| 129 | + socket: Optional[socket.socket] |
| 130 | + |
126 | 131 | data_expiration_seconds: int = 3600 # 1 hour
|
127 | 132 |
|
128 | 133 | # dst rank -> counter
|
@@ -234,18 +239,33 @@ def create(
|
234 | 239 | can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
235 | 240 | C, and D can call `StatelessProcessGroup.create` to form another group.
|
236 | 241 | """ # noqa
|
| 242 | + launch_server = rank == 0 |
| 243 | + if launch_server: |
| 244 | + # listen on the specified interface (instead of 0.0.0.0) |
| 245 | + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 246 | + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 247 | + listen_socket.bind((host, port)) |
| 248 | + listen_socket.listen() |
| 249 | + listen_fd = listen_socket.fileno() |
| 250 | + else: |
| 251 | + listen_socket = None |
| 252 | + listen_fd = None |
| 253 | + |
237 | 254 | store = TCPStore(
|
238 | 255 | host_name=host,
|
239 | 256 | port=port,
|
240 | 257 | world_size=world_size,
|
241 |
| - is_master=(rank == 0), |
| 258 | + is_master=launch_server, |
242 | 259 | timeout=datetime.timedelta(seconds=store_timeout),
|
| 260 | + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 |
| 261 | + master_listen_fd=listen_fd, |
243 | 262 | )
|
244 | 263 |
|
245 | 264 | return StatelessProcessGroup(
|
246 | 265 | rank=rank,
|
247 | 266 | world_size=world_size,
|
248 | 267 | store=store,
|
| 268 | + socket=listen_socket, |
249 | 269 | data_expiration_seconds=data_expiration_seconds)
|
250 | 270 |
|
251 | 271 |
|
|
0 commit comments