Skip to content

Commit e08ba42

Browse files
authored
feat: add group gemm operators (#282)
First step towards #199 . Group gemm should also be helpful for MoE.
1 parent 7aadc0d commit e08ba42

24 files changed

+764
-37
lines changed

Diff for: docs/api/python/decode.rst

+4
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,9 @@ Batch Decoding
2525
.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
2626
:members:
2727

28+
.. automethod:: __init__
29+
2830
.. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper
2931
:members:
32+
33+
.. automethod:: __init__

Diff for: docs/api/python/group_gemm.rst

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. _apigroup_gemm:
2+
3+
flashinfer.group_gemm
4+
=====================
5+
6+
This module provides a set of functions to group GEMM operations.
7+
8+
.. currentmodule:: flashinfer.group_gemm
9+
10+
.. autoclass:: SegmentGEMMWrapper
11+
:members:
12+
13+
.. automethod:: __init__

Diff for: docs/api/python/prefill.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ Batch Prefill/Append Attention
2222
.. autoclass:: BatchPrefillWithPagedKVCacheWrapper
2323
:members:
2424

25+
.. automethod:: __init__
26+
2527
.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper
2628
:members:
27-
29+
30+
.. automethod:: __init__

Diff for: docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform
3232
api/python/cascade
3333
api/python/page
3434
api/python/sampling
35+
api/python/group_gemm
3536
api/python/norm

Diff for: include/flashinfer/allocator.h

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_ALLOCATOR_H_
17+
#define FLASHINFER_ALLOCATOR_H_
18+
19+
#include <memory>
20+
#include <stdexcept>
21+
22+
namespace flashinfer {
23+
24+
struct AlignedAllocator {
25+
void* ptr;
26+
size_t space;
27+
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
28+
template <typename T>
29+
T* aligned_alloc(size_t size, size_t alignment) {
30+
if (std::align(alignment, size, ptr, space)) {
31+
T* result = reinterpret_cast<T*>(ptr);
32+
ptr = (char*)ptr + size;
33+
space -= size;
34+
return result;
35+
} else {
36+
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
37+
}
38+
return nullptr;
39+
}
40+
};
41+
42+
} // namespace flashinfer
43+
44+
#endif // FLASHINFER_ALLOCATOR_H_

Diff for: include/flashinfer/attention/handler.cuh

+4-23
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#ifndef FLASHINFER_HANDLER_CUH_
17-
#define FLASHINFER_HANDLER_CUH_
16+
#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_
17+
#define FLASHINFER_ATTENTION_HANDLER_CUH_
1818

1919
#include <algorithm>
2020
#include <cstddef>
21-
#include <memory>
2221
#include <sstream>
23-
#include <unordered_map>
2422
#include <vector>
2523

24+
#include "../allocator.h"
2625
#include "../page.cuh"
2726
#include "../pos_enc.cuh"
2827
#include "../utils.cuh"
@@ -241,24 +240,6 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
241240
return cudaSuccess;
242241
}
243242

244-
struct AlignedAllocator {
245-
void* ptr;
246-
size_t space;
247-
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
248-
template <typename T>
249-
T* aligned_alloc(size_t size, size_t alignment) {
250-
if (std::align(alignment, size, ptr, space)) {
251-
T* result = reinterpret_cast<T*>(ptr);
252-
ptr = (char*)ptr + size;
253-
space -= size;
254-
return result;
255-
} else {
256-
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
257-
}
258-
return nullptr;
259-
}
260-
};
261-
262243
class BatchDecodeHandler {
263244
public:
264245
template <typename DType>
@@ -584,4 +565,4 @@ class BatchPrefillHandler {
584565
};
585566

586567
} // namespace flashinfer
587-
#endif // FLASHINFER_HANDLER_CUH_
568+
#endif // FLASHINFER_ATTENTION_HANDLER_CUH_

Diff for: include/flashinfer/group_gemm/group_gemm_cutlass.cuh

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
17+
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
18+
19+
#include "cutlass/cutlass.h"
20+
#include "cutlass/gemm/device/gemm_grouped.h"
21+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
22+
#include "cutlass/layout/matrix.h"
23+
#include "cutlass/numeric_types.h"
24+
25+
namespace flashinfer {
26+
27+
namespace group_gemm {
28+
29+
template <typename T>
30+
struct cutlass_dtype {
31+
using type = T;
32+
};
33+
34+
template <>
35+
struct cutlass_dtype<half> {
36+
using type = cutlass::half_t;
37+
};
38+
39+
template <>
40+
struct cutlass_dtype<nv_bfloat16> {
41+
using type = cutlass::bfloat16_t;
42+
};
43+
44+
template <typename T>
45+
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
46+
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
47+
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
48+
int64_t* w_indices, size_t d_in, size_t d_out,
49+
bool w_column_major) {
50+
int i = blockIdx.x;
51+
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
52+
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
53+
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
54+
ptr_x[i] = x + xy_indptr[i] * d_in;
55+
ptr_y[i] = y + xy_indptr[i] * d_out;
56+
ld_x[i] = k; // m * k
57+
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
58+
ld_y[i] = n; // m * n
59+
}
60+
61+
} // namespace group_gemm
62+
63+
} // namespace flashinfer
64+
65+
#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_

Diff for: include/flashinfer/group_gemm/group_gemm_lora.cuh

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_
17+
#define FLASHINFER_GROUP_GEMM_LORA_CUH_
18+
19+
namespace flashinfer {
20+
21+
namespace group_gemm {
22+
23+
// TODO(Zihao): port punica's sgmv kernel
24+
25+
} // namespace group_gemm
26+
27+
} // namespace flashinfer
28+
29+
#endif // FLASHINFER_GROUP_GEMM_LORA_CUH_

Diff for: include/flashinfer/group_gemm/group_gemv.cuh

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_GROUP_GEMV_CUH_
17+
#define FLASHINFER_GROUP_GEMV_CUH_
18+
19+
namespace flashinfer {
20+
21+
namespace group_gemm {
22+
23+
// TODO(Zihao): port punica's bgmv kernel
24+
25+
} // namespace group_gemm
26+
27+
} // namespace flashinfer
28+
29+
#endif // FLASHINFER_GROUP_GEMV_CUH_

Diff for: include/flashinfer/group_gemm/handler.cuh

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_
17+
#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_
18+
19+
#include <cstddef>
20+
21+
#include "../allocator.h"
22+
#include "../utils.cuh"
23+
#include "group_gemm_cutlass.cuh"
24+
#include "group_gemm_lora.cuh"
25+
#include "group_gemv.cuh"
26+
27+
namespace flashinfer {
28+
29+
namespace group_gemm {
30+
31+
enum class GroupGEMMKernelConfig {
32+
kGeneral, // large d_in, d_out
33+
kShrink, // large d_in, small d_out
34+
kExpand, // small d_in, large d_out
35+
};
36+
37+
class CutlassSegmentGEMMHandler {
38+
public:
39+
void RegisterWorkspace(void* buffer, size_t size) {
40+
buffer_ = buffer;
41+
workspace_size_in_bytes_ = size;
42+
}
43+
44+
void* GetWorkspace() const { return buffer_; }
45+
46+
size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; }
47+
48+
cudaStream_t GetCUDAStream() const { return stream_; }
49+
50+
void SetCUDAStream(cudaStream_t stream) { stream_ = stream; }
51+
52+
CutlassSegmentGEMMHandler() {}
53+
54+
~CutlassSegmentGEMMHandler() {}
55+
56+
private:
57+
void* buffer_;
58+
size_t workspace_size_in_bytes_;
59+
cudaStream_t stream_;
60+
};
61+
62+
} // namespace group_gemm
63+
64+
} // namespace flashinfer
65+
66+
#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_

0 commit comments

Comments
 (0)