Skip to content

Commit e0c6f55

Browse files
authored
[Build] Avoid building too many extensions (#1624)
1 parent de23687 commit e0c6f55

25 files changed

+206
-272
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from vllm import attention_ops
7+
from vllm._C import ops
88

99
NUM_BLOCKS = 1024
1010
PARTITION_SIZE = 512
@@ -98,7 +98,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
9898

9999
for _ in range(num_iters):
100100
if version == "v1":
101-
attention_ops.paged_attention_v1(
101+
ops.paged_attention_v1(
102102
output,
103103
query,
104104
key_cache,
@@ -112,7 +112,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
112112
alibi_slopes,
113113
)
114114
elif version == "v2":
115-
attention_ops.paged_attention_v2(
115+
ops.paged_attention_v2(
116116
output,
117117
exp_sums,
118118
max_logits,

csrc/activation.cpp

Lines changed: 0 additions & 28 deletions
This file was deleted.

csrc/attention.cpp

Lines changed: 0 additions & 42 deletions
This file was deleted.

csrc/cache.cpp renamed to csrc/cache.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,3 @@ void gather_cached_kv(
2626
torch::Tensor& key_cache,
2727
torch::Tensor& value_cache,
2828
torch::Tensor& slot_mapping);
29-
30-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31-
m.def(
32-
"swap_blocks",
33-
&swap_blocks,
34-
"Swap in (out) the cache blocks from src to dst");
35-
m.def(
36-
"copy_blocks",
37-
&copy_blocks,
38-
"Copy the cache blocks from src to dst");
39-
m.def(
40-
"reshape_and_cache",
41-
&reshape_and_cache,
42-
"Reshape the key and value tensors and cache them");
43-
m.def(
44-
"gather_cached_kv",
45-
&gather_cached_kv,
46-
"Gather key and value from the cache into contiguous QKV tensors");
47-
}

csrc/cuda_utils.cpp

Lines changed: 0 additions & 13 deletions
This file was deleted.

csrc/cuda_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <torch/extension.h>
2+
3+
int get_device_attribute(
4+
int attribute,
5+
int device_id);

csrc/layernorm.cpp

Lines changed: 0 additions & 24 deletions
This file was deleted.

csrc/ops.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include <torch/extension.h>
2+
3+
void paged_attention_v1(
4+
torch::Tensor& out,
5+
torch::Tensor& query,
6+
torch::Tensor& key_cache,
7+
torch::Tensor& value_cache,
8+
torch::Tensor& head_mapping,
9+
float scale,
10+
torch::Tensor& block_tables,
11+
torch::Tensor& context_lens,
12+
int block_size,
13+
int max_context_len,
14+
const c10::optional<torch::Tensor>& alibi_slopes);
15+
16+
void paged_attention_v2(
17+
torch::Tensor& out,
18+
torch::Tensor& exp_sums,
19+
torch::Tensor& max_logits,
20+
torch::Tensor& tmp_out,
21+
torch::Tensor& query,
22+
torch::Tensor& key_cache,
23+
torch::Tensor& value_cache,
24+
torch::Tensor& head_mapping,
25+
float scale,
26+
torch::Tensor& block_tables,
27+
torch::Tensor& context_lens,
28+
int block_size,
29+
int max_context_len,
30+
const c10::optional<torch::Tensor>& alibi_slopes);
31+
32+
void rms_norm(
33+
torch::Tensor& out,
34+
torch::Tensor& input,
35+
torch::Tensor& weight,
36+
float epsilon);
37+
38+
void fused_add_rms_norm(
39+
torch::Tensor& input,
40+
torch::Tensor& residual,
41+
torch::Tensor& weight,
42+
float epsilon);
43+
44+
void rotary_embedding(
45+
torch::Tensor& positions,
46+
torch::Tensor& query,
47+
torch::Tensor& key,
48+
int head_size,
49+
torch::Tensor& cos_sin_cache,
50+
bool is_neox);
51+
52+
void silu_and_mul(
53+
torch::Tensor& out,
54+
torch::Tensor& input);
55+
56+
void gelu_new(
57+
torch::Tensor& out,
58+
torch::Tensor& input);
59+
60+
void gelu_fast(
61+
torch::Tensor& out,
62+
torch::Tensor& input);
63+
64+
torch::Tensor awq_gemm(
65+
torch::Tensor _in_feats,
66+
torch::Tensor _kernel,
67+
torch::Tensor _scaling_factors,
68+
torch::Tensor _zeros,
69+
int split_k_iters);
70+
71+
void squeezellm_gemm(
72+
torch::Tensor vec,
73+
torch::Tensor mat,
74+
torch::Tensor mul,
75+
torch::Tensor lookup_table);

csrc/pos_encoding.cpp

Lines changed: 0 additions & 16 deletions
This file was deleted.

csrc/pybind.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#include "cache.h"
2+
#include "cuda_utils.h"
3+
#include "ops.h"
4+
#include <torch/extension.h>
5+
6+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7+
// vLLM custom ops
8+
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
9+
10+
// Attention ops
11+
ops.def(
12+
"paged_attention_v1",
13+
&paged_attention_v1,
14+
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
15+
ops.def(
16+
"paged_attention_v2",
17+
&paged_attention_v2,
18+
"PagedAttention V2.");
19+
20+
// Activation ops
21+
ops.def(
22+
"silu_and_mul",
23+
&silu_and_mul,
24+
"Activation function used in SwiGLU.");
25+
ops.def(
26+
"gelu_new",
27+
&gelu_new,
28+
"GELU implementation used in GPT-2.");
29+
ops.def(
30+
"gelu_fast",
31+
&gelu_fast,
32+
"Approximate GELU implementation.");
33+
34+
// Layernorm
35+
ops.def(
36+
"rms_norm",
37+
&rms_norm,
38+
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
39+
40+
ops.def(
41+
"fused_add_rms_norm",
42+
&fused_add_rms_norm,
43+
"In-place fused Add and RMS Normalization");
44+
45+
// Rotary embedding
46+
ops.def(
47+
"rotary_embedding",
48+
&rotary_embedding,
49+
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
50+
51+
// Quantization ops
52+
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
53+
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
54+
55+
// Cache ops
56+
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
57+
cache_ops.def(
58+
"swap_blocks",
59+
&swap_blocks,
60+
"Swap in (out) the cache blocks from src to dst");
61+
cache_ops.def(
62+
"copy_blocks",
63+
&copy_blocks,
64+
"Copy the cache blocks from src to dst");
65+
cache_ops.def(
66+
"reshape_and_cache",
67+
&reshape_and_cache,
68+
"Reshape the key and value tensors and cache them");
69+
cache_ops.def(
70+
"gather_cached_kv",
71+
&gather_cached_kv,
72+
"Gather key and value from the cache into contiguous QKV tensors");
73+
74+
// Cuda utils
75+
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
76+
cuda_utils.def(
77+
"get_device_attribute",
78+
&get_device_attribute,
79+
"Gets the specified device attribute.");
80+
}

csrc/quantization.cpp

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)