34
34
from vllm .sampling_params import SamplingType
35
35
from vllm .sequence import IntermediateTensors
36
36
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
37
- GiB_bytes , LazyLoader , cdiv , check_use_alibi ,
38
- is_pin_memory_available )
37
+ GiB_bytes , LazyLoader , async_tensor_h2d , cdiv ,
38
+ check_use_alibi , is_pin_memory_available )
39
39
from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
40
40
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
41
41
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
@@ -281,7 +281,7 @@ def __init__(
281
281
def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> bool :
282
282
"""
283
283
Update the order of requests in the batch based on the attention
284
- backend's needs. For example, some attention backends (namely MLA) may
284
+ backend's needs. For example, some attention backends (namely MLA) may
285
285
want to separate requests based on if the attention computation will be
286
286
compute-bound or memory-bound.
287
287
@@ -1360,9 +1360,10 @@ def execute_model(
1360
1360
scheduler_output .num_scheduled_tokens [req_id ])
1361
1361
next_token_id = req_state .get_token_id (seq_len )
1362
1362
next_token_ids .append (next_token_id )
1363
- next_token_ids = torch .tensor (next_token_ids ,
1364
- dtype = torch .int32 ,
1365
- device = self .device )
1363
+ next_token_ids = async_tensor_h2d (next_token_ids ,
1364
+ dtype = torch .int32 ,
1365
+ target_device = self .device ,
1366
+ pin_memory = True )
1366
1367
eagle_attn_metadata = attn_metadata [self .drafter .attn_layer_name ]
1367
1368
1368
1369
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
@@ -1390,14 +1391,16 @@ def execute_model(
1390
1391
n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
1391
1392
for i , n in enumerate (num_draft_tokens )
1392
1393
]
1393
- num_rejected_tokens = torch . tensor (
1394
+ num_rejected_tokens_tensor = async_tensor_h2d (
1394
1395
num_rejected_tokens ,
1395
1396
dtype = torch .int32 ,
1396
- device = self .device ,
1397
- )
1397
+ target_device = self .device ,
1398
+ pin_memory = True )
1399
+ num_tokens = num_scheduled_tokens - sum (num_rejected_tokens )
1398
1400
cu_num_tokens , token_indices = self .drafter .prepare_inputs (
1399
1401
eagle_attn_metadata .query_start_loc ,
1400
- num_rejected_tokens ,
1402
+ num_rejected_tokens_tensor ,
1403
+ num_tokens ,
1401
1404
)
1402
1405
target_token_ids = self .input_ids [token_indices ]
1403
1406
target_positions = positions [token_indices ]
@@ -1408,7 +1411,6 @@ def execute_model(
1408
1411
target_hidden_states = hidden_states [token_indices ]
1409
1412
target_slot_mapping = eagle_attn_metadata .slot_mapping [
1410
1413
token_indices ]
1411
-
1412
1414
draft_token_ids = self .drafter .propose (
1413
1415
target_token_ids = target_token_ids ,
1414
1416
target_positions = target_positions ,
0 commit comments