@@ -1195,7 +1195,7 @@ def plan(
1195
1195
rope_theta : Optional [float ] = None ,
1196
1196
q_data_type : Union [str , torch .dtype ] = "float16" ,
1197
1197
kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
1198
- non_blocking : bool = False ,
1198
+ non_blocking : bool = True ,
1199
1199
) -> None :
1200
1200
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
1201
1201
@@ -1269,8 +1269,7 @@ def plan(
1269
1269
kv_data_type : Optional[Union[str, torch.dtype]]
1270
1270
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
1271
1271
non_blocking : bool
1272
- Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
1273
- If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
1272
+ Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
1274
1273
1275
1274
Note
1276
1275
----
@@ -1349,12 +1348,13 @@ def plan(
1349
1348
1350
1349
self ._qo_indptr_buf .copy_ (qo_indptr , non_blocking = non_blocking )
1351
1350
self ._paged_kv_indptr_buf .copy_ (paged_kv_indptr , non_blocking = non_blocking )
1352
- self ._paged_kv_indices_buf [: len (paged_kv_indices )].copy_ (
1353
- paged_kv_indices , non_blocking = non_blocking
1354
- )
1355
1351
self ._paged_kv_last_page_len_buf .copy_ (
1356
1352
paged_kv_last_page_len , non_blocking = non_blocking
1357
1353
)
1354
+ self ._paged_kv_indices_buf [: len (paged_kv_indices )].copy_ (
1355
+ paged_kv_indices ,
1356
+ non_blocking = (paged_kv_indices .device == self .device ) and non_blocking ,
1357
+ )
1358
1358
1359
1359
if packed_custom_mask is not None :
1360
1360
if not torch .is_tensor (self ._custom_mask_buf ):
@@ -1366,7 +1366,9 @@ def plan(
1366
1366
"mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
1367
1367
)
1368
1368
self ._custom_mask_buf [: len (packed_custom_mask )].copy_ (
1369
- packed_custom_mask , non_blocking = non_blocking
1369
+ packed_custom_mask ,
1370
+ non_blocking = (packed_custom_mask .device == self .device )
1371
+ and non_blocking ,
1370
1372
)
1371
1373
# NOTE(Zihao): mask_indptr has the same length as qo_indptr
1372
1374
self ._mask_indptr_buf .copy_ (mask_indptr , non_blocking = non_blocking )
0 commit comments