|
11 | 11 | from peft import LoraConfig as _LoraConfig
|
12 | 12 | from torch.distributed import ProcessGroup
|
13 | 13 | from text_generation_server.utils.log import log_master
|
14 |
| - |
15 |
| -from text_generation_server.adapters.config import AdapterConfig, ModuleMap |
16 | 14 | from text_generation_server.utils.import_utils import SYSTEM
|
| 15 | +from text_generation_server.adapters.config import AdapterConfig, ModuleMap |
17 | 16 | from text_generation_server.utils.kernels import load_kernel
|
18 | 17 | from text_generation_server.adapters.weights import (
|
19 | 18 | AdapterBatchMetadata,
|
@@ -130,15 +129,24 @@ def __init__(
|
130 | 129 |
|
131 | 130 | self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
|
132 | 131 | self._is_transposed = False
|
| 132 | + if SYSTEM == "ipex": |
| 133 | + # [num_layers, r, hidden_size] |
| 134 | + weights_a = [w.transpose(0, 1).contiguous() for w in weights_a] |
| 135 | + self._weights_a = torch.stack(weights_a) |
| 136 | + |
| 137 | + # [num_layers, hidden_size, r] |
| 138 | + weights_b = [w.transpose(0, 1).contiguous() for w in weights_b] |
| 139 | + self._weights_b = torch.stack(weights_b) |
| 140 | + else: |
| 141 | + # [num_layers, hidden_size, r] |
| 142 | + weights_a = [ |
| 143 | + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() |
| 144 | + for w in weights_a |
| 145 | + ] |
| 146 | + self._weights_a = torch.stack(weights_a) |
133 | 147 |
|
134 |
| - # [num_layers, hidden_size, r] |
135 |
| - weights_a = [ |
136 |
| - punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a |
137 |
| - ] |
138 |
| - self._weights_a = torch.stack(weights_a) |
139 |
| - |
140 |
| - # [num_layers, r, hidden_size] |
141 |
| - self._weights_b = torch.stack(weights_b) |
| 148 | + # [num_layers, r, hidden_size] |
| 149 | + self._weights_b = torch.stack(weights_b) |
142 | 150 |
|
143 | 151 | self.adapter_config = adapter_config
|
144 | 152 |
|
@@ -175,7 +183,10 @@ def _transpose_weights(self):
|
175 | 183 |
|
176 | 184 | @classmethod
|
177 | 185 | def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
178 |
| - return [BatchLoraWeights] |
| 186 | + if SYSTEM == "ipex": |
| 187 | + return [IPEXBatchLoraWeights] |
| 188 | + else: |
| 189 | + return [BatchLoraWeights] |
179 | 190 |
|
180 | 191 | # prepare pre-loaded lora weights for use in the model.
|
181 | 192 | #
|
@@ -245,17 +256,20 @@ def prepare_weights(
|
245 | 256 | lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
246 | 257 |
|
247 | 258 | # pad lora ranks to be compatible with sgmv
|
248 |
| - lora_a_list = [ |
249 |
| - punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list |
250 |
| - ] |
251 |
| - lora_b_list = [ |
252 |
| - punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list |
253 |
| - ] |
254 |
| - |
255 |
| - if lora_a_list: |
256 |
| - # update rank if it was padded |
257 |
| - padded_rank = lora_a_list[0].size(1) |
258 |
| - config.r = padded_rank |
| 259 | + if SYSTEM != "ipex": |
| 260 | + lora_a_list = [ |
| 261 | + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) |
| 262 | + for w in lora_a_list |
| 263 | + ] |
| 264 | + lora_b_list = [ |
| 265 | + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) |
| 266 | + for w in lora_b_list |
| 267 | + ] |
| 268 | + |
| 269 | + if lora_a_list: |
| 270 | + # update rank if it was padded |
| 271 | + padded_rank = lora_a_list[0].size(1) |
| 272 | + config.r = padded_rank |
259 | 273 |
|
260 | 274 | return LoraWeights(
|
261 | 275 | *shard_lora_weights(
|
@@ -471,6 +485,115 @@ def load(
|
471 | 485 | )
|
472 | 486 |
|
473 | 487 |
|
| 488 | +@dataclass |
| 489 | +class IPEXBatchLoraWeights(BatchLoraWeights): |
| 490 | + @classmethod |
| 491 | + def load( |
| 492 | + self, |
| 493 | + adapter_weights: Dict[int, AdapterWeights], |
| 494 | + meta: AdapterBatchMetadata, |
| 495 | + prefill: bool, |
| 496 | + prefill_head_indices: Optional[torch.Tensor], |
| 497 | + ) -> Optional["BatchLoraWeights"]: |
| 498 | + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} |
| 499 | + adapter_weights = { |
| 500 | + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) |
| 501 | + } |
| 502 | + if not adapter_weights: |
| 503 | + return None |
| 504 | + |
| 505 | + first_weights = next(iter(adapter_weights.values())) |
| 506 | + device = first_weights.weights_a.device |
| 507 | + segment_indices = meta.segment_indices |
| 508 | + |
| 509 | + lora_a = { |
| 510 | + idx: adapter_weights[idx].weights_a |
| 511 | + for idx in segment_indices |
| 512 | + if idx in adapter_weights |
| 513 | + } |
| 514 | + lora_b = { |
| 515 | + idx: adapter_weights[idx].weights_b |
| 516 | + for idx in segment_indices |
| 517 | + if idx in adapter_weights |
| 518 | + } |
| 519 | + adapter_index_configs = { |
| 520 | + idx: adapter_weights[idx].adapter_config |
| 521 | + for idx in segment_indices |
| 522 | + if idx in adapter_weights |
| 523 | + } |
| 524 | + if len(lora_a) != 0: |
| 525 | + lora_a_ptr = torch.stack(list(lora_a.values())) |
| 526 | + if len(lora_b) != 0: |
| 527 | + lora_b_ptr = torch.stack(list(lora_b.values())) |
| 528 | + |
| 529 | + use_sgmv = True if prefill else False |
| 530 | + |
| 531 | + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} |
| 532 | + |
| 533 | + rank_indices = defaultdict(list) |
| 534 | + for segment_idx, adapter_idx in enumerate(segment_indices): |
| 535 | + if adapter_idx not in adapter_weights: |
| 536 | + continue |
| 537 | + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) |
| 538 | + |
| 539 | + if prefill_head_indices is not None: |
| 540 | + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] |
| 541 | + for head_index in prefill_head_indices: |
| 542 | + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters |
| 543 | + if head_index < meta.adapter_segments[j]: |
| 544 | + prefill_head_segment_ends[-1] += 1 |
| 545 | + else: |
| 546 | + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) |
| 547 | + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) |
| 548 | + j += 1 |
| 549 | + |
| 550 | + rank_data = {} |
| 551 | + segment_starts = None |
| 552 | + segment_ends = None |
| 553 | + if use_sgmv: |
| 554 | + segment_starts = meta.adapter_segments[:-1] |
| 555 | + segment_ends = meta.adapter_segments[1:] |
| 556 | + if prefill_head_indices is not None: |
| 557 | + segment_starts = prefill_head_segment_starts[:-1] |
| 558 | + segment_ends = prefill_head_segment_ends[1:] |
| 559 | + batch_indices = [ |
| 560 | + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() |
| 561 | + ] |
| 562 | + for rank, indices in rank_indices.items(): |
| 563 | + adapters_indices = [] |
| 564 | + lora_a_keys = list(lora_a.keys()) |
| 565 | + for segment_idx in batch_indices: |
| 566 | + if segment_idx in indices: |
| 567 | + adapters_indices.append( |
| 568 | + lora_a_keys.index(segment_indices[segment_idx]) |
| 569 | + ) |
| 570 | + else: |
| 571 | + adapters_indices.append(-1) |
| 572 | + adapters_indices = torch.tensor( |
| 573 | + adapters_indices, dtype=torch.int64, device=device |
| 574 | + ) |
| 575 | + if use_sgmv: |
| 576 | + adapters_indices = adapters_indices[segment_starts] |
| 577 | + rank_data[rank] = RankSegments( |
| 578 | + rank=rank, |
| 579 | + tmp_shrink=None, |
| 580 | + tmp_expand=None, |
| 581 | + lora_a_ptr=lora_a_ptr, |
| 582 | + lora_b_ptr=lora_b_ptr, |
| 583 | + segment_starts=segment_starts, |
| 584 | + segment_ends=segment_ends, |
| 585 | + indices=adapters_indices, |
| 586 | + ) |
| 587 | + |
| 588 | + return BatchLoraWeights( |
| 589 | + lora_a=lora_a, |
| 590 | + lora_b=lora_b, |
| 591 | + adapter_index_configs=adapter_index_configs, |
| 592 | + rank_data=rank_data, |
| 593 | + use_sgmv=use_sgmv, |
| 594 | + ) |
| 595 | + |
| 596 | + |
474 | 597 | def get_scaling_factor(
|
475 | 598 | lora_alpha: int,
|
476 | 599 | r: int,
|
|
0 commit comments