Skip to content

Commit 8b6bb50

Browse files
committed
xpu lora support
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 58934c8 commit 8b6bb50

File tree

2 files changed

+226
-14
lines changed

2 files changed

+226
-14
lines changed

server/text_generation_server/adapters/lora.py

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from peft import LoraConfig as _LoraConfig
1212
from torch.distributed import ProcessGroup
1313
from text_generation_server.utils.log import log_master
14-
14+
from text_generation_server.utils.import_utils import SYSTEM
1515
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
1616

1717
from text_generation_server.adapters.weights import (
@@ -132,12 +132,21 @@ def __init__(
132132
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
133133
self._is_transposed = False
134134

135-
# [num_layers, hidden_size, r]
136-
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
137-
self._weights_a = torch.stack(weights_a)
135+
if SYSTEM == "ipex":
136+
# [num_layers, r, hidden_size]
137+
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
138+
self._weights_a = torch.stack(weights_a)
139+
140+
# [num_layers, hidden_size, r]
141+
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
142+
self._weights_b = torch.stack(weights_b)
143+
else:
144+
# [num_layers, hidden_size, r]
145+
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
146+
self._weights_a = torch.stack(weights_a)
138147

139-
# [num_layers, r, hidden_size]
140-
self._weights_b = torch.stack(weights_b)
148+
# [num_layers, r, hidden_size]
149+
self._weights_b = torch.stack(weights_b)
141150

142151
self.adapter_config = adapter_config
143152

@@ -174,7 +183,10 @@ def _transpose_weights(self):
174183

175184
@classmethod
176185
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
177-
return [BatchLoraWeights]
186+
if SYSTEM == "ipex":
187+
return [IPEXBatchLoraWeights]
188+
else:
189+
return [BatchLoraWeights]
178190

179191
# prepare pre-loaded lora weights for use in the model.
180192
#
@@ -243,14 +255,19 @@ def prepare_weights(
243255
lora_a_list[layer_id] = lora_a.transpose(0, 1)
244256
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
245257

246-
# pad lora ranks to be compatible with sgmv
247-
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
248-
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
258+
if SYSTEM != "ipex":
259+
# pad lora ranks to be compatible with sgmv
260+
lora_a_list = [
261+
pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
262+
]
263+
lora_b_list = [
264+
pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
265+
]
249266

250-
if lora_a_list:
251-
# update rank if it was padded
252-
padded_rank = lora_a_list[0].size(1)
253-
config.r = padded_rank
267+
if lora_a_list:
268+
# update rank if it was padded
269+
padded_rank = lora_a_list[0].size(1)
270+
config.r = padded_rank
254271

255272
return LoraWeights(
256273
*shard_lora_weights(
@@ -466,6 +483,115 @@ def load(
466483
)
467484

468485

486+
@dataclass
487+
class IPEXBatchLoraWeights(BatchLoraWeights):
488+
@classmethod
489+
def load(
490+
self,
491+
adapter_weights: Dict[int, AdapterWeights],
492+
meta: AdapterBatchMetadata,
493+
prefill: bool,
494+
prefill_head_indices: Optional[torch.Tensor],
495+
) -> Optional["BatchLoraWeights"]:
496+
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
497+
adapter_weights = {
498+
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
499+
}
500+
if not adapter_weights:
501+
return None
502+
503+
first_weights = next(iter(adapter_weights.values()))
504+
device = first_weights.weights_a.device
505+
segment_indices = meta.segment_indices
506+
507+
lora_a = {
508+
idx: adapter_weights[idx].weights_a
509+
for idx in segment_indices
510+
if idx in adapter_weights
511+
}
512+
lora_b = {
513+
idx: adapter_weights[idx].weights_b
514+
for idx in segment_indices
515+
if idx in adapter_weights
516+
}
517+
adapter_index_configs = {
518+
idx: adapter_weights[idx].adapter_config
519+
for idx in segment_indices
520+
if idx in adapter_weights
521+
}
522+
if len(lora_a) != 0:
523+
lora_a_ptr = torch.stack(list(lora_a.values()))
524+
if len(lora_b) != 0:
525+
lora_b_ptr = torch.stack(list(lora_b.values()))
526+
527+
use_sgmv = True if prefill else False
528+
529+
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
530+
531+
rank_indices = defaultdict(list)
532+
for segment_idx, adapter_idx in enumerate(segment_indices):
533+
if adapter_idx not in adapter_weights:
534+
continue
535+
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
536+
537+
if prefill_head_indices is not None:
538+
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
539+
for head_index in prefill_head_indices:
540+
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
541+
if head_index < meta.adapter_segments[j]:
542+
prefill_head_segment_ends[-1] += 1
543+
else:
544+
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
545+
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
546+
j += 1
547+
548+
rank_data = {}
549+
segment_starts = None
550+
segment_ends = None
551+
if use_sgmv:
552+
segment_starts = meta.adapter_segments[:-1]
553+
segment_ends = meta.adapter_segments[1:]
554+
if prefill_head_indices is not None:
555+
segment_starts = prefill_head_segment_starts[:-1]
556+
segment_ends = prefill_head_segment_ends[1:]
557+
batch_indices = [
558+
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
559+
]
560+
for rank, indices in rank_indices.items():
561+
adapters_indices = []
562+
lora_a_keys = list(lora_a.keys())
563+
for segment_idx in batch_indices:
564+
if segment_idx in indices:
565+
adapters_indices.append(
566+
lora_a_keys.index(segment_indices[segment_idx])
567+
)
568+
else:
569+
adapters_indices.append(-1)
570+
adapters_indices = torch.tensor(
571+
adapters_indices, dtype=torch.int64, device=device
572+
)
573+
if use_sgmv:
574+
adapters_indices = adapters_indices[segment_starts]
575+
rank_data[rank] = RankSegments(
576+
rank=rank,
577+
tmp_shrink=None,
578+
tmp_expand=None,
579+
lora_a_ptr=lora_a_ptr,
580+
lora_b_ptr=lora_b_ptr,
581+
segment_starts=segment_starts,
582+
segment_ends=segment_ends,
583+
indices=adapters_indices,
584+
)
585+
586+
return BatchLoraWeights(
587+
lora_a=lora_a,
588+
lora_b=lora_b,
589+
adapter_index_configs=adapter_index_configs,
590+
rank_data=rank_data,
591+
use_sgmv=use_sgmv,
592+
)
593+
594+
469595
def get_scaling_factor(
470596
lora_alpha: int,
471597
r: int,

server/text_generation_server/layers/lora.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.distributed
55
from torch import nn
66
from torch.distributed import ProcessGroup
7+
from text_generation_server.utils.import_utils import SYSTEM
78

89
from text_generation_server.utils.sgmv import (
910
add_lora_a_bgmv,
@@ -115,6 +116,91 @@ def forward_layer_type(
115116
self.layer_id,
116117
)
117118

119+
if end_idx - start_idx != result.shape[1]:
120+
result[:, start_idx:end_idx] += proj
121+
elif SYSTEM == "ipex" and data is not None:
122+
from intel_extension_for_pytorch.llm.functional import (
123+
bgmv_expand,
124+
bgmv_shrink,
125+
sgmv_expand,
126+
sgmv_shrink,
127+
)
128+
129+
# In IPEX, we provide the same API for sgmv and bgmv
130+
if end_idx - start_idx != result.shape[1]:
131+
proj = torch.zeros_like(result[:, start_idx:end_idx])
132+
else:
133+
proj = result
134+
135+
for r, rank_segments in data.rank_data.items():
136+
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
137+
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
138+
139+
if lora_a_ptr is None or lora_b_ptr is None:
140+
raise ValueError("LoRA data is missing")
141+
142+
if data.use_sgmv:
143+
# Use SGMV for prefill
144+
seq_len_tensor = (
145+
rank_segments.segment_ends - rank_segments.segment_starts
146+
).to(torch.int64)
147+
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
148+
total_tokens = seq_len_tensor.sum()
149+
v = torch.zeros(
150+
(total_tokens, r), dtype=input.dtype, device=input.device
151+
)
152+
bs = seq_len_tensor.shape[0]
153+
sgmv_shrink(
154+
input,
155+
lora_a_ptr,
156+
v,
157+
b_seq_start_loc,
158+
seq_len_tensor,
159+
rank_segments.indices,
160+
bs,
161+
seq_len_tensor.max().item(),
162+
1.0,
163+
)
164+
165+
if self.process_group.size() > 1:
166+
v = self.collect_lora_a(v)
167+
168+
sgmv_expand(
169+
v,
170+
lora_b_ptr,
171+
proj,
172+
b_seq_start_loc,
173+
seq_len_tensor,
174+
rank_segments.indices,
175+
bs,
176+
seq_len_tensor.max().item(),
177+
add_inputs=True,
178+
)
179+
else:
180+
# Use BGMV for decode
181+
v = torch.zeros(
182+
(input.size(0), r), dtype=input.dtype, device=input.device
183+
)
184+
# TODO: error with [-1, 0], but not [0, -1]
185+
bgmv_shrink(
186+
input,
187+
lora_a_ptr,
188+
v,
189+
rank_segments.indices,
190+
1.0,
191+
)
192+
193+
if self.process_group.size() > 1:
194+
v = self.collect_lora_a(v)
195+
196+
bgmv_expand(
197+
v,
198+
lora_b_ptr,
199+
proj,
200+
rank_segments.indices,
201+
add_inputs=True,
202+
)
203+
118204
if end_idx - start_idx != result.shape[1]:
119205
result[:, start_idx:end_idx] += proj
120206
else:

0 commit comments

Comments
 (0)