diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c0e9ff0286d..d5ed615ab23 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -264,9 +264,12 @@ def add_request( self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k self.min_p_cpu[req_index] = sampling_params.min_p self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty