@@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
137
137
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
138
138
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
139
139
140
- extern __shared__ __align__ (alignof (
141
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
140
+ extern __shared__ __align__ (
141
+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
142
+ uint8_t smem_sampling[];
142
143
auto & temp_storage = reinterpret_cast <
143
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
144
+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
144
145
temp_storage.data .sampled_id = d - 1 ;
145
146
__syncthreads ();
146
147
@@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
171
172
const uint32_t batch_size = gridDim .x ;
172
173
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
173
174
174
- extern __shared__ __align__ (alignof (
175
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
175
+ extern __shared__ __align__ (
176
+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
177
+ uint8_t smem_sampling[];
176
178
auto & temp_storage = reinterpret_cast <
177
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
179
+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
178
180
179
181
vec_t <DType, VEC_SIZE> probs_vec;
180
182
DType aggregate;
@@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
264
266
}
265
267
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
266
268
267
- extern __shared__ __align__ (alignof (
268
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
269
+ extern __shared__ __align__ (
270
+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
271
+ uint8_t smem_sampling[];
269
272
auto & temp_storage = reinterpret_cast <
270
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
273
+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
271
274
272
275
vec_t <DType, VEC_SIZE> probs_vec;
273
276
DType aggregate;
@@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float
454
457
const uint32_t row_idx = bx;
455
458
456
459
extern __shared__ __align__ (alignof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
457
- uint8_t smem [];
460
+ uint8_t smem_renorm [];
458
461
auto & temp_storage =
459
- reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem );
462
+ reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm );
460
463
temp_storage.data .max_val = DType (0 );
461
464
vec_t <DType, VEC_SIZE> probs_vec;
462
465
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32
543
546
const uint32_t row_idx = bx;
544
547
545
548
extern __shared__ __align__ (alignof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
546
- uint8_t smem [];
549
+ uint8_t smem_renorm [];
547
550
auto & temp_storage =
548
- reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem );
551
+ reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm );
549
552
temp_storage.data .max_val = DType (0 );
550
553
vec_t <DType, VEC_SIZE> probs_vec;
551
554
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
674
677
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
675
678
const uint32_t row_idx = bx;
676
679
677
- extern __shared__ __align__ (alignof (
678
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
680
+ extern __shared__ __align__ (
681
+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
682
+ uint8_t smem_sampling[];
679
683
auto & temp_storage = reinterpret_cast <
680
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
684
+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
681
685
682
686
uint32_t pos = 0 ;
683
687
for (pos = 0 ; pos < num_speculative_tokens; ++pos) {
0 commit comments