Skip to content

Commit 03553da

Browse files
yzh119liangyuRain
andauthored
feat: initial support of distributed operators (#289)
This PR implements the attention all-reduce kernel which will be used in merging attention states from different GPUs in sequence parallelism. We use [mscclpp](https://github.com/microsoft/mscclpp) for collective communications, thank @liangyuRain for teaching me how to use mscclpp. Co-authored-by: Liangyu Zhao <[email protected]>
1 parent 809abaa commit 03553da

File tree

8 files changed

+528
-4
lines changed

8 files changed

+528
-4
lines changed

Diff for: .gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
[submodule "3rdparty/composable_kernels"]
1414
path = 3rdparty/composable_kernels
1515
url = https://github.com/ROCm/composable_kernel.git
16+
[submodule "3rdparty/spdlog"]
17+
path = 3rdparty/spdlog
18+
url = [email protected]:gabime/spdlog.git

Diff for: 3rdparty/spdlog

Submodule spdlog added at c3aed4b

Diff for: CMakeLists.txt

+24-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ flashinfer_option(FLASHINFER_PAGE "Whether to compile page kernel tests/benchmar
3131
flashinfer_option(FLASHINFER_CASCADE "Whether to compile cascade kernel tests/benchmarks or not." OFF)
3232
flashinfer_option(FLASHINFER_SAMPLING "Whether to compile sampling kernel tests/benchmarks or not." OFF)
3333
flashinfer_option(FLASHINFER_NORM "Whether to compile normalization kernel tests/benchmarks or not." OFF)
34+
flashinfer_option(FLASHINFER_DISTRIBUTED "Whether to compile distributed kernel tests/benchmarks or not." OFF)
3435
flashinfer_option(FLASHINFER_TVM_BINDING "Whether to compile tvm binding or not." OFF)
3536
flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm binding." "")
3637

@@ -55,7 +56,11 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
5556
if(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
5657
message(STATUS "NVBench and GoogleTest enabled")
5758
add_subdirectory(3rdparty/nvbench)
58-
add_subdirectory(3rdparty/googletest)
59+
if(FLASHINFER_DISTRIBUTED)
60+
add_subdirectory(3rdparty/mscclpp)
61+
else(FLASHINFER_DISTRIBUTED)
62+
add_subdirectory(3rdparty/googletest)
63+
endif(FLASHINFER_DISTRIBUTED)
5964
endif(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
6065
find_package(Thrust REQUIRED)
6166

@@ -470,3 +475,21 @@ if(FLASHINFER_FASTDIV_TEST)
470475
target_include_directories(test_fastdiv PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
471476
target_link_libraries(test_fastdiv PRIVATE gtest gtest_main)
472477
endif(FLASHINFER_FASTDIV_TEST)
478+
479+
if (FLASHINFER_DISTRIBUTED)
480+
find_package(MPI REQUIRED)
481+
482+
message(STATUS "Compile sum all-reduce kernel tests.")
483+
file(GLOB_RECURSE TEST_DIST_SUM_ALL_REDUCE_SRCS ${PROJECT_SOURCE_DIR}/src/test_sum_all_reduce.cu)
484+
add_executable(test_sum_all_reduce ${TEST_DIST_SUM_ALL_REDUCE_SRCS})
485+
target_include_directories(test_sum_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include)
486+
target_link_libraries(test_sum_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
487+
target_compile_definitions(test_sum_all_reduce PRIVATE -DENABLE_MPI)
488+
489+
message(STATUS "Compile attention allreduce kernel tests.")
490+
file(GLOB_RECURSE TEST_DIST_ATTN_ALL_REDUCE_SRCS ${PROJECT_SOURCE_DIR}/src/test_attn_all_reduce.cu)
491+
add_executable(test_attn_all_reduce ${TEST_DIST_ATTN_ALL_REDUCE_SRCS})
492+
target_include_directories(test_attn_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include)
493+
target_link_libraries(test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
494+
target_compile_definitions(test_attn_all_reduce PRIVATE -DENABLE_MPI)
495+
endif(FLASHINFER_DISTRIBUTED)

Diff for: cmake/config.cmake

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ set(FLASHINFER_SAMPLING ON)
1717
# Whether to compile normalization kernel tests/benchmarks or not.
1818
set(FLASHINFER_NORMALIZATION ON)
1919
# Whether to compile fastdiv tests
20-
set(FLASHINFER_FASTDIV_TEST OFF)
20+
set(FLASHINFER_FASTDIV_TEST ON)
21+
# Whether to compile distributed tests
22+
set(FLASHINFER_DISTRIBUTED ON)
2123
# The following configurations can impact the binary
2224
# size of the generated library
2325
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)

Diff for: include/flashinfer/attention/cascade.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other
408408
* \brief Merge self-attention states of a list of index sets.
409409
* \tparam DTypeIn The data type of v.
410410
* \tparam DTypeOut The data type of v_merged.
411-
* \param v The partial v of index sets. (num_index_sets, n, h, d)
412-
* \param s The logsumexp value of index sets. (num_index_sets, n, h)
411+
* \param v The partial v of index sets. (n, num_index_sets, h, d)
412+
* \param s The logsumexp value of index sets. (n, num_index_sets, h)
413413
* \param v_merged The merged v of index sets union. (n, h, d)
414414
* \param s_merged The merged logsumexp value of index sets union. (n, h)
415415
* \param num_index_sets The number of index sets.

Diff for: include/flashinfer/distributed/all_reduce.cuh

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_DISTRIBUTED_ALL_REDUCE_CUH_
17+
#define FLASHINFER_DISTRIBUTED_ALL_REDUCE_CUH_
18+
19+
#include <mscclpp/concurrency_device.hpp>
20+
#include <mscclpp/core.hpp>
21+
#include <mscclpp/proxy_channel.hpp>
22+
#include <mscclpp/proxy_channel_device.hpp>
23+
#include <mscclpp/sm_channel.hpp>
24+
#include <mscclpp/sm_channel_device.hpp>
25+
26+
#include "../attention/state.cuh"
27+
#include "../vec_dtypes.cuh"
28+
29+
namespace flashinfer {
30+
31+
namespace distributed {
32+
33+
void SetupChannels(mscclpp::Communicator* comm, std::vector<mscclpp::SmChannel>& sm_channels,
34+
int rank, int nranks, void* buff, size_t buff_size_in_bytes) {
35+
const mscclpp::TransportFlags all_transports = mscclpp::Transport::CudaIpc;
36+
mscclpp::RegisteredMemory buf_reg_mem =
37+
comm->registerMemory(buff, buff_size_in_bytes, all_transports);
38+
39+
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
40+
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_reg_mem;
41+
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
42+
43+
for (int r = 0; r < nranks; ++r) {
44+
if (r == rank) continue;
45+
46+
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
47+
conn_futures.push_back(comm->connectOnSetup(r, 0, transport));
48+
49+
comm->sendMemoryOnSetup(buf_reg_mem, r, 0);
50+
auto remoteMemory = comm->recvMemoryOnSetup(r, 0);
51+
remote_reg_mem.push_back(remoteMemory);
52+
}
53+
comm->setup();
54+
std::transform(
55+
conn_futures.begin(), conn_futures.end(), std::back_inserter(connections),
56+
[](const mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>& future) {
57+
return future.get();
58+
});
59+
60+
std::unordered_map<size_t, std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> sm_semaphores;
61+
for (size_t cid = 0; cid < connections.size(); ++cid) {
62+
sm_semaphores.emplace(
63+
cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm, connections[cid]));
64+
}
65+
comm->setup();
66+
67+
for (size_t cid = 0; cid < connections.size(); ++cid) {
68+
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
69+
sm_channels.emplace_back(sm_semaphores[cid], remote_reg_mem[cid].get(), buf_reg_mem.data());
70+
}
71+
}
72+
}
73+
74+
constexpr uint32_t MAX_RANKS = 8;
75+
__device__ mscclpp::DeviceSyncer device_syncer;
76+
77+
template <typename DType>
78+
__global__ void AttentionAllReduceInplaceKernel(mscclpp::SmChannelDeviceHandle* sm_channels,
79+
uint8_t* buf, const uint32_t rank,
80+
const uint32_t num_ranks, const uint32_t batch_size,
81+
const uint32_t num_heads, const uint32_t head_dim) {
82+
const uint32_t vec_size = 16 / sizeof(DType);
83+
const size_t chunk_size = head_dim / num_ranks;
84+
if (num_ranks == 1) return;
85+
const uint32_t num_peers = num_ranks - 1;
86+
const uint32_t tid = threadIdx.x + blockDim.x * (threadIdx.y + blockIdx.x * blockDim.y);
87+
const uint32_t tx = threadIdx.x;
88+
const uint32_t head_id = threadIdx.y;
89+
const uint32_t batch_id = blockIdx.x;
90+
DType* v_buf = (DType*)buf;
91+
float* s_buf = (float*)(buf + batch_size * num_heads * head_dim * sizeof(DType));
92+
93+
if (tid < num_peers) {
94+
sm_channels[tid].signal();
95+
sm_channels[tid].wait();
96+
}
97+
device_syncer.sync(gridDim.x);
98+
99+
float other_lse[MAX_RANKS - 1], self_lse = s_buf[batch_id * num_heads + head_id];
100+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
101+
int peer_idx = (round_idx + rank);
102+
if (peer_idx >= num_peers) peer_idx -= num_peers;
103+
other_lse[round_idx] =
104+
((float*)(sm_channels[peer_idx].dst_ + batch_size * num_heads * head_dim *
105+
sizeof(DType)))[batch_id * num_heads + head_id];
106+
}
107+
108+
state_t<vec_size> tmp;
109+
for (uint32_t elem_idx = tx; elem_idx < chunk_size / vec_size; elem_idx += blockDim.x) {
110+
tmp.init();
111+
tmp.o.cast_load(v_buf + (batch_id * num_heads + head_id) * head_dim + rank * chunk_size +
112+
elem_idx * vec_size);
113+
tmp.m = self_lse;
114+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
115+
int peer_idx = (round_idx + rank);
116+
if (peer_idx >= num_peers) peer_idx -= num_peers;
117+
vec_t<float, vec_size> other_v;
118+
other_v.cast_load(((DType*)sm_channels[peer_idx].dst_) +
119+
(batch_id * num_heads + head_id) * head_dim + rank * chunk_size +
120+
elem_idx * vec_size);
121+
tmp.merge(other_v, other_lse[round_idx], 1);
122+
}
123+
tmp.normalize();
124+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
125+
int peer_idx = (round_idx + rank);
126+
if (peer_idx >= num_peers) peer_idx -= num_peers;
127+
tmp.o.cast_store(((DType*)sm_channels[peer_idx].dst_) +
128+
(batch_id * num_heads + head_id) * head_dim + rank * chunk_size +
129+
elem_idx * vec_size);
130+
}
131+
tmp.o.cast_store(v_buf + (batch_id * num_heads + head_id) * head_dim + rank * chunk_size +
132+
elem_idx * vec_size);
133+
}
134+
float lse = tmp.get_lse();
135+
device_syncer.sync(gridDim.x);
136+
137+
if (tx == 0) {
138+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
139+
int peer_idx = (round_idx + rank);
140+
if (peer_idx >= num_peers) peer_idx -= num_peers;
141+
((float*)(sm_channels[peer_idx].dst_ + batch_size * num_heads * head_dim *
142+
sizeof(DType)))[batch_id * num_heads + head_id] =
143+
lse;
144+
}
145+
s_buf[batch_id * num_heads + head_id] = lse;
146+
}
147+
148+
device_syncer.sync(gridDim.x);
149+
if (tid < num_peers) {
150+
sm_channels[tid].signal();
151+
sm_channels[tid].wait();
152+
}
153+
}
154+
155+
template <typename DType, typename ReduceDType>
156+
__global__ void SumAllReduceInplaceKernel(mscclpp::SmChannelDeviceHandle* sm_channels, DType* buf,
157+
const uint32_t rank, const uint32_t num_ranks,
158+
const size_t num_elems) {
159+
const uint32_t vec_size = 16 / sizeof(DType);
160+
const size_t chunk_size = num_elems / num_ranks;
161+
if (num_ranks == 1) return;
162+
const uint32_t num_peers = num_ranks - 1;
163+
const uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
164+
165+
if (tid < num_peers) {
166+
sm_channels[tid].signal();
167+
sm_channels[tid].wait();
168+
}
169+
device_syncer.sync(gridDim.x);
170+
171+
size_t num_vec_per_chunk = chunk_size / vec_size;
172+
// use int4 as much as possible
173+
for (uint32_t i = tid; i < num_vec_per_chunk; i += blockDim.x * gridDim.x) {
174+
vec_t<ReduceDType, vec_size> tmp;
175+
tmp.cast_load(buf + rank * chunk_size + i * vec_size);
176+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
177+
int peer_idx = (round_idx + rank);
178+
if (peer_idx >= num_peers) peer_idx -= num_peers;
179+
vec_t<ReduceDType, vec_size> val;
180+
val.cast_load(((DType*)sm_channels[peer_idx].dst_) + rank * chunk_size + i * vec_size);
181+
#pragma unroll
182+
for (int j = 0; j < vec_size; ++j) {
183+
tmp[j] += val[j];
184+
}
185+
}
186+
for (uint32_t round_idx = 0; round_idx < num_peers; ++round_idx) {
187+
int peer_idx = (round_idx + rank);
188+
if (peer_idx >= num_peers) peer_idx -= num_peers;
189+
tmp.cast_store(((DType*)sm_channels[peer_idx].dst_) + rank * chunk_size + i * vec_size);
190+
}
191+
tmp.cast_store(buf + rank * chunk_size + i * vec_size);
192+
}
193+
194+
device_syncer.sync(gridDim.x);
195+
if (tid < num_peers) {
196+
sm_channels[tid].signal();
197+
sm_channels[tid].wait();
198+
}
199+
}
200+
201+
} // namespace distributed
202+
203+
} // namespace flashinfer
204+
205+
#endif // FLASHINFER_DISTRIBUTED_ALL_REDUCE_CUH_

0 commit comments

Comments
 (0)