Skip to content

Commit 1250b68

Browse files
authored
bugfix: suppress alignment warning of sampling kernels (#297)
We declare multiple kernels inside the `sampling.cuh` and they use dynamic shared memory (with the same extern variable name) with different alignment requirements (e.g. some are alignof 4, some are alignof 64). In this PR we use different names for extern variable that have different alignment requirements to suppress the warning.
1 parent aff4cf0 commit 1250b68

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

include/flashinfer/sampling.cuh

+20-16
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
137137
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
138138
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
139139

140-
extern __shared__ __align__(alignof(
141-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
140+
extern __shared__ __align__(
141+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
142+
uint8_t smem_sampling[];
142143
auto& temp_storage = reinterpret_cast<
143-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
144+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
144145
temp_storage.data.sampled_id = d - 1;
145146
__syncthreads();
146147

@@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
171172
const uint32_t batch_size = gridDim.x;
172173
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
173174

174-
extern __shared__ __align__(alignof(
175-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
175+
extern __shared__ __align__(
176+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
177+
uint8_t smem_sampling[];
176178
auto& temp_storage = reinterpret_cast<
177-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
179+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
178180

179181
vec_t<DType, VEC_SIZE> probs_vec;
180182
DType aggregate;
@@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
264266
}
265267
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
266268

267-
extern __shared__ __align__(alignof(
268-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
269+
extern __shared__ __align__(
270+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
271+
uint8_t smem_sampling[];
269272
auto& temp_storage = reinterpret_cast<
270-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
273+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
271274

272275
vec_t<DType, VEC_SIZE> probs_vec;
273276
DType aggregate;
@@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float
454457
const uint32_t row_idx = bx;
455458

456459
extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
457-
uint8_t smem[];
460+
uint8_t smem_renorm[];
458461
auto& temp_storage =
459-
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
462+
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
460463
temp_storage.data.max_val = DType(0);
461464
vec_t<DType, VEC_SIZE> probs_vec;
462465
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32
543546
const uint32_t row_idx = bx;
544547

545548
extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
546-
uint8_t smem[];
549+
uint8_t smem_renorm[];
547550
auto& temp_storage =
548-
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
551+
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
549552
temp_storage.data.max_val = DType(0);
550553
vec_t<DType, VEC_SIZE> probs_vec;
551554
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
674677
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
675678
const uint32_t row_idx = bx;
676679

677-
extern __shared__ __align__(alignof(
678-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
680+
extern __shared__ __align__(
681+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
682+
uint8_t smem_sampling[];
679683
auto& temp_storage = reinterpret_cast<
680-
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
684+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
681685

682686
uint32_t pos = 0;
683687
for (pos = 0; pos < num_speculative_tokens; ++pos) {

0 commit comments

Comments
 (0)