Skip to content

Commit 052f654

Browse files
hanzhi713Alvant
authored andcommitted
[Bugfix] Fix potentially unsafe custom allreduce synchronization (vllm-project#8558)
Signed-off-by: Alvant <[email protected]>
1 parent 735cd4b commit 052f654

File tree

2 files changed

+83
-59
lines changed

2 files changed

+83
-59
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cuda_runtime.h>
77

88
#include <iostream>
9+
#include <array>
910
#include <limits>
1011
#include <map>
1112
#include <unordered_map>
@@ -23,17 +24,23 @@
2324

2425
namespace vllm {
2526

26-
constexpr int kMaxBlocks = 64;
27-
// note: we don't want to use atomics for signals because peer atomics are no
28-
// supported on PCIe links
27+
constexpr int kMaxBlocks = 36;
28+
// Counter may overflow, but it's fine since unsigned int overflow is
29+
// well-defined behavior.
30+
using FlagType = uint32_t;
2931
struct Signal {
30-
alignas(128) uint32_t start[kMaxBlocks][8];
31-
alignas(128) uint32_t end[kMaxBlocks][8];
32+
alignas(128) FlagType self_counter[kMaxBlocks][8];
33+
// Two sets of peer counters are needed for two syncs. The reason is that
34+
// it's possible for peer GPU block to arrive at the second sync point while
35+
// the current GPU block haven't passed the first sync point. Thus, peer GPU
36+
// may write counter+1 while current GPU is busy waiting for counter. We use
37+
// alternating counter array to avoid this possibility.
38+
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
3239
};
3340

3441
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
3542

36-
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
43+
struct __align__(16) RankSignals { Signal* signals[8]; };
3744

3845
// like std::array, but aligned
3946
template <typename T, int sz>
@@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
123130
}
124131
}
125132

126-
// This function is meant to be used as the first synchronization in the all
127-
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
128-
// prior memory accesses. Note: volatile writes will not be reordered against
129-
// other volatile writes.
130-
template <int ngpus>
131-
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
132-
int rank) {
133-
if (threadIdx.x < ngpus) {
134-
// reset flag for next time
135-
self_sg->end[blockIdx.x][threadIdx.x] = 0;
136-
// simultaneously write to the corresponding flag of all ranks.
137-
// Latency = 1 p2p write
138-
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
139-
// wait until we got true from all ranks
140-
while (!self_sg->start[blockIdx.x][threadIdx.x]);
141-
}
142-
__syncthreads();
133+
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
134+
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
135+
"l"(flag_addr));
136+
}
137+
138+
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
139+
FlagType flag;
140+
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
141+
: "=r"(flag)
142+
: "l"(flag_addr));
143+
return flag;
144+
}
145+
146+
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
147+
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
148+
}
149+
150+
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
151+
FlagType flag;
152+
asm volatile("ld.volatile.global.u32 %0, [%1];"
153+
: "=r"(flag)
154+
: "l"(flag_addr));
155+
return flag;
143156
}
144157

145-
// This function is meant to be used as the second or the final synchronization
146-
// barrier in the all reduce kernel. If it's the final synchronization barrier,
147-
// we don't need to make any visibility guarantees for prior memory accesses.
148-
template <int ngpus, bool final_sync = false>
149-
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
150-
int rank) {
151-
__syncthreads();
152-
// eliminate the case that prior writes are not visible after signals become
153-
// visible. Note that I did not managed to make this happen through a lot of
154-
// testing. Might be the case that hardware provides stronger guarantee than
155-
// the memory model.
156-
if constexpr (!final_sync) __threadfence_system();
158+
// is_start: whether this is the very first synchronization barrier.
159+
// need_fence: whether a memory fence is needed. If true, a release-acquire
160+
// semantic is used to enforce memory access order before and after this
161+
// barrier.
162+
template <int ngpus, bool is_start, bool need_fence = false>
163+
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
164+
int rank) {
165+
if constexpr (!is_start) __syncthreads();
166+
static_assert(
167+
!(is_start && need_fence)); // Start barrier shouldn't need fence.
157168
if (threadIdx.x < ngpus) {
158-
// reset flag for next time
159-
self_sg->start[blockIdx.x][threadIdx.x] = 0;
160-
// simultaneously write to the corresponding flag of all ranks.
161-
// Latency = 1 p2p write
162-
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
163-
// wait until we got true from all ranks
164-
while (!self_sg->end[blockIdx.x][threadIdx.x]);
169+
// Increment the counter. Technically we only need one counter, but we use
170+
// multiple per block to eliminate the need to share the counter via smem.
171+
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
172+
// Write the expected counter value to peer and wait for correct value from
173+
// peer.
174+
auto peer_counter_ptr =
175+
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
176+
auto self_counter_ptr =
177+
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
178+
if constexpr (need_fence) {
179+
st_flag_release(peer_counter_ptr, val);
180+
while (ld_flag_acquire(self_counter_ptr) != val);
181+
} else {
182+
st_flag_volatile(peer_counter_ptr, val);
183+
while (ld_flag_volatile(self_counter_ptr) != val);
184+
}
165185
}
166-
if constexpr (!final_sync) __syncthreads();
186+
if constexpr (is_start || need_fence) __syncthreads();
167187
}
168188

169189
template <typename P, int ngpus, typename A>
@@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
178198

179199
template <typename T, int ngpus>
180200
__global__ void __launch_bounds__(512, 1)
181-
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
182-
volatile Signal* self_sg, T* __restrict__ result,
183-
int rank, int size) {
201+
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
202+
T* __restrict__ result, int rank, int size) {
184203
using P = typename packed_t<T>::P;
185204
using A = typename packed_t<T>::A;
186205
// note: we don't reorder the address so the accumulation order is the same
187206
// for all ranks, ensuring bitwise identical results
188207
auto dp = *_dp;
189-
start_sync<ngpus>(sg, self_sg, rank);
208+
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
190209
// do the actual reduction
191210
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
192211
idx += gridDim.x * blockDim.x) {
193212
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
194213
}
195-
end_sync<ngpus, true>(sg, self_sg, rank);
214+
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
196215
}
197216

198217
template <typename P>
199-
DINLINE P* get_tmp_buf(volatile Signal* sg) {
218+
DINLINE P* get_tmp_buf(Signal* sg) {
200219
return (P*)(((Signal*)sg) + 1);
201220
}
202221

203222
template <typename T, int ngpus>
204223
__global__ void __launch_bounds__(512, 1)
205-
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
206-
volatile Signal* self_sg, T* __restrict__ result,
207-
int rank, int size) {
224+
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
225+
T* __restrict__ result, int rank, int size) {
208226
int tid = blockIdx.x * blockDim.x + threadIdx.x;
209227
int stride = gridDim.x * blockDim.x;
210228
using P = typename packed_t<T>::P;
@@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
222240
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
223241
}
224242
auto tmp_out = tmps[0];
225-
start_sync<ngpus>(sg, self_sg, rank);
243+
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
226244
// stage 1: reduce scatter
227245
for (int idx = start + tid; idx < end; idx += stride) {
228246
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
229247
}
230-
end_sync<ngpus>(sg, self_sg, rank);
248+
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
231249

232250
// stage 2: allgather. Note: it's important to match the tid between
233251
// the two stages, because visibility across devices is only guaranteed
@@ -437,6 +455,8 @@ class CustomAllreduce {
437455
#define KL(ngpus, name) \
438456
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
439457
rank_, size);
458+
// TODO(hanzhi713): Threshold is different for A100 and H100.
459+
// Add per device threshold.
440460
#define REDUCE_CASE(ngpus) \
441461
case ngpus: { \
442462
if (world_size_ == 2) { \

csrc/custom_all_reduce_test.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
/**
22
* This is a standalone test for custom allreduce.
33
* To compile, make sure you have MPI and NCCL installed in your system.
4-
* export MPI_HOME=XXX
4+
* export MPI_HOME=xxx
55
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
6-
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
6+
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
77
*
88
* Warning: this C++ test is not designed to be very readable and was used
99
* during the rapid prototyping process.
1010
*
1111
* To run:
12-
* mpirun -np 8 ./custom_all_reduce_test
12+
* mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test
1313
*/
1414
#include <cuda.h>
1515
#include <curand_kernel.h>
@@ -302,15 +302,19 @@ int main(int argc, char** argv) {
302302

303303
bool performance_test = true;
304304
cudaProfilerStart();
305-
// for (int threads : {256, 512}) {
305+
// Uncomment to scan through different block size configs.
306+
// for (int threads : {256, 512, 1024}) {
306307
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
307-
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
308+
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
309+
// performance_test);
308310
// }
309311
// }
312+
// Scan through different sizes to test performance.
310313
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
311314
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
312315
}
313316

314317
cudaProfilerStop();
318+
MPICHECK(MPI_Finalize());
315319
return EXIT_SUCCESS;
316320
}

0 commit comments

Comments
 (0)