Skip to content

Commit 824ce40

Browse files
authored
Fix the type annotation of q_dtype and kv_dtype on ragged prefill (#798)
1 parent eb69778 commit 824ce40

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flashinfer/prefill.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1951,8 +1951,8 @@ def plan(
19511951
sm_scale: Optional[float] = None,
19521952
rope_scale: Optional[float] = None,
19531953
rope_theta: Optional[float] = None,
1954-
q_data_type: str = "float16",
1955-
kv_data_type: Optional[str] = None,
1954+
q_data_type: Union[str, torch.dtype] = "float16",
1955+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
19561956
) -> None:
19571957
r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.
19581958

0 commit comments

Comments
 (0)