Skip to content

Commit e325287

Browse files
authored
Switch to punica-sgmv kernel from the Hub (#3236)
* Switch to punica-sgmv kernel from the Hub This also switches (temporarily) to the tgi-nix/kernel-builder merge branch, bumping up to CUDA 12.8 (same as non-Nix Torch). * nix: client depends on aiohttp This probably worked before the nixpkgs bump because a dependency propagated aiohttp.
1 parent 43b1b07 commit e325287

File tree

12 files changed

+115
-317
lines changed

12 files changed

+115
-317
lines changed

Dockerfile

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
121121
# Build specific version of transformers
122122
RUN . .venv/bin/activate && make build-awq
123123

124-
# Build Lorax Punica kernels
125-
FROM kernel-builder AS lorax-punica-builder
126-
WORKDIR /usr/src
127-
COPY server/Makefile-lorax-punica Makefile
128-
# Build specific version of transformers
129-
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
130-
131124
# Build Transformers CUDA kernels
132125
FROM kernel-builder AS custom-kernels-builder
133126
WORKDIR /usr/src
@@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
210203
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
211204
# Copy build artifacts from awq kernels builder
212205
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
213-
# Copy build artifacts from lorax punica kernels builder
214-
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
215206
# Copy build artifacts from mamba builder
216207
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
217208
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages

flake.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
66
};
77
nix-filter.url = "github:numtide/nix-filter";
8-
tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7";
8+
tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder";
99
nixpkgs.follows = "tgi-nix/nixpkgs";
1010
flake-utils.url = "github:numtide/flake-utils";
1111
rust-overlay = {

nix/client.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
buildPythonPackage,
33
poetry-core,
4+
aiohttp,
45
huggingface-hub,
56
pydantic,
67
}:
@@ -15,6 +16,7 @@ buildPythonPackage {
1516
build-system = [ poetry-core ];
1617

1718
dependencies = [
19+
aiohttp
1820
huggingface-hub
1921
pydantic
2022
];

nix/server.nix

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
peft,
3232
pillow,
3333
prometheus-client,
34-
punica-kernels,
34+
punica-sgmv,
3535
py-cpuinfo,
3636
pydantic,
3737
quantization,
@@ -107,7 +107,7 @@ buildPythonPackage {
107107
peft
108108
pillow
109109
prometheus-client
110-
punica-kernels
110+
punica-sgmv
111111
py-cpuinfo
112112
pydantic
113113
quantization

server/Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ include Makefile-flash-att-v2
33
include Makefile-vllm
44
include Makefile-awq
55
include Makefile-selective-scan
6-
include Makefile-lorax-punica
76
include Makefile-exllamav2
87
include Makefile-flashinfer
98

server/Makefile-lorax-punica

Lines changed: 0 additions & 12 deletions
This file was deleted.

server/kernels.lock

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,64 @@
163163
}
164164
}
165165
},
166+
{
167+
"repo_id": "kernels-community/punica-sgmv",
168+
"sha": "9ae1b469cb39c33df9ddd61657c6359acc423714",
169+
"variants": {
170+
"torch26-cxx11-cu118-x86_64-linux": {
171+
"hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4",
172+
"hash_type": "git_lfs_concat"
173+
},
174+
"torch26-cxx11-cu124-x86_64-linux": {
175+
"hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683",
176+
"hash_type": "git_lfs_concat"
177+
},
178+
"torch26-cxx11-cu126-aarch64-linux": {
179+
"hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99",
180+
"hash_type": "git_lfs_concat"
181+
},
182+
"torch26-cxx11-cu126-x86_64-linux": {
183+
"hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af",
184+
"hash_type": "git_lfs_concat"
185+
},
186+
"torch26-cxx98-cu118-x86_64-linux": {
187+
"hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae",
188+
"hash_type": "git_lfs_concat"
189+
},
190+
"torch26-cxx98-cu124-x86_64-linux": {
191+
"hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f",
192+
"hash_type": "git_lfs_concat"
193+
},
194+
"torch26-cxx98-cu126-aarch64-linux": {
195+
"hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba",
196+
"hash_type": "git_lfs_concat"
197+
},
198+
"torch26-cxx98-cu126-x86_64-linux": {
199+
"hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283",
200+
"hash_type": "git_lfs_concat"
201+
},
202+
"torch27-cxx11-cu118-x86_64-linux": {
203+
"hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d",
204+
"hash_type": "git_lfs_concat"
205+
},
206+
"torch27-cxx11-cu126-aarch64-linux": {
207+
"hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46",
208+
"hash_type": "git_lfs_concat"
209+
},
210+
"torch27-cxx11-cu126-x86_64-linux": {
211+
"hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb",
212+
"hash_type": "git_lfs_concat"
213+
},
214+
"torch27-cxx11-cu128-aarch64-linux": {
215+
"hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a",
216+
"hash_type": "git_lfs_concat"
217+
},
218+
"torch27-cxx11-cu128-x86_64-linux": {
219+
"hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1",
220+
"hash_type": "git_lfs_concat"
221+
}
222+
}
223+
},
166224
{
167225
"repo_id": "kernels-community/quantization",
168226
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",

server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta"
5858
[tool.kernels.dependencies]
5959
"kernels-community/paged-attention" = ">=0.0.2"
6060
"kernels-community/moe" = ">=0.1.1"
61+
"kernels-community/punica-sgmv" = ">=0.0.1"
6162
"kernels-community/quantization" = ">=0.0.3"
6263
"kernels-community/quantization-eetq" = ">=0.0.1"
6364
"kernels-community/rotary" = ">=0.0.1"

server/text_generation_server/adapters/lora.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,20 @@
1313
from text_generation_server.utils.log import log_master
1414

1515
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
16-
16+
from text_generation_server.utils.import_utils import SYSTEM
17+
from text_generation_server.utils.kernels import load_kernel
1718
from text_generation_server.adapters.weights import (
1819
AdapterBatchMetadata,
1920
AdapterWeights,
2021
BatchAdapterWeights,
2122
)
22-
from text_generation_server.utils.sgmv import (
23-
BGMV_MAX_RANK,
24-
MAX_RANK_CUSTOM,
25-
get_tmp_tensors,
26-
orient_for_rank,
27-
pad_rank,
28-
use_cutlass_shrink,
29-
has_sgmv,
30-
)
23+
24+
if SYSTEM == "cuda":
25+
punica_sgmv = load_kernel(
26+
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
27+
)
28+
else:
29+
punica_sgmv = None
3130

3231

3332
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
@@ -129,11 +128,13 @@ def __init__(
129128
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
130129
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
131130

132-
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
131+
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
133132
self._is_transposed = False
134133

135134
# [num_layers, hidden_size, r]
136-
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
135+
weights_a = [
136+
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
137+
]
137138
self._weights_a = torch.stack(weights_a)
138139

139140
# [num_layers, r, hidden_size]
@@ -244,8 +245,12 @@ def prepare_weights(
244245
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
245246

246247
# pad lora ranks to be compatible with sgmv
247-
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
248-
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
248+
lora_a_list = [
249+
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
250+
]
251+
lora_b_list = [
252+
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
253+
]
249254

250255
if lora_a_list:
251256
# update rank if it was padded
@@ -293,7 +298,7 @@ def has_adapter(self, adapter_index: int) -> bool:
293298

294299
def can_vectorize(self, pg: ProcessGroup) -> bool:
295300
return all(
296-
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
301+
rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM
297302
for rank_data in self.rank_data.values()
298303
)
299304

@@ -337,8 +342,8 @@ def load(
337342
)
338343

339344
use_sgmv = False
340-
if prefill or max_rank > BGMV_MAX_RANK:
341-
if has_sgmv():
345+
if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:
346+
if punica_sgmv is not None:
342347
use_sgmv = True
343348
lora_a_ptr = torch.tensor(
344349
[
@@ -425,7 +430,7 @@ def load(
425430

426431
if use_sgmv:
427432
lora_a_ptr_indices = lora_a_ptr[indices]
428-
tmp_shrink, tmp_expand = get_tmp_tensors(
433+
tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(
429434
lora_a_ptr_indices.size(0), rank, device
430435
)
431436
segment_starts = meta.adapter_segments[indices]

server/text_generation_server/layers/lora.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from torch import nn
66
from torch.distributed import ProcessGroup
77

8-
from text_generation_server.utils.sgmv import (
9-
add_lora_a_bgmv,
10-
add_lora_b_bgmv,
11-
has_sgmv,
12-
lora_a_sgmv_cutlass,
13-
lora_b_sgmv_cutlass,
14-
orient_for_rank,
15-
)
8+
from text_generation_server.utils.import_utils import SYSTEM
9+
from text_generation_server.utils.kernels import load_kernel
10+
11+
if SYSTEM == "cuda":
12+
punica_sgmv = load_kernel(
13+
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
14+
)
15+
else:
16+
punica_sgmv = None
17+
1618

1719
if TYPE_CHECKING:
1820
from text_generation_server.adapters import AdapterBatchData
@@ -41,7 +43,11 @@ def forward_layer_type(
4143
return result
4244
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
4345

44-
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
46+
if (
47+
punica_sgmv is not None
48+
and data is not None
49+
and data.can_vectorize(self.process_group)
50+
):
4551
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
4652
# The 'result' tensor represents the full output, which can vary in size based on
4753
# the layer type (e.g., attention vs. feed-forward layers). We define the current
@@ -68,7 +74,7 @@ def forward_layer_type(
6874

6975
if data.use_sgmv:
7076
# Use SGMV for prefill
71-
v = lora_a_sgmv_cutlass(
77+
v = punica_sgmv.lora_a_sgmv_cutlass(
7278
input,
7379
rank_segments.tmp_shrink,
7480
lora_a_ptr,
@@ -81,7 +87,7 @@ def forward_layer_type(
8187
if self.process_group.size() > 1:
8288
v = self.collect_lora_a(v)
8389

84-
lora_b_sgmv_cutlass(
90+
punica_sgmv.lora_b_sgmv_cutlass(
8591
proj,
8692
v,
8793
rank_segments.tmp_expand,
@@ -96,7 +102,7 @@ def forward_layer_type(
96102
(input.size(0), r), dtype=input.dtype, device=input.device
97103
)
98104
# TODO: error with [-1, 0], but not [0, -1]
99-
add_lora_a_bgmv(
105+
punica_sgmv.add_lora_a_bgmv(
100106
v,
101107
input,
102108
lora_a_ptr,
@@ -107,7 +113,7 @@ def forward_layer_type(
107113
if self.process_group.size() > 1:
108114
v = self.collect_lora_a(v)
109115

110-
add_lora_b_bgmv(
116+
punica_sgmv.add_lora_b_bgmv(
111117
proj,
112118
v,
113119
lora_b_ptr,
@@ -142,7 +148,7 @@ def forward_lora(
142148
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
143149
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
144150

145-
lora_a = orient_for_rank(lora_a, lora_b.size(0))
151+
lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))
146152

147153
a_out = input @ lora_a
148154
if self.process_group.size() > 1:

0 commit comments

Comments
 (0)