6
6
#include < cuda_runtime.h>
7
7
8
8
#include < iostream>
9
+ #include < array>
9
10
#include < limits>
10
11
#include < map>
11
12
#include < unordered_map>
23
24
24
25
namespace vllm {
25
26
26
- constexpr int kMaxBlocks = 64 ;
27
- // note: we don't want to use atomics for signals because peer atomics are no
28
- // supported on PCIe links
27
+ constexpr int kMaxBlocks = 36 ;
28
+ // Counter may overflow, but it's fine since unsigned int overflow is
29
+ // well-defined behavior.
30
+ using FlagType = uint32_t ;
29
31
struct Signal {
30
- alignas (128 ) uint32_t start[kMaxBlocks ][8 ];
31
- alignas (128 ) uint32_t end[kMaxBlocks ][8 ];
32
+ alignas (128 ) FlagType self_counter[kMaxBlocks ][8 ];
33
+ // Two sets of peer counters are needed for two syncs. The reason is that
34
+ // it's possible for peer GPU block to arrive at the second sync point while
35
+ // the current GPU block haven't passed the first sync point. Thus, peer GPU
36
+ // may write counter+1 while current GPU is busy waiting for counter. We use
37
+ // alternating counter array to avoid this possibility.
38
+ alignas (128 ) FlagType peer_counter[2 ][kMaxBlocks ][8 ];
32
39
};
33
40
34
41
struct __align__ (16 ) RankData { const void * __restrict__ ptrs[8 ]; };
35
42
36
- struct __align__ (16 ) RankSignals { volatile Signal* signals[8 ]; };
43
+ struct __align__ (16 ) RankSignals { Signal* signals[8 ]; };
37
44
38
45
// like std::array, but aligned
39
46
template <typename T, int sz>
@@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
123
130
}
124
131
}
125
132
126
- // This function is meant to be used as the first synchronization in the all
127
- // reduce kernel. Thus, it doesn't need to make any visibility guarantees for
128
- // prior memory accesses. Note: volatile writes will not be reordered against
129
- // other volatile writes.
130
- template <int ngpus>
131
- DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
132
- int rank) {
133
- if (threadIdx .x < ngpus) {
134
- // reset flag for next time
135
- self_sg->end [blockIdx .x ][threadIdx .x ] = 0 ;
136
- // simultaneously write to the corresponding flag of all ranks.
137
- // Latency = 1 p2p write
138
- sg.signals [threadIdx .x ]->start [blockIdx .x ][rank] = 1 ;
139
- // wait until we got true from all ranks
140
- while (!self_sg->start [blockIdx .x ][threadIdx .x ]);
141
- }
142
- __syncthreads ();
133
+ static DINLINE void st_flag_release (FlagType* flag_addr, FlagType flag) {
134
+ asm volatile (" st.release.sys.global.u32 [%1], %0;" ::" r" (flag),
135
+ " l" (flag_addr));
136
+ }
137
+
138
+ static DINLINE FlagType ld_flag_acquire (FlagType* flag_addr) {
139
+ FlagType flag;
140
+ asm volatile (" ld.acquire.sys.global.u32 %0, [%1];"
141
+ : " =r" (flag)
142
+ : " l" (flag_addr));
143
+ return flag;
144
+ }
145
+
146
+ static DINLINE void st_flag_volatile (FlagType* flag_addr, FlagType flag) {
147
+ asm volatile (" st.volatile.global.u32 [%1], %0;" ::" r" (flag), " l" (flag_addr));
148
+ }
149
+
150
+ static DINLINE FlagType ld_flag_volatile (FlagType* flag_addr) {
151
+ FlagType flag;
152
+ asm volatile (" ld.volatile.global.u32 %0, [%1];"
153
+ : " =r" (flag)
154
+ : " l" (flag_addr));
155
+ return flag;
143
156
}
144
157
145
- // This function is meant to be used as the second or the final synchronization
146
- // barrier in the all reduce kernel. If it's the final synchronization barrier,
147
- // we don't need to make any visibility guarantees for prior memory accesses.
148
- template <int ngpus, bool final_sync = false >
149
- DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
150
- int rank) {
151
- __syncthreads ();
152
- // eliminate the case that prior writes are not visible after signals become
153
- // visible. Note that I did not managed to make this happen through a lot of
154
- // testing. Might be the case that hardware provides stronger guarantee than
155
- // the memory model.
156
- if constexpr (!final_sync) __threadfence_system ();
158
+ // is_start: whether this is the very first synchronization barrier.
159
+ // need_fence: whether a memory fence is needed. If true, a release-acquire
160
+ // semantic is used to enforce memory access order before and after this
161
+ // barrier.
162
+ template <int ngpus, bool is_start, bool need_fence = false >
163
+ DINLINE void multi_gpu_barrier (const RankSignals& sg, Signal* self_sg,
164
+ int rank) {
165
+ if constexpr (!is_start) __syncthreads ();
166
+ static_assert (
167
+ !(is_start && need_fence)); // Start barrier shouldn't need fence.
157
168
if (threadIdx .x < ngpus) {
158
- // reset flag for next time
159
- self_sg->start [blockIdx .x ][threadIdx .x ] = 0 ;
160
- // simultaneously write to the corresponding flag of all ranks.
161
- // Latency = 1 p2p write
162
- sg.signals [threadIdx .x ]->end [blockIdx .x ][rank] = 1 ;
163
- // wait until we got true from all ranks
164
- while (!self_sg->end [blockIdx .x ][threadIdx .x ]);
169
+ // Increment the counter. Technically we only need one counter, but we use
170
+ // multiple per block to eliminate the need to share the counter via smem.
171
+ auto val = self_sg->self_counter [blockIdx .x ][threadIdx .x ] += 1 ;
172
+ // Write the expected counter value to peer and wait for correct value from
173
+ // peer.
174
+ auto peer_counter_ptr =
175
+ &sg.signals [threadIdx .x ]->peer_counter [val % 2 ][blockIdx .x ][rank];
176
+ auto self_counter_ptr =
177
+ &self_sg->peer_counter [val % 2 ][blockIdx .x ][threadIdx .x ];
178
+ if constexpr (need_fence) {
179
+ st_flag_release (peer_counter_ptr, val);
180
+ while (ld_flag_acquire (self_counter_ptr) != val);
181
+ } else {
182
+ st_flag_volatile (peer_counter_ptr, val);
183
+ while (ld_flag_volatile (self_counter_ptr) != val);
184
+ }
165
185
}
166
- if constexpr (!final_sync ) __syncthreads ();
186
+ if constexpr (is_start || need_fence ) __syncthreads ();
167
187
}
168
188
169
189
template <typename P, int ngpus, typename A>
@@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
178
198
179
199
template <typename T, int ngpus>
180
200
__global__ void __launch_bounds__ (512 , 1 )
181
- cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
182
- volatile Signal* self_sg, T* __restrict__ result,
183
- int rank, int size) {
201
+ cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
202
+ T* __restrict__ result, int rank, int size) {
184
203
using P = typename packed_t <T>::P;
185
204
using A = typename packed_t <T>::A;
186
205
// note: we don't reorder the address so the accumulation order is the same
187
206
// for all ranks, ensuring bitwise identical results
188
207
auto dp = *_dp;
189
- start_sync <ngpus>(sg, self_sg, rank);
208
+ multi_gpu_barrier <ngpus, true >(sg, self_sg, rank);
190
209
// do the actual reduction
191
210
for (int idx = blockIdx .x * blockDim .x + threadIdx .x ; idx < size;
192
211
idx += gridDim .x * blockDim .x ) {
193
212
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs [0 ], idx);
194
213
}
195
- end_sync <ngpus, true >(sg, self_sg, rank);
214
+ multi_gpu_barrier <ngpus, false >(sg, self_sg, rank);
196
215
}
197
216
198
217
template <typename P>
199
- DINLINE P* get_tmp_buf (volatile Signal* sg) {
218
+ DINLINE P* get_tmp_buf (Signal* sg) {
200
219
return (P*)(((Signal*)sg) + 1 );
201
220
}
202
221
203
222
template <typename T, int ngpus>
204
223
__global__ void __launch_bounds__ (512 , 1 )
205
- cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
206
- volatile Signal* self_sg, T* __restrict__ result,
207
- int rank, int size) {
224
+ cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
225
+ T* __restrict__ result, int rank, int size) {
208
226
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
209
227
int stride = gridDim .x * blockDim .x ;
210
228
using P = typename packed_t <T>::P;
@@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
222
240
tmps[i] = get_tmp_buf<P>(sg.signals [target]);
223
241
}
224
242
auto tmp_out = tmps[0 ];
225
- start_sync <ngpus>(sg, self_sg, rank);
243
+ multi_gpu_barrier <ngpus, true >(sg, self_sg, rank);
226
244
// stage 1: reduce scatter
227
245
for (int idx = start + tid; idx < end; idx += stride) {
228
246
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
229
247
}
230
- end_sync <ngpus>(sg, self_sg, rank);
248
+ multi_gpu_barrier <ngpus, false , true >(sg, self_sg, rank);
231
249
232
250
// stage 2: allgather. Note: it's important to match the tid between
233
251
// the two stages, because visibility across devices is only guaranteed
@@ -437,6 +455,8 @@ class CustomAllreduce {
437
455
#define KL (ngpus, name ) \
438
456
name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
439
457
rank_, size);
458
+ // TODO(hanzhi713): Threshold is different for A100 and H100.
459
+ // Add per device threshold.
440
460
#define REDUCE_CASE (ngpus ) \
441
461
case ngpus: { \
442
462
if (world_size_ == 2 ) { \
0 commit comments