From cc09d84cab096e6aeca0d3d088b694746f363164 Mon Sep 17 00:00:00 2001 From: Akash Verma Date: Fri, 21 Feb 2025 00:07:32 +0000 Subject: [PATCH 1/3] Added the code to download CK library from compute-artifactory and create link target. Enable USE_CK_FLASH_ATTENTION based on USE_FLASH_ATTENTION option. --- CMakeLists.txt | 14 +++++ aten/src/ATen/CMakeLists.txt | 3 +- .../hip/flash_attn/ck/CMakeLists.txt | 63 ------------------- benchmarks/transformer/sdpa.py | 1 + caffe2/CMakeLists.txt | 4 ++ cmake/External/ck.cmake | 44 +++++++++++++ 6 files changed, 65 insertions(+), 64 deletions(-) delete mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt create mode 100644 cmake/External/ck.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fc51fa382891..da0288dc91fc3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -865,6 +865,13 @@ cmake_dependent_option( "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) +cmake_dependent_option( + USE_CK_FLASH_ATTENTION + "Whether to build the CK flash_attention kernel. Will be enabled if USE_FLASH_ATTENTION is enabled." + ON + "USE_FLASH_ATTENTION" + OFF) + # We are currenlty not using alibi attention for Flash So we disable this # feature by default We dont currently document this feature because we don't # Suspect users building from source will need this @@ -888,6 +895,13 @@ if(USE_ROCM) endif() endif() +# CK shared lib linkage +if(USE_ROCM) + if(UNIX AND (USE_CK_FLASH_ATTENTION)) + include(cmake/External/ck.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index cff157c784c66..64cec5d0ceb00 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -184,7 +184,8 @@ if(USE_FLASH_ATTENTION) endif() message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") message(STATUS "Generating CK kernel instances...") - add_subdirectory(native/transformers/hip/flash_attn/ck) + # disable buidling CK files + # add_subdirectory(native/transformers/hip/flash_attn/ck) file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) endif() diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt deleted file mode 100644 index a72911cd510eb..0000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ /dev/null @@ -1,63 +0,0 @@ -# generate a list of kernels, but not actually emit files at config stage -execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.") -endif() - -execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.") -endif() - -# Generate the files for both fwd and bwd -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") -endif() - -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate BWD kernels.") -endif() - -# Change make_kernel to make_kernel_pt for fwd -execute_process( - COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt" - RESULT_VARIABLE ret) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass") -endif() - -# Change make_kernel to make_kernel_pt for bwd -execute_process( - COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt" - RESULT_VARIABLE ret) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the bwd pass") -endif() - -# Change file extensions to .hip -execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done" - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change the generated instances extensions from .cpp to .hpp") -endif() diff --git a/benchmarks/transformer/sdpa.py b/benchmarks/transformer/sdpa.py index d45970213e012..f146d0b1ecb4f 100644 --- a/benchmarks/transformer/sdpa.py +++ b/benchmarks/transformer/sdpa.py @@ -172,6 +172,7 @@ def print_results(experiments: List[Experiment]): def main(): + torch.backends.cuda.preferred_rocm_fa_library("ck") seed = 123 torch.manual_seed(seed) results = [] diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 83aa670846672..c9b520ef94990 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -929,6 +929,10 @@ if(USE_ROCM) if(USE_FLASH_ATTENTION) target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) endif() +# link CK library + if(USE_CK_FLASH_ATTENTION) + target_link_libraries(torch_hip PRIVATE __ck_lib) + endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake # TODO: Not totally sure if this is live or not diff --git a/cmake/External/ck.cmake b/cmake/External/ck.cmake new file mode 100644 index 0000000000000..59ee354e3bd76 --- /dev/null +++ b/cmake/External/ck.cmake @@ -0,0 +1,44 @@ +# +# create INTERFACE target for CK library +# + +# get CK commit hash +execute_process( + COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel + RESULT_VARIABLE result + OUTPUT_VARIABLE submodule_status + ERROR_VARIABLE submodule_status_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +if(result EQUAL 0) + string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status}) + # extract first 8 characters of the commit hash + string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash) +else() + message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") +endif() + +# get ROCm version from LoadHIP.cmake +include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake) + +# full path for CK library on compute-artifactory.amd.com +set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local") +set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so") + +# set destination +set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so") + +# download CK library +file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status) +if(NOT download_status) + message(STATUS "Downloaded CK library successfully.") +else() + message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.") +endif() + +# create INTERFACE target +add_library(__ck_lib INTERFACE) + +# specify path to CK library +target_link_libraries(__ck_lib INTERFACE ${destination}) + From 6943678cdc3d79c37906161e88df783828f8d403 Mon Sep 17 00:00:00 2001 From: Akash Verma Date: Tue, 1 Apr 2025 02:54:58 +0000 Subject: [PATCH 2/3] Added USE_CK_FLASH_ATTENTION as cmake variable. Fixed lint/NIT errors using lintrunner. --- CMakeLists.txt | 2 +- aten/src/ATen/CMakeLists.txt | 29 +++++++++++++---------------- benchmarks/transformer/sdpa.py | 1 - cmake/External/ck.cmake | 25 ++++++++++++------------- 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index da0288dc91fc3..944d7cd557e9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -897,7 +897,7 @@ endif() # CK shared lib linkage if(USE_ROCM) - if(UNIX AND (USE_CK_FLASH_ATTENTION)) + if(UNIX AND (USE_CK_FLASH_ATTENTION)) include(cmake/External/ck.cmake) endif() endif() diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 64cec5d0ceb00..6378342367ff9 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -172,23 +172,20 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) - if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) - set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) - if(USE_CK_FLASH_ATTENTION STREQUAL "1") - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() - message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") - message(STATUS "Generating CK kernel instances...") - # disable buidling CK files - # add_subdirectory(native/transformers/hip/flash_attn/ck) - file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + if(USE_CK_FLASH_ATTENTION) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") endif() + endif() + message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") + message(STATUS "Generating CK kernel instances...") + # disable buidling CK files + # add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") diff --git a/benchmarks/transformer/sdpa.py b/benchmarks/transformer/sdpa.py index f146d0b1ecb4f..d45970213e012 100644 --- a/benchmarks/transformer/sdpa.py +++ b/benchmarks/transformer/sdpa.py @@ -172,7 +172,6 @@ def print_results(experiments: List[Experiment]): def main(): - torch.backends.cuda.preferred_rocm_fa_library("ck") seed = 123 torch.manual_seed(seed) results = [] diff --git a/cmake/External/ck.cmake b/cmake/External/ck.cmake index 59ee354e3bd76..ac2165a701d15 100644 --- a/cmake/External/ck.cmake +++ b/cmake/External/ck.cmake @@ -4,18 +4,18 @@ # get CK commit hash execute_process( - COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel - RESULT_VARIABLE result - OUTPUT_VARIABLE submodule_status - ERROR_VARIABLE submodule_status_error - OUTPUT_STRIP_TRAILING_WHITESPACE - ) + COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel + RESULT_VARIABLE result + OUTPUT_VARIABLE submodule_status + ERROR_VARIABLE submodule_status_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) if(result EQUAL 0) - string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status}) - # extract first 8 characters of the commit hash - string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash) + string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status}) + # extract first 8 characters of the commit hash + string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash) else() - message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") + message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") endif() # get ROCm version from LoadHIP.cmake @@ -31,9 +31,9 @@ set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so") # download CK library file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status) if(NOT download_status) - message(STATUS "Downloaded CK library successfully.") + message(STATUS "Downloaded CK library successfully.") else() - message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.") + message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.") endif() # create INTERFACE target @@ -41,4 +41,3 @@ add_library(__ck_lib INTERFACE) # specify path to CK library target_link_libraries(__ck_lib INTERFACE ${destination}) - From 507e0aa5b054253f13bc8cdfe6826baea425f464 Mon Sep 17 00:00:00 2001 From: Akash Verma Date: Fri, 4 Apr 2025 21:09:14 +0000 Subject: [PATCH 3/3] Removed duplicate include for ROCM_VERSION_DEV. Corrected indentation and code cleanup. --- CMakeLists.txt | 4 ++-- aten/src/ATen/CMakeLists.txt | 7 ------- cmake/External/ck.cmake | 3 --- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 944d7cd557e9a..b25220d8b14ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -895,9 +895,9 @@ if(USE_ROCM) endif() endif() -# CK shared lib linkage +# link CK library if(USE_ROCM) - if(UNIX AND (USE_CK_FLASH_ATTENTION)) + if(UNIX AND USE_CK_FLASH_ATTENTION) include(cmake/External/ck.cmake) endif() endif() diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6378342367ff9..e40b282388cd9 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -173,13 +173,6 @@ file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) if(USE_CK_FLASH_ATTENTION) - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") message(STATUS "Generating CK kernel instances...") # disable buidling CK files diff --git a/cmake/External/ck.cmake b/cmake/External/ck.cmake index ac2165a701d15..ca27911d36f83 100644 --- a/cmake/External/ck.cmake +++ b/cmake/External/ck.cmake @@ -18,9 +18,6 @@ else() message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") endif() -# get ROCm version from LoadHIP.cmake -include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake) - # full path for CK library on compute-artifactory.amd.com set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local") set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so")