Skip to content

Commit 103bd17

Browse files
[Build] Only build 9.0a for scaled_mm and sparse kernels (#12339)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent ce69f7f commit 103bd17

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
275275
# Only build Marlin kernels if we are building for at least some compatible archs.
276276
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
277277
# are not supported by Machete yet.
278-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
278+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
279279
if (MARLIN_ARCHS)
280280
set(MARLIN_SRCS
281281
"csrc/quantization/fp8/fp8_marlin.cu"
@@ -296,8 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
296296
endif()
297297

298298
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
299-
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
300-
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
299+
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
300+
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
301301
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
302302
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
303303
set_gencode_flags_for_srcs(
@@ -351,7 +351,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
351351
# 2:4 Sparse Kernels
352352

353353
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
354-
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
354+
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
355355
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
356356
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
357357
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")

cmake/utils.cmake

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ endmacro()
259259
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
260260
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
261261
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
262-
# 9.0a to the result.
262+
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
263263
# The result is stored in `OUT_CUDA_ARCHS`.
264264
#
265265
# Example:
@@ -270,34 +270,47 @@ endmacro()
270270
#
271271
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
272272
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273+
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
273274

274275
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
275276
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
276277
set(_CUDA_ARCHS)
277278
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
278279
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
279-
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
280+
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
281+
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
280282
set(_CUDA_ARCHS "9.0a")
281283
endif()
282284
endif()
283285

284286
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
285287

286-
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
287-
# less or eqault to ARCH
288-
foreach(_ARCH ${CUDA_ARCHS})
289-
set(_TMP_ARCH)
290-
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
291-
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
292-
set(_TMP_ARCH ${_SRC_ARCH})
293-
else()
294-
break()
288+
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
289+
# is less or equal to ARCH (but has the same major version since SASS binary
290+
# compatibility is only forward compatible within the same major version).
291+
foreach(_ARCH ${TGT_CUDA_ARCHS_})
292+
set(_TMP_ARCH)
293+
# Extract the major version of the target arch
294+
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
295+
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
296+
# Extract the major version of the source arch
297+
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
298+
# Check major-version match AND version-less-or-equal
299+
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
300+
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
301+
set(_TMP_ARCH "${_SRC_ARCH}")
302+
endif()
303+
else()
304+
# If we hit a version greater than the target, we can break
305+
break()
306+
endif()
307+
endforeach()
308+
309+
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
310+
if (_TMP_ARCH)
311+
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
295312
endif()
296313
endforeach()
297-
if (_TMP_ARCH)
298-
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
299-
endif()
300-
endforeach()
301314

302315
list(REMOVE_DUPLICATES _CUDA_ARCHS)
303316
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)

0 commit comments

Comments
 (0)