Skip to content

feat: add group gemm operators #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,9 @@ Batch Decoding
.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper
:members:

.. automethod:: __init__
13 changes: 13 additions & 0 deletions docs/api/python/group_gemm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _apigroup_gemm:

flashinfer.group_gemm
=====================

This module provides a set of functions to group GEMM operations.

.. currentmodule:: flashinfer.group_gemm

.. autoclass:: SegmentGEMMWrapper
:members:

.. automethod:: __init__
5 changes: 4 additions & 1 deletion docs/api/python/prefill.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Batch Prefill/Append Attention
.. autoclass:: BatchPrefillWithPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper
:members:


.. automethod:: __init__
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/cascade
api/python/page
api/python/sampling
api/python/group_gemm
api/python/norm
44 changes: 44 additions & 0 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ALLOCATOR_H_
#define FLASHINFER_ALLOCATOR_H_

#include <memory>
#include <stdexcept>

namespace flashinfer {

struct AlignedAllocator {
void* ptr;
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
}
return nullptr;
}
};

} // namespace flashinfer

#endif // FLASHINFER_ALLOCATOR_H_
27 changes: 4 additions & 23 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_HANDLER_CUH_
#define FLASHINFER_HANDLER_CUH_
#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_
#define FLASHINFER_ATTENTION_HANDLER_CUH_

#include <algorithm>
#include <cstddef>
#include <memory>
#include <sstream>
#include <unordered_map>
#include <vector>

#include "../allocator.h"
#include "../page.cuh"
#include "../pos_enc.cuh"
#include "../utils.cuh"
Expand Down Expand Up @@ -241,24 +240,6 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
return cudaSuccess;
}

struct AlignedAllocator {
void* ptr;
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
}
return nullptr;
}
};

class BatchDecodeHandler {
public:
template <typename DType>
Expand Down Expand Up @@ -584,4 +565,4 @@ class BatchPrefillHandler {
};

} // namespace flashinfer
#endif // FLASHINFER_HANDLER_CUH_
#endif // FLASHINFER_ATTENTION_HANDLER_CUH_
65 changes: 65 additions & 0 deletions include/flashinfer/group_gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"

namespace flashinfer {

namespace group_gemm {

template <typename T>
struct cutlass_dtype {
using type = T;
};

template <>
struct cutlass_dtype<half> {
using type = cutlass::half_t;
};

template <>
struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};

template <typename T>
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
int64_t* w_indices, size_t d_in, size_t d_out,
bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
ptr_x[i] = x + xy_indptr[i] * d_in;
ptr_y[i] = y + xy_indptr[i] * d_out;
ld_x[i] = k; // m * k
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
ld_y[i] = n; // m * n
}

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_
29 changes: 29 additions & 0 deletions include/flashinfer/group_gemm/group_gemm_lora.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_
#define FLASHINFER_GROUP_GEMM_LORA_CUH_

namespace flashinfer {

namespace group_gemm {

// TODO(Zihao): port punica's sgmv kernel

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_LORA_CUH_
29 changes: 29 additions & 0 deletions include/flashinfer/group_gemm/group_gemv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMV_CUH_
#define FLASHINFER_GROUP_GEMV_CUH_

namespace flashinfer {

namespace group_gemm {

// TODO(Zihao): port punica's bgmv kernel

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMV_CUH_
66 changes: 66 additions & 0 deletions include/flashinfer/group_gemm/handler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_
#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_

#include <cstddef>

#include "../allocator.h"
#include "../utils.cuh"
#include "group_gemm_cutlass.cuh"
#include "group_gemm_lora.cuh"
#include "group_gemv.cuh"

namespace flashinfer {

namespace group_gemm {

enum class GroupGEMMKernelConfig {
kGeneral, // large d_in, d_out
kShrink, // large d_in, small d_out
kExpand, // small d_in, large d_out
};

class CutlassSegmentGEMMHandler {
public:
void RegisterWorkspace(void* buffer, size_t size) {
buffer_ = buffer;
workspace_size_in_bytes_ = size;
}

void* GetWorkspace() const { return buffer_; }

size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; }

cudaStream_t GetCUDAStream() const { return stream_; }

void SetCUDAStream(cudaStream_t stream) { stream_ = stream; }

CutlassSegmentGEMMHandler() {}

~CutlassSegmentGEMMHandler() {}

private:
void* buffer_;
size_t workspace_size_in_bytes_;
cudaStream_t stream_;
};

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_
Loading