Skip to content

Commit 8b45172

Browse files
authored
Remove USE_CUTLASS flag (#19271)
### Description Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for onnxruntime CUDA build), there is no need to have a flag to disable cutlass. Changes: (1) Reverted #18761 (2) remove the condition to build cutlass. (3) Fix a few build errors or warnings during testing CUDA 11.4 build. Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later. Flash attention and cutlass fused multihead attention will not be built for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if you want to support latest GPUs. It is better to include it in 1.17.0 (otherwise, the release branch might encounter build failure with CUDA 11.4). Tests: (1) Build with flash attention and efficient attention off: **passed** (2) Build with CUDA 11.4: **passed** Example build command used in Ubuntu 20.04: ``` export CUDA_HOME=/usr/local/cuda-11.4 export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/ export CUDACXX=/usr/local/cuda-11.4/bin/nvcc sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 11.4 \ --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --disable_types float8 ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 656ca66 commit 8b45172

26 files changed

+25
-131
lines changed

cmake/CMakeLists.txt

+7-16
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
9797
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
9898
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
9999

100-
cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
101100
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
102101
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
103102

@@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA)
707706
enable_language(CUDA)
708707
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
709708

709+
if (onnxruntime_DISABLE_CONTRIB_OPS)
710+
set(onnxruntime_USE_FLASH_ATTENTION OFF)
711+
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
712+
endif()
710713
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
711-
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
712-
set(onnxruntime_USE_CUTLASS OFF)
714+
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
715+
set(onnxruntime_USE_FLASH_ATTENTION OFF)
716+
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
713717
endif()
714718
else()
715-
set(onnxruntime_USE_CUTLASS OFF)
716-
endif()
717-
718-
if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
719-
if (onnxruntime_DISABLE_CONTRIB_OPS)
720-
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
721-
else()
722-
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
723-
endif()
724719
set(onnxruntime_USE_FLASH_ATTENTION OFF)
725720
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
726721
endif()
@@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name)
906901
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
907902
endif()
908903

909-
if (onnxruntime_USE_CUTLASS)
910-
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
911-
endif()
912-
913904
if(USE_NEURAL_SPEED)
914905
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
915906
endif()

cmake/external/cutlass.cmake

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
if (onnxruntime_USE_CUTLASS)
2-
include(FetchContent)
3-
FetchContent_Declare(
4-
cutlass
5-
URL ${DEP_URL_cutlass}
6-
URL_HASH SHA1=${DEP_SHA1_cutlass}
7-
)
1+
include(FetchContent)
2+
FetchContent_Declare(
3+
cutlass
4+
URL ${DEP_URL_cutlass}
5+
URL_HASH SHA1=${DEP_SHA1_cutlass}
6+
)
87

9-
FetchContent_GetProperties(cutlass)
10-
if(NOT cutlass_POPULATED)
11-
FetchContent_Populate(cutlass)
12-
endif()
8+
FetchContent_GetProperties(cutlass)
9+
if(NOT cutlass_POPULATED)
10+
FetchContent_Populate(cutlass)
1311
endif()

onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc

-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#ifdef USE_CUTLASS
5-
64
#include "core/common/safeint.h"
75
#include "core/providers/cuda/cuda_common.h"
86
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
@@ -204,5 +202,3 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
204202
} // namespace cuda
205203
} // namespace contrib
206204
} // namespace onnxruntime
207-
208-
#endif

onnxruntime/contrib_ops/cuda/collective/sharded_moe.h

-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#ifdef USE_CUTLASS
5-
64
#pragma once
75

86
#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
@@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
3634
} // namespace cuda
3735
} // namespace contrib
3836
} // namespace onnxruntime
39-
40-
#endif

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

-8
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
7070
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
7171
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
7272
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
73-
#ifdef USE_CUTLASS
7473
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
7574
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
76-
#endif
7775
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
7876
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
7977
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
@@ -169,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR
169167
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
170168
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
171169

172-
#ifdef USE_CUTLASS
173170
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
174171
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
175-
#endif
176172

177173
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
178174
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
@@ -272,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
272268
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
273269
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
274270
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
275-
#ifdef USE_CUTLASS
276271
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
277272
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
278-
#endif
279273
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
280274
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
281275
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
@@ -377,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
377371
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
378372
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
379373

380-
#ifdef USE_CUTLASS
381374
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
382375
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
383-
#endif
384376

385377
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
386378
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,

onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
17-
#ifdef USE_CUTLASS
18-
1916
#pragma once
2017

2118
#include <cuda_runtime_api.h>
@@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() {
5249
}
5350

5451
} // namespace ort_fastertransformer
55-
56-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

+4-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#ifdef USE_CUTLASS
1716

1817
#include "cutlass_heuristic.h"
1918

@@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
6665
}
6766

6867
// Check that the workspace has sufficient space for this split-k factor
69-
const int ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
70-
const int ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
71-
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
68+
const size_t ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
69+
const size_t ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
70+
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
7271

7372
if (required_ws_bytes > workspace_bytes) {
7473
return false;
@@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
128127
int current_m_tile = 0;
129128

130129
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
131-
for (int ii = 0; ii < candidate_configs.size(); ++ii) {
130+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
132131
CutlassGemmConfig candidate_config = candidate_configs[ii];
133132
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
134133
int occupancy = occupancies[ii];
@@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
186185
}
187186

188187
} // namespace ort_fastertransformer
189-
190-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h

-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#ifdef USE_CUTLASS
1716

1817
#pragma once
1918

@@ -38,4 +37,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
3837
const int multi_processor_count, const int is_weight_only);
3938

4039
} // namespace ort_fastertransformer
41-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h

-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
*
2323
*/
2424

25-
#ifdef USE_CUTLASS
26-
2725
#pragma once
2826

2927
#include "cutlass/array.h"
@@ -133,5 +131,3 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
133131
};
134132

135133
} // namespace ort_fastertransformer
136-
137-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h

-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifdef USE_CUTLASS
18-
1917
#pragma once
2018

2119
namespace ort_fastertransformer {
@@ -58,5 +56,3 @@ struct CutlassGemmConfig {
5856
};
5957

6058
} // namespace ort_fastertransformer
61-
62-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h

-4
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
*
3030
**************************************************************************************************/
3131

32-
#ifdef USE_CUTLASS
33-
3432
/*! \file
3533
\brief Scheduler for grouped GEMM
3634
*/
@@ -79,5 +77,3 @@ struct GemmMoeProblemVisitor
7977
} // namespace cutlass
8078

8179
/////////////////////////////////////////////////////////////////////////////////////////////////
82-
83-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
2323
*/
2424

25-
#ifdef USE_CUTLASS
26-
2725
#pragma once
2826

2927
#include "cutlass/layout/matrix.h"
@@ -152,6 +150,4 @@ struct MixedGemmArchTraits<
152150

153151
} // namespace kernel
154152
} // namespace gemm
155-
} // namespace cutlass
156-
157-
#endif
153+
} // namespace cutlass

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h

-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
*
2424
**************************************************************************************************/
2525

26-
#ifdef USE_CUTLASS
27-
2826
#pragma once
2927

3028
#include "cutlass/complex.h"
@@ -463,5 +461,3 @@ struct MoeFCGemm {
463461
} // namespace cutlass
464462

465463
/////////////////////////////////////////////////////////////////////////////////////////////////
466-
467-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h

-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifdef USE_CUTLASS
18-
1917
#pragma once
2018

2119
#include <cuda_runtime_api.h>
@@ -64,5 +62,3 @@ class MoeGemmRunner {
6462
};
6563

6664
} // namespace ort_fastertransformer
67-
68-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu

-4
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifdef USE_CUTLASS
18-
1917
#include "moe_gemm_kernels_template.h"
2018

2119
namespace ort_fastertransformer {
2220
template class MoeGemmRunner<half, half>;
2321
} // namespace ort_fastertransformer
24-
25-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu

-4
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifdef USE_CUTLASS
18-
1917
#include "moe_gemm_kernels_template.h"
2018

2119
namespace ort_fastertransformer {
2220
template class MoeGemmRunner<float, float>;
2321
} // namespace ort_fastertransformer
24-
25-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h

-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifdef USE_CUTLASS
18-
1917
// Ignore CUTLASS warnings about type punning
2018
#ifdef __GNUC__
2119
#pragma GCC diagnostic push
@@ -428,5 +426,3 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, con
428426
}
429427

430428
} // namespace ort_fastertransformer
431-
432-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
// Copyright (c) Microsoft Corporation. All rights reserved.
1717
// Licensed under the MIT License.
1818

19-
#ifdef USE_CUTLASS
20-
2119
#include <cuda.h>
2220
#include <cuda_fp16.h>
2321
#include <math.h>
@@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half
900898
cudaStream_t);
901899

902900
} // namespace ort_fastertransformer
903-
904-
#endif

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
// Copyright (c) Microsoft Corporation. All rights reserved.
1717
// Licensed under the MIT License.
1818

19-
#ifdef USE_CUTLASS
20-
2119
#pragma once
2220

2321
#include "moe_gemm_kernels.h"
@@ -174,6 +172,4 @@ class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_s
174172
}
175173
};
176174

177-
} // namespace ort_fastertransformer
178-
179-
#endif
175+
} // namespace ort_fastertransformer

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h

-4
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
\brief Base scheduler for grouped problems, using MoE
3434
*/
3535

36-
#ifdef USE_CUTLASS
37-
3836
#pragma once
3937

4038
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
@@ -290,5 +288,3 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
290288
} // namespace kernel
291289
} // namespace gemm
292290
} // namespace cutlass
293-
294-
#endif

0 commit comments

Comments
 (0)