Skip to content

Commit bb49fac

Browse files
authored
doc: remove misleading docstring about non_blocking (#966)
As noted in #965 , we have some misleading docstring about the use of `non_blocking` option in plan functions of attention wrappers (they are only necessary in our old designs which we deprecate later), this PR fixes the issue.
1 parent 034fc18 commit bb49fac

File tree

4 files changed

+22
-23
lines changed

4 files changed

+22
-23
lines changed

flashinfer/decode.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def plan(
746746
sm_scale: Optional[float] = None,
747747
rope_scale: Optional[float] = None,
748748
rope_theta: Optional[float] = None,
749-
non_blocking: bool = False,
749+
non_blocking: bool = True,
750750
) -> None:
751751
r"""Plan batch decode for given problem specification.
752752
@@ -789,8 +789,7 @@ def plan(
789789
The data type of both the query and key/value tensors. Defaults to torch.float16.
790790
data_type is deprecated, please use q_data_type and kv_data_type instead.
791791
non_blocking : bool
792-
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
793-
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
792+
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
794793
795794
796795
Note
@@ -823,12 +822,12 @@ def plan(
823822
"The size of indices should be less than or equal to the allocated buffer"
824823
)
825824
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking)
826-
self._paged_kv_indices_buf[: len(indices)].copy_(
827-
indices, non_blocking=non_blocking
828-
)
829825
self._paged_kv_last_page_len_buf.copy_(
830826
last_page_len, non_blocking=non_blocking
831827
)
828+
self._paged_kv_indices_buf[: len(indices)].copy_(
829+
indices, non_blocking=(indices.device == self.device) and non_blocking
830+
)
832831
else:
833832
self._paged_kv_indptr_buf = indptr.to(
834833
self.device, non_blocking=non_blocking

flashinfer/pod.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def plan(
298298
sm_scale: Optional[float] = None,
299299
rope_scale: Optional[float] = None,
300300
rope_theta: Optional[float] = None,
301-
non_blocking: bool = False,
301+
non_blocking: bool = True,
302302
) -> None:
303303
r"""Plan POD's batch decode for given problem specification.
304304
@@ -335,8 +335,7 @@ def plan(
335335
The data type of both the query and key/value tensors. Defaults to torch.float16.
336336
data_type is deprecated, please use q_data_type and kv_data_type instead.
337337
non_blocking : bool
338-
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
339-
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
338+
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
340339
341340
342341
Note
@@ -371,12 +370,12 @@ def plan(
371370
"The size of indices should be less than or equal to the allocated buffer"
372371
)
373372
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking)
374-
self._paged_kv_indices_buf[: len(indices)].copy_(
375-
indices, non_blocking=non_blocking
376-
)
377373
self._paged_kv_last_page_len_buf.copy_(
378374
last_page_len, non_blocking=non_blocking
379375
)
376+
self._paged_kv_indices_buf[: len(indices)].copy_(
377+
indices, non_blocking=(indices.device == self.device) and non_blocking
378+
)
380379
else:
381380
self._paged_kv_indptr_buf = indptr.to(
382381
self.device, non_blocking=non_blocking

flashinfer/prefill.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,7 @@ def plan(
11951195
rope_theta: Optional[float] = None,
11961196
q_data_type: Union[str, torch.dtype] = "float16",
11971197
kv_data_type: Optional[Union[str, torch.dtype]] = None,
1198-
non_blocking: bool = False,
1198+
non_blocking: bool = True,
11991199
) -> None:
12001200
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
12011201
@@ -1269,8 +1269,7 @@ def plan(
12691269
kv_data_type : Optional[Union[str, torch.dtype]]
12701270
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
12711271
non_blocking : bool
1272-
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
1273-
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
1272+
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
12741273
12751274
Note
12761275
----
@@ -1349,12 +1348,13 @@ def plan(
13491348

13501349
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking)
13511350
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking)
1352-
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
1353-
paged_kv_indices, non_blocking=non_blocking
1354-
)
13551351
self._paged_kv_last_page_len_buf.copy_(
13561352
paged_kv_last_page_len, non_blocking=non_blocking
13571353
)
1354+
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
1355+
paged_kv_indices,
1356+
non_blocking=(paged_kv_indices.device == self.device) and non_blocking,
1357+
)
13581358

13591359
if packed_custom_mask is not None:
13601360
if not torch.is_tensor(self._custom_mask_buf):
@@ -1366,7 +1366,9 @@ def plan(
13661366
"mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
13671367
)
13681368
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
1369-
packed_custom_mask, non_blocking=non_blocking
1369+
packed_custom_mask,
1370+
non_blocking=(packed_custom_mask.device == self.device)
1371+
and non_blocking,
13701372
)
13711373
# NOTE(Zihao): mask_indptr has the same length as qo_indptr
13721374
self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking)

flashinfer/sparse.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def plan(
207207
rope_theta: Optional[float] = None,
208208
q_data_type: Union[str, torch.dtype] = "float16",
209209
kv_data_type: Optional[Union[str, torch.dtype]] = None,
210-
non_blocking: bool = False,
210+
non_blocking: bool = True,
211211
) -> None:
212212
r"""Create auxiliary data structures for block sparse attention.
213213
@@ -270,8 +270,7 @@ def plan(
270270
kv_data_type : Optional[Union[str, torch.dtype]]
271271
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
272272
non_blocking : bool
273-
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
274-
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
273+
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
275274
276275
277276
The :meth:`plan` method should be called before any :meth:`run` or
@@ -414,7 +413,7 @@ def plan(
414413

415414
kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C
416415
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
417-
kv_lens_arr_host, non_blocking=non_blocking
416+
kv_lens_arr_host,
418417
)
419418

420419
if self._backend == "fa3":

0 commit comments

Comments
 (0)