Skip to content

Commit 77bff3f

Browse files
authored
bugfix: fix the python 3.8 type error (#486)
Bugfix to #484, thanks @wzhao18 for spotting this error.
1 parent eebbea0 commit 77bff3f

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

python/flashinfer/cascade.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
import math
18-
from typing import Optional, Tuple
18+
from typing import Optional, Tuple, List
1919
import torch
2020

2121
# mypy: disable-error-code="attr-defined"
@@ -273,7 +273,7 @@ def __init__(
273273
def reset_workspace_buffer(
274274
self,
275275
float_workspace_buffer: torch.Tensor,
276-
int_workspace_buffers: list[torch.Tensor],
276+
int_workspace_buffers: List[torch.Tensor],
277277
) -> None:
278278
r"""Reset the workspace buffer.
279279
@@ -283,8 +283,8 @@ def reset_workspace_buffer(
283283
The new float workspace buffer, the device of the new float workspace buffer should
284284
be the same as the device of the input tensors.
285285
286-
int_workspace_buffer : torch.Tensor
287-
The new int workspace buffer, the device of the new int workspace buffer should
286+
int_workspace_buffers : List[torch.Tensor]
287+
The array of new int workspace buffer, the device of the new int workspace buffer should
288288
be the same as the device of the input tensors.
289289
"""
290290
for wrapper, int_workspace_buffer in zip(
@@ -294,10 +294,10 @@ def reset_workspace_buffer(
294294

295295
def plan(
296296
self,
297-
qo_indptr_arr: list[torch.Tensor],
298-
paged_kv_indptr_arr: list[torch.Tensor],
299-
paged_kv_indices_arr: list[torch.Tensor],
300-
paged_kv_last_page_len: list[torch.Tensor],
297+
qo_indptr_arr: List[torch.Tensor],
298+
paged_kv_indptr_arr: List[torch.Tensor],
299+
paged_kv_indices_arr: List[torch.Tensor],
300+
paged_kv_last_page_len: List[torch.Tensor],
301301
num_qo_heads: int,
302302
num_kv_heads: int,
303303
head_dim: int,
@@ -318,17 +318,17 @@ def plan(
318318
319319
Parameters
320320
----------
321-
qo_indptr_arr : list[torch.Tensor]
321+
qo_indptr_arr : List[torch.Tensor]
322322
An array of qo indptr tensors for each level, the array length should be equal to
323323
the number of levels.
324324
The last element of each tensor should be the total number of queries/outputs.
325-
paged_kv_indptr_arr : list[torch.Tensor]
325+
paged_kv_indptr_arr : List[torch.Tensor]
326326
An array of paged kv-cache indptr tensors for each level, the array length should be
327327
equal to the number of levels.
328-
paged_kv_indices_arr : list[torch.Tensor]
328+
paged_kv_indices_arr : List[torch.Tensor]
329329
An array of paged kv-cache indices tensors for each level, the array length should be
330330
equal to the number of levels.
331-
paged_kv_last_page_len : list[torch.Tensor]
331+
paged_kv_last_page_len : List[torch.Tensor]
332332
An array of paged kv-cache last page length tensors for each level, the array length
333333
should be equal to the number of levels.
334334
num_qo_heads : int

0 commit comments

Comments
 (0)