Skip to content

Add GPTQ support #916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
82e6b2e
Add gptq implementation compatible with awq interface
chu-tianxiang Sep 18, 2023
612d7b1
Add more models
chu-tianxiang Sep 25, 2023
049a37c
fix bug in model loading
chu-tianxiang Sep 25, 2023
5563578
Add fallback kernel for desc act models
chu-tianxiang Sep 27, 2023
0470121
Fix engine args and opt model
chu-tianxiang Sep 27, 2023
92c7f8d
Merge main branch
chu-tianxiang Oct 8, 2023
f9d0ccc
Add mistral model
chu-tianxiang Oct 9, 2023
cbf9433
Fix bug in gpt layer
chu-tianxiang Oct 11, 2023
a7b391d
Fix conflict
chu-tianxiang Oct 24, 2023
b51ebb7
Merge main branch
chu-tianxiang Oct 24, 2023
9a99461
Fix squeezellm
chu-tianxiang Oct 24, 2023
2593dfe
Use exllama v2 kernels for better performance
chu-tianxiang Nov 2, 2023
97072a7
Add Yi and ChatGLM GPTQ support
chu-tianxiang Nov 14, 2023
2d8dc1d
Fix chatglm
chu-tianxiang Nov 14, 2023
22ea9ce
merge main
chu-tianxiang Dec 1, 2023
17b6f2b
Fix phi model
chu-tianxiang Dec 1, 2023
62bd8ce
move post init to first forward pass to make code cleaner
chu-tianxiang Dec 2, 2023
e1c4c25
merge main
chu-tianxiang Dec 3, 2023
b6b8c63
Update GPTQ kernel and fix minor problems
chu-tianxiang Dec 10, 2023
1bcb832
Merge main
chu-tianxiang Dec 10, 2023
d1954ab
Fix typo
chu-tianxiang Dec 11, 2023
514021c
Merge branch 'main' into gptq_hf
WoosukKwon Dec 15, 2023
62d6760
Minor fix
WoosukKwon Dec 15, 2023
5156579
Minor
WoosukKwon Dec 15, 2023
1f3f6ee
Support Mixtral
WoosukKwon Dec 15, 2023
99cc231
Ignore warning
WoosukKwon Dec 15, 2023
17fcdd2
Fix squeezellm
WoosukKwon Dec 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'squeezellm', None],
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'squeezellm', None],
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
Expand Down
12 changes: 12 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,15 @@ void squeezellm_gemm(
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);

torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama);

void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);
4 changes: 2 additions & 2 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
#endif


ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");

// Cache ops
Expand Down
64 changes: 64 additions & 0 deletions csrc/quantization/gptq/compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/

#ifndef _compat_cuh
#define _compat_cuh

namespace vllm {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x

__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;

do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}

// atomicAdd for half2 types

__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}

//

#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }

#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

#endif
#endif

} // namespace gptq
} // namespace vllm
#endif
151 changes: 151 additions & 0 deletions csrc/quantization/gptq/matrix_view.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/

#ifndef _matrix_view_cuh
#define _matrix_view_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>

#include "qdq_util.cuh"

namespace vllm {
namespace gptq {

class MatrixView_half
{
public:
const half* data;
const int height;
const int width;

__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }

__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }

__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}

__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};

class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;

__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }

__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }

__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};

class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;

__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }

__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}

__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}

__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};

class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;

__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }

__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}

__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};

} // namespace gptq
} // namespace vllm
#endif
Loading