diff --git a/setup.py b/setup.py index f042dd061..061d70857 100644 --- a/setup.py +++ b/setup.py @@ -85,10 +85,12 @@ def generate_cuda() -> None: ) if enable_sm90: + SM90_ALLOWED_HEAD_DIMS = {64, 128, 256} + sm90_head_dims = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS] aot_kernel_uris += get_sm90_instantiation_cu( argparse.Namespace( path=gen_dir, - head_dims=head_dims, + head_dims=sm90_head_dims, pos_encoding_modes=[0], use_fp16_qk_reductions=[0], mask_modes=mask_modes,