15
15
"""
16
16
17
17
import math
18
- from typing import Optional , Tuple
18
+ from typing import Optional , Tuple , List
19
19
import torch
20
20
21
21
# mypy: disable-error-code="attr-defined"
@@ -273,7 +273,7 @@ def __init__(
273
273
def reset_workspace_buffer (
274
274
self ,
275
275
float_workspace_buffer : torch .Tensor ,
276
- int_workspace_buffers : list [torch .Tensor ],
276
+ int_workspace_buffers : List [torch .Tensor ],
277
277
) -> None :
278
278
r"""Reset the workspace buffer.
279
279
@@ -283,8 +283,8 @@ def reset_workspace_buffer(
283
283
The new float workspace buffer, the device of the new float workspace buffer should
284
284
be the same as the device of the input tensors.
285
285
286
- int_workspace_buffer : torch.Tensor
287
- The new int workspace buffer, the device of the new int workspace buffer should
286
+ int_workspace_buffers : List[ torch.Tensor]
287
+ The array of new int workspace buffer, the device of the new int workspace buffer should
288
288
be the same as the device of the input tensors.
289
289
"""
290
290
for wrapper , int_workspace_buffer in zip (
@@ -294,10 +294,10 @@ def reset_workspace_buffer(
294
294
295
295
def plan (
296
296
self ,
297
- qo_indptr_arr : list [torch .Tensor ],
298
- paged_kv_indptr_arr : list [torch .Tensor ],
299
- paged_kv_indices_arr : list [torch .Tensor ],
300
- paged_kv_last_page_len : list [torch .Tensor ],
297
+ qo_indptr_arr : List [torch .Tensor ],
298
+ paged_kv_indptr_arr : List [torch .Tensor ],
299
+ paged_kv_indices_arr : List [torch .Tensor ],
300
+ paged_kv_last_page_len : List [torch .Tensor ],
301
301
num_qo_heads : int ,
302
302
num_kv_heads : int ,
303
303
head_dim : int ,
@@ -318,17 +318,17 @@ def plan(
318
318
319
319
Parameters
320
320
----------
321
- qo_indptr_arr : list [torch.Tensor]
321
+ qo_indptr_arr : List [torch.Tensor]
322
322
An array of qo indptr tensors for each level, the array length should be equal to
323
323
the number of levels.
324
324
The last element of each tensor should be the total number of queries/outputs.
325
- paged_kv_indptr_arr : list [torch.Tensor]
325
+ paged_kv_indptr_arr : List [torch.Tensor]
326
326
An array of paged kv-cache indptr tensors for each level, the array length should be
327
327
equal to the number of levels.
328
- paged_kv_indices_arr : list [torch.Tensor]
328
+ paged_kv_indices_arr : List [torch.Tensor]
329
329
An array of paged kv-cache indices tensors for each level, the array length should be
330
330
equal to the number of levels.
331
- paged_kv_last_page_len : list [torch.Tensor]
331
+ paged_kv_last_page_len : List [torch.Tensor]
332
332
An array of paged kv-cache last page length tensors for each level, the array length
333
333
should be equal to the number of levels.
334
334
num_qo_heads : int
0 commit comments