Skip to content

Commit 41dddab

Browse files
authored
Add option to disable duplicates in topk (vllm-project#464)
Current implementation of optimized topp/topk calculations for scalar case is handling the duplicates that are outside of kth border. Unfortunately, to analyze duplicates it is necessary to make a synchronization with CPU, what makes multi-step scheduling useless together with topp/topk. This PR adds option to skip duplicates handling with `VLLM_HANDLE_TOPK_DUPLICATES` (default `True`). When this variable is set, handling duplicates will be skipped and we will avoid synchronization with CPU. It also removes the synchronization which was done earlier in Sampler, by saving scalar value of `top_k` and `top_p`. It should give performance gain for all benchmarks with these sampling parameters, especially together with multi-step scheduling. While disabling the duplicates handling may cause small accuracy differences, the best solution will be to handle duplicates without synchronization with CPU. However, this is not a trivial problem, so I will try to provide such solution later.
1 parent e818cf3 commit 41dddab

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

README_GAUDI.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi
277277
- block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): `block_size`
278278
- block size step (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `block_size`
279279
- block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): `max(128, (max_num_seqs*max_model_len)/block_size)`
280+
- ``VLLM_HANDLE_TOPK_DUPLICATES``: if ``true``, will handle duplicates that are outside of top-k, ``false`` by default
280281

281282
Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
282283

docs/source/getting_started/gaudi-installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ Environment variables
378378
- sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``block_size``
379379
- sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size``
380380
- sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``
381-
381+
- ``VLLM_HANDLE_TOPK_DUPLICATES``: if ``true``, will handle duplicates that are outside of top-k, ``false`` by default
382382

383383
Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
384384

vllm/model_executor/layers/sampler.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A layer that samples the next tokens from the model's outputs."""
22
import itertools
33
import math
4+
import os
45
import warnings
56
from dataclasses import dataclass
67
from importlib.util import find_spec
@@ -195,19 +196,16 @@ def _init_sampling_tensors(
195196
self._sampling_tensors = None
196197

197198
# Initialize new sampling tensors
198-
(sampling_tensors, do_penalties, do_top_p_top_k,
199-
do_min_p) = SamplingTensors.from_sampling_metadata(
199+
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
200+
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata(
200201
sampling_metadata, vocab_size, logits.device, logits.dtype)
201202

202203
self._sampling_tensors = sampling_tensors
203204
self._do_penalties = do_penalties
204205
self._do_top_p_top_k = do_top_p_top_k
205206
self._do_min_p = do_min_p
206-
self._top_p_scalar = sampling_tensors.top_ps[0]
207-
self._top_k_scalar = sampling_tensors.top_ks[0]
208-
scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar)
209-
scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar)
210-
self._scalar_p_and_k = torch.logical_and(scalar_p, scalar_k)
207+
self._top_k_scalar = top_k_scalar
208+
self._top_p_scalar = top_p_scalar
211209

212210
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
213211

@@ -270,10 +268,10 @@ def forward(
270268

271269
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
272270
# If we have a scalar p and k, we can use the optimized version.
273-
if self._scalar_p_and_k.any():
271+
if self._top_k_scalar and self._top_p_scalar:
274272
logits = self._apply_top_k_top_p_opt(logits,
275-
self._top_p_scalar.item(),
276-
self._top_k_scalar.item())
273+
self._top_p_scalar,
274+
self._top_k_scalar)
277275
else:
278276
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
279277
sampling_tensors.top_ks)
@@ -386,8 +384,13 @@ class ApplyToppTopkScalar:
386384
The main logic of this is in __call__
387385
This is a class instead of a function, just to keep track of
388386
the monotonic non-decreasing state _padded_k
387+
388+
To enable the duplicates that are outside of kth border,
389+
set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true.
389390
"""
390391
_padded_k = 0
392+
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES',
393+
'0').lower() in ['1', 'true']
391394

392395
def __init__(self, increment: int):
393396
self._increment = increment
@@ -397,12 +400,15 @@ def __call__(self, logits: torch.Tensor, p: float, k: int):
397400
ApplyToppTopkScalar._padded_k = min(k + self._increment,
398401
logits.shape[1])
399402

400-
vals, idx = torch.topk(logits, k=ApplyToppTopkScalar._padded_k, \
401-
dim=1, sorted=True)
403+
vals, idx = torch.topk(logits,
404+
k=ApplyToppTopkScalar._padded_k,
405+
dim=1,
406+
sorted=True)
402407

403408
# this "if" checks if we have bucketed so much that
404409
# we have padded k upto shape of logits
405-
if ApplyToppTopkScalar._padded_k != logits.shape[1]:
410+
if self._handle_duplicates and \
411+
ApplyToppTopkScalar._padded_k != logits.shape[1]:
406412
smallest_of_top_k = vals[:, k - 1]
407413
num_duplicates_of_smallest_of_topk = torch.sum(
408414
logits == smallest_of_top_k.unsqueeze(1), 1)
@@ -427,9 +433,10 @@ def __call__(self, logits: torch.Tensor, p: float, k: int):
427433
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])
428434

429435
# recompute topk with expanded padded_k
430-
vals, idx = torch.topk(logits, \
431-
k=ApplyToppTopkScalar._padded_k, \
432-
dim=1, sorted=True)
436+
vals, idx = torch.topk(logits,
437+
k=ApplyToppTopkScalar._padded_k,
438+
dim=1,
439+
sorted=True)
433440

434441
idx = torch.fliplr(idx)
435442
vals = torch.fliplr(vals)

vllm/model_executor/sampling_metadata.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def from_sampling_metadata(
389389
vocab_size: int,
390390
device: torch.device,
391391
dtype: torch.dtype,
392-
) -> Tuple["SamplingTensors", bool, bool, bool]:
392+
) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int],
393+
Optional[float]]:
393394
prompt_tokens: List[array] = []
394395
output_tokens: List[array] = []
395396
top_ks: List[int] = []
@@ -476,6 +477,11 @@ def from_sampling_metadata(
476477
prompt_tokens.append(seq_data.prompt_token_ids_array)
477478
output_tokens.append(seq_data.output_token_ids_array)
478479

480+
top_k_scalar = top_ks[0] if do_top_p_top_k and all(
481+
k == top_ks[0] for k in top_ks) else None
482+
top_p_scalar = top_ps[0] if do_top_p_top_k and all(
483+
p == top_ps[0] for p in top_ps) else None
484+
479485
sampling_tensors = SamplingTensors.from_lists(
480486
temperatures,
481487
top_ps,
@@ -490,7 +496,8 @@ def from_sampling_metadata(
490496
device,
491497
dtype,
492498
)
493-
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
499+
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
500+
top_k_scalar, top_p_scalar)
494501

495502
@classmethod
496503
def from_lists(

0 commit comments

Comments
 (0)