Skip to content

[Distributed] Add custom allreduce support for ROCM #14125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 1, 2025

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Mar 3, 2025

Enable custom allreduce for AMD gpus (MI300X)
Fix custom allreduce test
Set RAY_RUNTIME_ENV_IGNORE_GITIGNORE in tests utils that use ray as it ignores .so files uploading sources to ray workers.

@ilmarkov ilmarkov requested a review from tlrmchlsmth as a code owner March 3, 2025 08:49
Copy link

github-actions bot commented Mar 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@ilmarkov ilmarkov marked this pull request as draft March 3, 2025 08:49
@mergify mergify bot added the ci/build label Mar 3, 2025
@ilmarkov ilmarkov marked this pull request as ready for review March 4, 2025 15:50
@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Mar 4, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work on this, @ilmarkov. I made a first pass and left some comments.


def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Small typo "According to"

// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
#if !defined(USE_ROCM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes more sense to move the ifdef's to outside of start_sync and end_sync? That way the two implementations are completely separate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think it is going to improve code readability, sure

// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
alignas(128) FlagType start[kMaxBlocks][8];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a longer comment is warranted here. Especially since you are deleting one :).

};

struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
const void*
#if !defined(USE_ROCM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked it. The __restrict__ keyword within a struct does not have a performance implications. According to this answer there is no effect of it in nvcc. hipcc in its turn raises a compiler error when the keyword is used. So I will remove the keyword. for both platforms in this case.

// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
end_sync<ngpus, true>(sg, self_sg, rank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why you changed the boolean value here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The template parameter meaning has changed. It is opposite now. "True" used to mean here that we want to do a memory fence, now it means that we don't have to do any additional thread sync as this is the end of the kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit on naming: The names start_sync and end_sync makes me think that we're starting and ending the synchronization, and that both are needed to make a barrier -- but IIUC,
it should be read as sync_at_start and sync_at_end, so consider renaming the functions for clarity

tests/utils.py Outdated
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you clean this up? I assume you had some issue with setting the runtime_env but it does look like it's necessary.

lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
if current_platform.is_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this platform specific dispatching logic does make me think that we may be close to wanting to split CustomAllReduce into platform specific subclasses.

@@ -5,6 +5,9 @@
from typing import TYPE_CHECKING, Dict, List, Optional

import torch
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: looks like a few of these are unused. Can you trim down the import?

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left another round of comments. I generally feel good about this change for the AMD side of things but I'm a bit more hesitant to change the barrier logic on the Nvidia side. Assuming the algorithms are the same, I think we can run some Nvidia benchmarks and convince ourselves that we haven't regressed performance at all. CC @tlrmchlsmth

*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) {
#if !defined(USE_ROCM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you #define a BLOCK_LIMIT outside of this function declaration and do the USE_ROCM dispatching there?

@@ -362,7 +434,11 @@ class CustomAllreduce {
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
#if defined(USE_ROCM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same idea here. I generally think it's cleaner to do this kind of ifdef dispatching when declaring the constant variable and not in the implementation of a function.

@@ -237,7 +305,8 @@ DINLINE P* get_tmp_buf(Signal* sg) {
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
T* __restrict__ result, int rank, int size,
int compression_factor = 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It looks like you aren't using the compression factor yet? If not, let's hold off adding it until it's used to eliminate any confusion in the short-term.

@ilmarkov
Copy link
Contributor Author

Thank you for the review, @SageMoore. I will do the ifdef cleaning.

Assuming the algorithms are the same, I think we can run some Nvidia benchmarks and convince ourselves that we haven't regressed performance at all.

The barrier logic on Nvidia has not been changed. Only name of the barrier functions is changed, it is done like in Rocm fork. But barrier execution on Nvidia left untouched.

@kahakuka
Copy link

@ilmarkov Hello, I am in version 5.7 of rocm. The 'hipIpcMemLazyEnabled PeerAccess' of the following function needs to be' 0 '. The accuracy is incorrect after changing to '0'. Excuse me, does this require a higher version of rocm to customize all reduce this time?

@ilmarkov
Copy link
Contributor Author

@kahakuka Thank you for the note! Are you using MI200 GPU?

CMakeLists.txt Outdated
Comment on lines 523 to 527
if(VLLM_GPU_LANG STREQUAL "HIP")
list(APPEND VLLM_EXT_SRC
"csrc/custom_all_reduce.cu")
endif()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this into the initial set(VLLM_EXT_SRC statement starting around line 229

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes right after the long CUDA if block. Do you think it makes sense to put HIP-related if block before CUDA?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you put it in this block, it will compile for both CUDA and HIP so no conditionals are needed

vllm/CMakeLists.txt

Lines 229 to 244 in 2bb0e1a

set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")

Comment on lines 146 to 177
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif

AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));

auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));

return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why the changes to the allocation and initialization code are needed?

Also I find this function very difficult to read. Could you add whitespace and comments to improve this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the allocation and IPC initialization code was done in python wrapper of CUDA. However, in case of Rocm, there is a library version dependency on some of the constants (e.g. hipIpcMemLazyEnabledPeerAccess) fixing which in python would look ugly, in my opinion. So I decided to move the allocation code to cpp.
I don't think it will affect the initialization performance.

I will prettify the function.

Comment on lines 258 to 339
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
start_sync<ngpus>(sg, self_sg, rank);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we could pretty easily isolate the changes to the synchronization to just the RoCM case. If so I think we should do that as it would make this PR much less risky.

Currently I think this PR would need quite a bit of benchmarking to convince us that there won't be a performance regression in some cases on CUDA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of execution nothing has changed on CUDA in this PR. We only change naming of the sync functions and Signal struct fields. All the sync functions and kernel are the same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll take another look in that case

@@ -93,7 +93,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
fa = get_tp_group().ca_comm
fa = get_tp_group().device_communicator.ca_comm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GroupCoordinator (returned by get_tp_group()) does not have ca_comm field. Apparently, it was moved to CudaCommunicator class at some point and the test wasn't changed.

Comment on lines 141 to 143
@staticmethod
@with_amdsmi_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is a little confusing and we should rename it so that the name sounds like it applies in both the nvidia and amd case. I suggest is_fully_connected_nvlink_or_xgmi

@kahakuka
Copy link

@kahakuka感谢您的留言!您使用的是 MI200 GPU 吗?

Yes, I also tried it on MI250, it's the same.

Copy link

mergify bot commented Mar 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@kahakuka
Copy link

kahakuka commented Mar 18, 2025

@ilmarkov Hello, Thank you very much, it can run with the latest code. may I ask another question? Is your test in eager mode or CUDAGraph mode. I tested the Eagle mode and found it to be normal, but there may be issues with the accuracy of the Cudagraph mode.

@ilmarkov
Copy link
Contributor Author

ilmarkov commented Mar 18, 2025

@kahakuka Hi, the distributed/test_custom_all_reduce.py test verifies both eager and cudaGraph modes. Can you share the script so that we could reproduce the accuracy issue?

@kahakuka
Copy link

@kahakuka Hi, the distributed/test_custom_all_reduce.py test verifies both eager and cudaGraph modes. Can you share the script so that we could reproduce the accuracy issue?
@ilmarkov
Thank you very much for your reply. I validated the accuracy through testing the model. The testing method is as follows:
Server side:
eager:
HIP_VISIBLE_DEVICES=6,7 vllm serve /data/models/Qwen2-7B-Instruct --enforce-eager --dtype float16 --trust-remote-code -tp 2
cudagraph:
HIP_VISIBLE_DEVICES=6,7 vllm serve /data/models/Qwen2-7B-Instruct --dtype float16 --trust-remote-code -tp 2

client:
curl http://localhost:8000/v1/chat/completions
-H "Content-Type: application/json"
-d '{
"model": "/data/models/Qwen2-7B-Instruct",
"messages": [
{"role": "system", "content": "hello"}
]
}'
The answer from eager is normal, while the response from CUDAGraph is garbled.

@ilmarkov
Copy link
Contributor Author

@kahakuka I didn't manage to reproduce the issue on MI300X. So we decided to disable custom_allreduce on older AMD GPUs so that we could land this PR.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, Ilia. All of my comments are just minor cosmetic things. I'll go ahead and approve.

@@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));

fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
bool full_connected) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: fully_connected

@@ -101,7 +101,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True

@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
def is_fully_connected_nvlink_or_xgmi(cls, device_ids: List[int]) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: is_fully_connected should be sufficient here.

@@ -72,6 +70,8 @@ def __init__(self,
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
logger.warning("Custom allreduce is disabled because "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good addition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should only be a warning on CUDA and RoCM platforms. IIUC this will always log a warning on other platforms and this should not be a warning e.g. on TPUs.

Also, could you change the comment on line 72 to # e.g. in a non-GPU environment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess this is only constructed in the cuda_communicator, so not concerned about TPUs now.

However, it does look like this will always log a warning on AMD GPUs earlier than MI300X, and we shouldn't be warning users about expected behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom allreduce library can be built on AMD GPU earlier than MI300X. We just don't enable it. So there shouldn't be a warning as we will not try to create CustomAllreduce object.

};

struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a quick comment (in the PR not the code) as a reply to this comment explaining why __restrict__ is no longer safe? Or was never safe in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __restrict__ keyword within a struct does not have implications in CUDA. According to this answer there is no effect of it in nvcc. hipcc in its turn raises a compiler error when the keyword is used.

@kahakuka
Copy link

@ilmarkov
Hello, the code interface functions for testing do not match the actual application. I'm not sure how you compiled them on your end.
image

@kahakuka
Copy link

@ilmarkov Hello, I have a question for you: I have looked at the source code of nv. In the case of PCIe, even without checking nvlink, custom allreduce can still be used. Is it also possible on the rocm side.

ilmarkov added 7 commits March 31, 2025 11:18
Signed-off-by: ilmarkov <[email protected]>
pre_commit error
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 31, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 31, 2025 13:22
@vllm-bot vllm-bot merged commit b7b7676 into vllm-project:main Apr 1, 2025
65 of 67 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Apr 2, 2025
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
if current_platform.is_rocm():
device_capability = current_platform.get_device_capability()
if (current_platform.is_rocm() and device_capability is not None
and device_capability < (9, 4)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small note here
device_capability should not be used on ROCm, as it is not sequential, as in bigger does not mean more capability, since on Radeon it is (11, x), or (12, x), but often times custom all reduce can not be used there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras Thank you for the note! Is there any way to enable it on MI300 and newer gpus other than checking gcnArchName of the device properties?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afraid not, nothing that I'm aware of
The best way that I know of is something like https://github.com/vllm-project/vllm/blob/main/vllm/platforms/rocm.py#L265-L267

Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 30, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 7, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 7, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 9, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 12, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 13, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 13, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants