From ff386f7d90cf99a8819816e1cd4b8f8b21780291 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Fri, 24 Jan 2025 23:09:46 +0000 Subject: [PATCH] Filter out unsupported head dim for sm90 --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f042dd061..061d70857 100644 --- a/setup.py +++ b/setup.py @@ -85,10 +85,12 @@ def generate_cuda() -> None: ) if enable_sm90: + SM90_ALLOWED_HEAD_DIMS = {64, 128, 256} + sm90_head_dims = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS] aot_kernel_uris += get_sm90_instantiation_cu( argparse.Namespace( path=gen_dir, - head_dims=head_dims, + head_dims=sm90_head_dims, pos_encoding_modes=[0], use_fp16_qk_reductions=[0], mask_modes=mask_modes,