Skip to content

Commit 2d2e13a

Browse files
authored
misc: allow head_dim=64 for sm90 AOT (#783)
Keeps the default behavior of #782, i.e., build AOT without `head_dim=64`. But gives an option to specifically enable `head_dim=64`. `head_dim=64` is used by some small models like [Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json).
1 parent 088e81f commit 2d2e13a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

setup.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929

3030
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",")
3131
head_dims = list(map(int, head_dims))
32-
SM90_ALLOWED_HEAD_DIMS = {(128, 128), (256, 256), (192, 128)}
33-
head_dims_sm90 = list(
34-
SM90_ALLOWED_HEAD_DIMS
35-
) # No support for custom head dims for SM90
32+
SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)}
33+
head_dims_sm90 = [(d, d) for d in head_dims if (d, d) in SM90_ALLOWED_HEAD_DIMS]
34+
head_dims_sm90.extend(
35+
[(k, v) for k, v in SM90_ALLOWED_HEAD_DIMS if k != v]
36+
) # Always enable (192,128)
3637

3738
mask_modes = [0, 1, 2]
3839

0 commit comments

Comments
 (0)