diff --git a/setup.py b/setup.py index fc18489fc..1a606a3cf 100644 --- a/setup.py +++ b/setup.py @@ -29,10 +29,11 @@ head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",") head_dims = list(map(int, head_dims)) -SM90_ALLOWED_HEAD_DIMS = {(128, 128), (256, 256), (192, 128)} -head_dims_sm90 = list( - SM90_ALLOWED_HEAD_DIMS -) # No support for custom head dims for SM90 +SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)} +head_dims_sm90 = [(d, d) for d in head_dims if (d, d) in SM90_ALLOWED_HEAD_DIMS] +head_dims_sm90.extend( + [(k, v) for k, v in SM90_ALLOWED_HEAD_DIMS if k != v] +) # Always enable (192,128) mask_modes = [0, 1, 2]