Skip to content

Commit 29c5bc0

Browse files
committed
lora enable in xpu
Signed-off-by: Wang, Yi A <[email protected]>
1 parent e325287 commit 29c5bc0

File tree

2 files changed

+231
-23
lines changed

2 files changed

+231
-23
lines changed

server/text_generation_server/adapters/lora.py

Lines changed: 145 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from peft import LoraConfig as _LoraConfig
1212
from torch.distributed import ProcessGroup
1313
from text_generation_server.utils.log import log_master
14-
15-
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
1614
from text_generation_server.utils.import_utils import SYSTEM
15+
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
1716
from text_generation_server.utils.kernels import load_kernel
1817
from text_generation_server.adapters.weights import (
1918
AdapterBatchMetadata,
@@ -130,15 +129,24 @@ def __init__(
130129

131130
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
132131
self._is_transposed = False
132+
if SYSTEM == "ipex":
133+
# [num_layers, r, hidden_size]
134+
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
135+
self._weights_a = torch.stack(weights_a)
136+
137+
# [num_layers, hidden_size, r]
138+
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
139+
self._weights_b = torch.stack(weights_b)
140+
else:
141+
# [num_layers, hidden_size, r]
142+
weights_a = [
143+
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()
144+
for w in weights_a
145+
]
146+
self._weights_a = torch.stack(weights_a)
133147

134-
# [num_layers, hidden_size, r]
135-
weights_a = [
136-
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
137-
]
138-
self._weights_a = torch.stack(weights_a)
139-
140-
# [num_layers, r, hidden_size]
141-
self._weights_b = torch.stack(weights_b)
148+
# [num_layers, r, hidden_size]
149+
self._weights_b = torch.stack(weights_b)
142150

143151
self.adapter_config = adapter_config
144152

@@ -175,7 +183,10 @@ def _transpose_weights(self):
175183

176184
@classmethod
177185
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
178-
return [BatchLoraWeights]
186+
if SYSTEM == "ipex":
187+
return [IPEXBatchLoraWeights]
188+
else:
189+
return [BatchLoraWeights]
179190

180191
# prepare pre-loaded lora weights for use in the model.
181192
#
@@ -245,17 +256,20 @@ def prepare_weights(
245256
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
246257

247258
# pad lora ranks to be compatible with sgmv
248-
lora_a_list = [
249-
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
250-
]
251-
lora_b_list = [
252-
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
253-
]
254-
255-
if lora_a_list:
256-
# update rank if it was padded
257-
padded_rank = lora_a_list[0].size(1)
258-
config.r = padded_rank
259+
if SYSTEM != "ipex":
260+
lora_a_list = [
261+
punica_sgmv.pad_rank(w, dim=1, world_size=world_size)
262+
for w in lora_a_list
263+
]
264+
lora_b_list = [
265+
punica_sgmv.pad_rank(w, dim=0, world_size=world_size)
266+
for w in lora_b_list
267+
]
268+
269+
if lora_a_list:
270+
# update rank if it was padded
271+
padded_rank = lora_a_list[0].size(1)
272+
config.r = padded_rank
259273

260274
return LoraWeights(
261275
*shard_lora_weights(
@@ -471,6 +485,115 @@ def load(
471485
)
472486

473487

488+
@dataclass
489+
class IPEXBatchLoraWeights(BatchLoraWeights):
490+
@classmethod
491+
def load(
492+
self,
493+
adapter_weights: Dict[int, AdapterWeights],
494+
meta: AdapterBatchMetadata,
495+
prefill: bool,
496+
prefill_head_indices: Optional[torch.Tensor],
497+
) -> Optional["BatchLoraWeights"]:
498+
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
499+
adapter_weights = {
500+
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
501+
}
502+
if not adapter_weights:
503+
return None
504+
505+
first_weights = next(iter(adapter_weights.values()))
506+
device = first_weights.weights_a.device
507+
segment_indices = meta.segment_indices
508+
509+
lora_a = {
510+
idx: adapter_weights[idx].weights_a
511+
for idx in segment_indices
512+
if idx in adapter_weights
513+
}
514+
lora_b = {
515+
idx: adapter_weights[idx].weights_b
516+
for idx in segment_indices
517+
if idx in adapter_weights
518+
}
519+
adapter_index_configs = {
520+
idx: adapter_weights[idx].adapter_config
521+
for idx in segment_indices
522+
if idx in adapter_weights
523+
}
524+
if len(lora_a) != 0:
525+
lora_a_ptr = torch.stack(list(lora_a.values()))
526+
if len(lora_b) != 0:
527+
lora_b_ptr = torch.stack(list(lora_b.values()))
528+
529+
use_sgmv = True if prefill else False
530+
531+
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
532+
533+
rank_indices = defaultdict(list)
534+
for segment_idx, adapter_idx in enumerate(segment_indices):
535+
if adapter_idx not in adapter_weights:
536+
continue
537+
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
538+
539+
if prefill_head_indices is not None:
540+
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
541+
for head_index in prefill_head_indices:
542+
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
543+
if head_index < meta.adapter_segments[j]:
544+
prefill_head_segment_ends[-1] += 1
545+
else:
546+
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
547+
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
548+
j += 1
549+
550+
rank_data = {}
551+
segment_starts = None
552+
segment_ends = None
553+
if use_sgmv:
554+
segment_starts = meta.adapter_segments[:-1]
555+
segment_ends = meta.adapter_segments[1:]
556+
if prefill_head_indices is not None:
557+
segment_starts = prefill_head_segment_starts[:-1]
558+
segment_ends = prefill_head_segment_ends[1:]
559+
batch_indices = [
560+
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
561+
]
562+
for rank, indices in rank_indices.items():
563+
adapters_indices = []
564+
lora_a_keys = list(lora_a.keys())
565+
for segment_idx in batch_indices:
566+
if segment_idx in indices:
567+
adapters_indices.append(
568+
lora_a_keys.index(segment_indices[segment_idx])
569+
)
570+
else:
571+
adapters_indices.append(-1)
572+
adapters_indices = torch.tensor(
573+
adapters_indices, dtype=torch.int64, device=device
574+
)
575+
if use_sgmv:
576+
adapters_indices = adapters_indices[segment_starts]
577+
rank_data[rank] = RankSegments(
578+
rank=rank,
579+
tmp_shrink=None,
580+
tmp_expand=None,
581+
lora_a_ptr=lora_a_ptr,
582+
lora_b_ptr=lora_b_ptr,
583+
segment_starts=segment_starts,
584+
segment_ends=segment_ends,
585+
indices=adapters_indices,
586+
)
587+
588+
return BatchLoraWeights(
589+
lora_a=lora_a,
590+
lora_b=lora_b,
591+
adapter_index_configs=adapter_index_configs,
592+
rank_data=rank_data,
593+
use_sgmv=use_sgmv,
594+
)
595+
596+
474597
def get_scaling_factor(
475598
lora_alpha: int,
476599
r: int,

server/text_generation_server/layers/lora.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch.distributed
55
from torch import nn
66
from torch.distributed import ProcessGroup
7-
87
from text_generation_server.utils.import_utils import SYSTEM
8+
99
from text_generation_server.utils.kernels import load_kernel
1010

1111
if SYSTEM == "cuda":
@@ -121,6 +121,91 @@ def forward_layer_type(
121121
self.layer_id,
122122
)
123123

124+
if end_idx - start_idx != result.shape[1]:
125+
result[:, start_idx:end_idx] += proj
126+
elif SYSTEM == "ipex" and data is not None:
127+
from intel_extension_for_pytorch.llm.functional import (
128+
bgmv_expand,
129+
bgmv_shrink,
130+
sgmv_expand,
131+
sgmv_shrink,
132+
)
133+
134+
# In IPEX, we provide the same API for sgmv and bgmv
135+
if end_idx - start_idx != result.shape[1]:
136+
proj = torch.zeros_like(result[:, start_idx:end_idx])
137+
else:
138+
proj = result
139+
140+
for r, rank_segments in data.rank_data.items():
141+
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
142+
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
143+
144+
if lora_a_ptr is None or lora_b_ptr is None:
145+
raise ValueError("LoRA data is missing")
146+
147+
if data.use_sgmv:
148+
# Use SGMV for prefill
149+
seq_len_tensor = (
150+
rank_segments.segment_ends - rank_segments.segment_starts
151+
).to(torch.int64)
152+
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
153+
total_tokens = seq_len_tensor.sum()
154+
v = torch.zeros(
155+
(total_tokens, r), dtype=input.dtype, device=input.device
156+
)
157+
bs = seq_len_tensor.shape[0]
158+
sgmv_shrink(
159+
input,
160+
lora_a_ptr,
161+
v,
162+
b_seq_start_loc,
163+
seq_len_tensor,
164+
rank_segments.indices,
165+
bs,
166+
seq_len_tensor.max().item(),
167+
1.0,
168+
)
169+
170+
if self.process_group.size() > 1:
171+
v = self.collect_lora_a(v)
172+
173+
sgmv_expand(
174+
v,
175+
lora_b_ptr,
176+
proj,
177+
b_seq_start_loc,
178+
seq_len_tensor,
179+
rank_segments.indices,
180+
bs,
181+
seq_len_tensor.max().item(),
182+
add_inputs=True,
183+
)
184+
else:
185+
# Use BGMV for decode
186+
v = torch.zeros(
187+
(input.size(0), r), dtype=input.dtype, device=input.device
188+
)
189+
# TODO: error with [-1, 0], but not [0, -1]
190+
bgmv_shrink(
191+
input,
192+
lora_a_ptr,
193+
v,
194+
rank_segments.indices,
195+
1.0,
196+
)
197+
198+
if self.process_group.size() > 1:
199+
v = self.collect_lora_a(v)
200+
201+
bgmv_expand(
202+
v,
203+
lora_b_ptr,
204+
proj,
205+
rank_segments.indices,
206+
add_inputs=True,
207+
)
208+
124209
if end_idx - start_idx != result.shape[1]:
125210
result[:, start_idx:end_idx] += proj
126211
else:

0 commit comments

Comments
 (0)