Skip to content

Commit 088e81f

Browse files
authored
misc: remove head dimension 64 from AOT (flashinfer-ai#782)
This PR removes `head_dim=64` from pre-built AOT wheels, `head_dim=64` will be supported using JIT in the future.
1 parent 74a4054 commit 088e81f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

setup.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
root = Path(__file__).parent.resolve()
2828
gen_dir = root / "csrc" / "generated"
2929

30-
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
30+
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",")
3131
head_dims = list(map(int, head_dims))
32-
SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)}
33-
head_dims_sm90 = list(SM90_ALLOWED_HEAD_DIMS) # No support for custom head dims for SM90
32+
SM90_ALLOWED_HEAD_DIMS = {(128, 128), (256, 256), (192, 128)}
33+
head_dims_sm90 = list(
34+
SM90_ALLOWED_HEAD_DIMS
35+
) # No support for custom head dims for SM90
3436

3537
mask_modes = [0, 1, 2]
3638

0 commit comments

Comments
 (0)