25
25
26
26
using namespace flashinfer ;
27
27
28
- void sampling_from_probs (at::Tensor probs, at::Tensor samples, bool deterministic,
28
+ void sampling_from_probs (at::Tensor probs, at::Tensor output,
29
+ std::optional<at::Tensor> maybe_indices, bool deterministic,
29
30
std::optional<at::Generator> gen_, int64_t cuda_stream) {
30
31
CHECK_INPUT (probs);
31
32
auto device = probs.device ();
32
33
CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
33
- unsigned int batch_size = probs .size (0 );
34
+ unsigned int batch_size = output .size (0 );
34
35
unsigned int vocab_size = probs.size (1 );
35
36
36
37
uint64_t philox_seed, philox_offset;
@@ -43,20 +44,22 @@ void sampling_from_probs(at::Tensor probs, at::Tensor samples, bool deterministi
43
44
44
45
cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
45
46
cudaError_t status = sampling::SamplingFromProb (
46
- static_cast <float *>(probs.data_ptr ()), static_cast <int *>(samples.data_ptr ()), batch_size,
47
- vocab_size, deterministic, philox_seed, philox_offset, stream);
47
+ static_cast <float *>(probs.data_ptr ()), static_cast <int *>(output.data_ptr ()),
48
+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
49
+ batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
48
50
TORCH_CHECK (status == cudaSuccess, " SamplingFromProbs failed with error code " +
49
51
std::string (cudaGetErrorString (status)));
50
52
}
51
53
52
- void top_p_sampling_from_probs (at::Tensor probs, at::Tensor samples,
54
+ void top_p_sampling_from_probs (at::Tensor probs, at::Tensor output,
55
+ std::optional<at::Tensor> maybe_indices,
53
56
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
54
57
bool deterministic, std::optional<at::Generator> gen_,
55
58
int64_t cuda_stream) {
56
59
CHECK_INPUT (probs);
57
60
auto device = probs.device ();
58
61
CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
59
- unsigned int batch_size = probs .size (0 );
62
+ unsigned int batch_size = output .size (0 );
60
63
unsigned int vocab_size = probs.size (1 );
61
64
bool has_top_p_arr = maybe_top_p_arr.has_value ();
62
65
uint64_t philox_seed, philox_offset;
@@ -69,25 +72,26 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
69
72
70
73
cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
71
74
cudaError_t status = sampling::TopPSamplingFromProb<float , int >(
72
- static_cast <float *>(probs.data_ptr ()), static_cast <int *>(samples.data_ptr ()),
75
+ static_cast <float *>(probs.data_ptr ()), static_cast <int *>(output.data_ptr ()),
76
+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
73
77
has_top_p_arr ? static_cast <float *>(maybe_top_p_arr->data_ptr ()) : nullptr , batch_size,
74
78
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
75
79
TORCH_CHECK (status == cudaSuccess, " TopPSamplingFromProbs failed with error code " +
76
80
std::string (cudaGetErrorString (status)));
77
81
}
78
82
79
- void top_k_sampling_from_probs (at::Tensor probs, at::Tensor samples,
83
+ void top_k_sampling_from_probs (at::Tensor probs, at::Tensor output,
84
+ std::optional<at::Tensor> maybe_indices,
80
85
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
81
86
bool deterministic, std::optional<at::Generator> gen_,
82
87
int64_t cuda_stream) {
83
88
CHECK_INPUT (probs);
84
- CHECK_INPUT (samples );
89
+ CHECK_INPUT (output );
85
90
auto device = probs.device ();
86
- CHECK_EQ (samples.device (), device);
87
- CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
88
- CHECK_DIM (1 , samples); // samples: (batch_size)
89
- CHECK_EQ (probs.size (0 ), samples.size (0 ));
90
- unsigned int batch_size = probs.size (0 );
91
+ CHECK_EQ (output.device (), device);
92
+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
93
+ CHECK_DIM (1 , output); // output: (batch_size)
94
+ unsigned int batch_size = output.size (0 );
91
95
unsigned int vocab_size = probs.size (1 );
92
96
bool has_top_k_arr = maybe_top_k_arr.has_value ();
93
97
uint64_t philox_seed, philox_offset;
@@ -100,24 +104,26 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor samples,
100
104
101
105
cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
102
106
cudaError_t status = sampling::TopKSamplingFromProb<float , int >(
103
- static_cast <float *>(probs.data_ptr ()), static_cast <int *>(samples.data_ptr ()),
107
+ static_cast <float *>(probs.data_ptr ()), static_cast <int *>(output.data_ptr ()),
108
+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
104
109
has_top_k_arr ? static_cast <float *>(maybe_top_k_arr->data_ptr ()) : nullptr , batch_size,
105
110
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
106
111
TORCH_CHECK (status == cudaSuccess, " TopKSamplingFromProbs failed with error code " +
107
112
std::string (cudaGetErrorString (status)));
108
113
}
109
114
110
- void min_p_sampling_from_probs (at::Tensor probs, at::Tensor samples,
115
+ void min_p_sampling_from_probs (at::Tensor probs, at::Tensor output,
116
+ std::optional<at::Tensor> maybe_indices,
111
117
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
112
118
bool deterministic, std::optional<at::Generator> gen_,
113
119
int64_t cuda_stream) {
114
120
CHECK_INPUT (probs);
115
- CHECK_INPUT (samples );
121
+ CHECK_INPUT (output );
116
122
auto device = probs.device ();
117
- CHECK_EQ (samples .device (), device);
118
- CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
119
- CHECK_DIM (1 , samples ); // samples : (batch_size)
120
- unsigned int batch_size = probs .size (0 );
123
+ CHECK_EQ (output .device (), device);
124
+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
125
+ CHECK_DIM (1 , output ); // output : (batch_size)
126
+ unsigned int batch_size = output .size (0 );
121
127
unsigned int vocab_size = probs.size (1 );
122
128
bool has_min_p_arr = maybe_min_p_arr.has_value ();
123
129
uint64_t philox_seed, philox_offset;
@@ -132,24 +138,26 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
132
138
cudaError_t status = sampling::MinPSamplingFromProb<float , int >(
133
139
static_cast <float *>(probs.data_ptr ()),
134
140
has_min_p_arr ? static_cast <float *>(maybe_min_p_arr->data_ptr ()) : nullptr ,
135
- static_cast <int *>(samples.data_ptr ()), batch_size, min_p_val, vocab_size, deterministic,
136
- philox_seed, philox_offset, stream);
141
+ static_cast <int *>(output.data_ptr ()),
142
+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
143
+ batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
137
144
TORCH_CHECK (status == cudaSuccess, " MinPSamplingFromProb failed with error code " +
138
145
std::string (cudaGetErrorString (status)));
139
146
}
140
147
141
- void top_k_top_p_sampling_from_probs (at::Tensor probs, at::Tensor samples,
148
+ void top_k_top_p_sampling_from_probs (at::Tensor probs, at::Tensor output,
149
+ std::optional<at::Tensor> maybe_indices,
142
150
std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
143
151
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
144
152
bool deterministic, std::optional<at::Generator> gen_,
145
153
int64_t cuda_stream) {
146
154
CHECK_INPUT (probs);
147
- CHECK_INPUT (samples );
155
+ CHECK_INPUT (output );
148
156
auto device = probs.device ();
149
- CHECK_EQ (samples .device (), device);
150
- CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
151
- CHECK_DIM (1 , samples ); // samples : (batch_size)
152
- unsigned int batch_size = probs .size (0 );
157
+ CHECK_EQ (output .device (), device);
158
+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
159
+ CHECK_DIM (1 , output ); // output : (batch_size)
160
+ unsigned int batch_size = output .size (0 );
153
161
unsigned int vocab_size = probs.size (1 );
154
162
bool has_top_k_arr = maybe_top_k_arr.has_value ();
155
163
bool has_top_p_arr = maybe_top_p_arr.has_value ();
@@ -166,8 +174,10 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
166
174
static_cast <float *>(probs.data_ptr ()),
167
175
has_top_k_arr ? static_cast <int *>(maybe_top_k_arr->data_ptr ()) : nullptr ,
168
176
has_top_p_arr ? static_cast <float *>(maybe_top_p_arr->data_ptr ()) : nullptr ,
169
- static_cast <int *>(samples.data_ptr ()), batch_size, top_k_val, top_p_val, vocab_size,
170
- deterministic, philox_seed, philox_offset, stream);
177
+ static_cast <int *>(output.data_ptr ()),
178
+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
179
+ batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
180
+ stream);
171
181
TORCH_CHECK (status == cudaSuccess, " TopKTopPSamplingFromProbs failed with error code " +
172
182
std::string (cudaGetErrorString (status)));
173
183
}
0 commit comments