Skip to content

Merge CK validation to release/2.6 #2016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,13 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

cmake_dependent_option(
USE_CK_FLASH_ATTENTION
"Whether to build the CK flash_attention kernel. Will be enabled if USE_FLASH_ATTENTION is enabled."
ON
"USE_FLASH_ATTENTION"
OFF)

# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
Expand All @@ -888,6 +895,13 @@ if(USE_ROCM)
endif()
endif()

# CK shared lib linkage
if(USE_ROCM)
if(UNIX AND (USE_CK_FLASH_ATTENTION))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the UNIX part since USE_ROCM is defined as cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)?
Also, please correct indentation

Copy link
Author

@akashveramd akashveramd Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aotriton.cmake is getting linked similarly using condition if(UNIX AND (USE_FLASH_ATTENTION...
I borrowed the logic from there and linked in a similar way.
Will correct the indentation and the spacing though.

include(cmake/External/ck.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
28 changes: 13 additions & 15 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,20 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
if(USE_CK_FLASH_ATTENTION)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
# disable buidling CK files
# add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
Expand Down
63 changes: 0 additions & 63 deletions aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt

This file was deleted.

4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,10 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
endif()
# link CK library
if(USE_CK_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __ck_lib)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not
Expand Down
43 changes: 43 additions & 0 deletions cmake/External/ck.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# create INTERFACE target for CK library
#

# get CK commit hash
execute_process(
COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel
RESULT_VARIABLE result
OUTPUT_VARIABLE submodule_status
ERROR_VARIABLE submodule_status_error
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(result EQUAL 0)
string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status})
# extract first 8 characters of the commit hash
string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash)
else()
message(FATAL_ERROR "Failed to get submodule status for composable_kernel.")
endif()

# get ROCm version from LoadHIP.cmake
include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake)

# full path for CK library on compute-artifactory.amd.com
set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pruthvistony @jeffdaily Should this be a concern about putting artifactory links in a public repo?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reuse the release page of https://github.com/ROCm/composable_kernel
(There is no need to create new releases, just re-use any existing release to store the assets)

Copy link
Collaborator

@jithunnair-amd jithunnair-amd Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think https://github.com/ROCm/composable_kernel is an option, since we don't own that repo and I don't think it'd be appropriate to store these artifacts on their repo, given that they are generated using scripts that are not even in their repo.
Alternately, we could instead create releases on https://github.com/ROCm/CK_kernels, since we own that repo.
@pruthvistony, can we file an OSRB request for CK_kernels repo (if we want to use the same solution for upstream)?

set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so")

# set destination
set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so")

# download CK library
file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status)
if(NOT download_status)
message(STATUS "Downloaded CK library successfully.")
else()
message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.")
endif()

# create INTERFACE target
add_library(__ck_lib INTERFACE)

# specify path to CK library
target_link_libraries(__ck_lib INTERFACE ${destination})