|
11 | 11 | from peft import LoraConfig as _LoraConfig
|
12 | 12 | from torch.distributed import ProcessGroup
|
13 | 13 | from text_generation_server.utils.log import log_master
|
14 |
| - |
| 14 | +from text_generation_server.utils.import_utils import SYSTEM |
15 | 15 | from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
16 | 16 |
|
17 | 17 | from text_generation_server.adapters.weights import (
|
@@ -132,12 +132,21 @@ def __init__(
|
132 | 132 | self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
133 | 133 | self._is_transposed = False
|
134 | 134 |
|
135 |
| - # [num_layers, hidden_size, r] |
136 |
| - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] |
137 |
| - self._weights_a = torch.stack(weights_a) |
| 135 | + if SYSTEM == "ipex": |
| 136 | + # [num_layers, r, hidden_size] |
| 137 | + weights_a = [w.transpose(0, 1).contiguous() for w in weights_a] |
| 138 | + self._weights_a = torch.stack(weights_a) |
| 139 | + |
| 140 | + # [num_layers, hidden_size, r] |
| 141 | + weights_b = [w.transpose(0, 1).contiguous() for w in weights_b] |
| 142 | + self._weights_b = torch.stack(weights_b) |
| 143 | + else: |
| 144 | + # [num_layers, hidden_size, r] |
| 145 | + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] |
| 146 | + self._weights_a = torch.stack(weights_a) |
138 | 147 |
|
139 |
| - # [num_layers, r, hidden_size] |
140 |
| - self._weights_b = torch.stack(weights_b) |
| 148 | + # [num_layers, r, hidden_size] |
| 149 | + self._weights_b = torch.stack(weights_b) |
141 | 150 |
|
142 | 151 | self.adapter_config = adapter_config
|
143 | 152 |
|
@@ -174,7 +183,10 @@ def _transpose_weights(self):
|
174 | 183 |
|
175 | 184 | @classmethod
|
176 | 185 | def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
177 |
| - return [BatchLoraWeights] |
| 186 | + if SYSTEM == "ipex": |
| 187 | + return [IPEXBatchLoraWeights] |
| 188 | + else: |
| 189 | + return [BatchLoraWeights] |
178 | 190 |
|
179 | 191 | # prepare pre-loaded lora weights for use in the model.
|
180 | 192 | #
|
@@ -243,14 +255,19 @@ def prepare_weights(
|
243 | 255 | lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
244 | 256 | lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
245 | 257 |
|
246 |
| - # pad lora ranks to be compatible with sgmv |
247 |
| - lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] |
248 |
| - lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] |
| 258 | + if SYSTEM != "ipex": |
| 259 | + # pad lora ranks to be compatible with sgmv |
| 260 | + lora_a_list = [ |
| 261 | + pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list |
| 262 | + ] |
| 263 | + lora_b_list = [ |
| 264 | + pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list |
| 265 | + ] |
249 | 266 |
|
250 |
| - if lora_a_list: |
251 |
| - # update rank if it was padded |
252 |
| - padded_rank = lora_a_list[0].size(1) |
253 |
| - config.r = padded_rank |
| 267 | + if lora_a_list: |
| 268 | + # update rank if it was padded |
| 269 | + padded_rank = lora_a_list[0].size(1) |
| 270 | + config.r = padded_rank |
254 | 271 |
|
255 | 272 | return LoraWeights(
|
256 | 273 | *shard_lora_weights(
|
@@ -466,6 +483,115 @@ def load(
|
466 | 483 | )
|
467 | 484 |
|
468 | 485 |
|
| 486 | +@dataclass |
| 487 | +class IPEXBatchLoraWeights(BatchLoraWeights): |
| 488 | + @classmethod |
| 489 | + def load( |
| 490 | + self, |
| 491 | + adapter_weights: Dict[int, AdapterWeights], |
| 492 | + meta: AdapterBatchMetadata, |
| 493 | + prefill: bool, |
| 494 | + prefill_head_indices: Optional[torch.Tensor], |
| 495 | + ) -> Optional["BatchLoraWeights"]: |
| 496 | + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} |
| 497 | + adapter_weights = { |
| 498 | + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) |
| 499 | + } |
| 500 | + if not adapter_weights: |
| 501 | + return None |
| 502 | + |
| 503 | + first_weights = next(iter(adapter_weights.values())) |
| 504 | + device = first_weights.weights_a.device |
| 505 | + segment_indices = meta.segment_indices |
| 506 | + |
| 507 | + lora_a = { |
| 508 | + idx: adapter_weights[idx].weights_a |
| 509 | + for idx in segment_indices |
| 510 | + if idx in adapter_weights |
| 511 | + } |
| 512 | + lora_b = { |
| 513 | + idx: adapter_weights[idx].weights_b |
| 514 | + for idx in segment_indices |
| 515 | + if idx in adapter_weights |
| 516 | + } |
| 517 | + adapter_index_configs = { |
| 518 | + idx: adapter_weights[idx].adapter_config |
| 519 | + for idx in segment_indices |
| 520 | + if idx in adapter_weights |
| 521 | + } |
| 522 | + if len(lora_a) != 0: |
| 523 | + lora_a_ptr = torch.stack(list(lora_a.values())) |
| 524 | + if len(lora_b) != 0: |
| 525 | + lora_b_ptr = torch.stack(list(lora_b.values())) |
| 526 | + |
| 527 | + use_sgmv = True if prefill else False |
| 528 | + |
| 529 | + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} |
| 530 | + |
| 531 | + rank_indices = defaultdict(list) |
| 532 | + for segment_idx, adapter_idx in enumerate(segment_indices): |
| 533 | + if adapter_idx not in adapter_weights: |
| 534 | + continue |
| 535 | + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) |
| 536 | + |
| 537 | + if prefill_head_indices is not None: |
| 538 | + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] |
| 539 | + for head_index in prefill_head_indices: |
| 540 | + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters |
| 541 | + if head_index < meta.adapter_segments[j]: |
| 542 | + prefill_head_segment_ends[-1] += 1 |
| 543 | + else: |
| 544 | + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) |
| 545 | + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) |
| 546 | + j += 1 |
| 547 | + |
| 548 | + rank_data = {} |
| 549 | + segment_starts = None |
| 550 | + segment_ends = None |
| 551 | + if use_sgmv: |
| 552 | + segment_starts = meta.adapter_segments[:-1] |
| 553 | + segment_ends = meta.adapter_segments[1:] |
| 554 | + if prefill_head_indices is not None: |
| 555 | + segment_starts = prefill_head_segment_starts[:-1] |
| 556 | + segment_ends = prefill_head_segment_ends[1:] |
| 557 | + batch_indices = [ |
| 558 | + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() |
| 559 | + ] |
| 560 | + for rank, indices in rank_indices.items(): |
| 561 | + adapters_indices = [] |
| 562 | + lora_a_keys = list(lora_a.keys()) |
| 563 | + for segment_idx in batch_indices: |
| 564 | + if segment_idx in indices: |
| 565 | + adapters_indices.append( |
| 566 | + lora_a_keys.index(segment_indices[segment_idx]) |
| 567 | + ) |
| 568 | + else: |
| 569 | + adapters_indices.append(-1) |
| 570 | + adapters_indices = torch.tensor( |
| 571 | + adapters_indices, dtype=torch.int64, device=device |
| 572 | + ) |
| 573 | + if use_sgmv: |
| 574 | + adapters_indices = adapters_indices[segment_starts] |
| 575 | + rank_data[rank] = RankSegments( |
| 576 | + rank=rank, |
| 577 | + tmp_shrink=None, |
| 578 | + tmp_expand=None, |
| 579 | + lora_a_ptr=lora_a_ptr, |
| 580 | + lora_b_ptr=lora_b_ptr, |
| 581 | + segment_starts=segment_starts, |
| 582 | + segment_ends=segment_ends, |
| 583 | + indices=adapters_indices, |
| 584 | + ) |
| 585 | + |
| 586 | + return BatchLoraWeights( |
| 587 | + lora_a=lora_a, |
| 588 | + lora_b=lora_b, |
| 589 | + adapter_index_configs=adapter_index_configs, |
| 590 | + rank_data=rank_data, |
| 591 | + use_sgmv=use_sgmv, |
| 592 | + ) |
| 593 | + |
| 594 | + |
469 | 595 | def get_scaling_factor(
|
470 | 596 | lora_alpha: int,
|
471 | 597 | r: int,
|
|
0 commit comments