Skip to content

Commit 917cde4

Browse files
scsudhak-intelshreyankg
authored andcommitted
[Hardware][Intel-Gaudi] Enable long-contexts + LoRA support for Intel Gaudi (vllm-project#12812)
Signed-off-by: Sanju C Sudhakaran <[email protected]>
1 parent 8a30c5f commit 917cde4

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

vllm/lora/punica_wrapper/punica_hpu.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Optional, Tuple, Union, final
3+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
44

55
import torch
66
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
77
dispatch_bgmv_linear)
88

99
from .punica_base import PunicaWrapperBase
10+
from .utils import convert_mapping
11+
12+
if TYPE_CHECKING:
13+
# avoid circuit import
14+
from vllm.lora.layers import LoRAMapping
15+
from vllm.lora.models import LongContextLoRAContext
1016

1117

1218
@final
@@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
1925
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
2026
max_batches, device)
2127

28+
def _update_base_metadata(
29+
self,
30+
mapping: "LoRAMapping",
31+
lora_index_to_id: List[Optional[int]],
32+
max_loras: int,
33+
vocab_size: int,
34+
extra_vocab_size: int,
35+
long_lora_context: Optional["LongContextLoRAContext"] = None,
36+
):
37+
(
38+
base_indices,
39+
sampler_indices,
40+
sampler_indices_padded,
41+
embeddings_indices,
42+
long_lora_offsets_tensor,
43+
indices_len,
44+
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
45+
extra_vocab_size, self.device, None)
46+
# Updating each element in `long_lora_offsets` with `lora_offset` slows
47+
# down perf in HPU due to a series of `strided_insert` ops during lazy
48+
# graph accumulation. Hence HPU appends `lora_offset` to a list and
49+
# converts it to a tensor only after it is ready.
50+
if long_lora_context:
51+
index_mapping_indices: List[int] = list(
52+
mapping.index_mapping).copy()
53+
long_lora_offsets: List[int] = []
54+
for i in range(len(index_mapping_indices)):
55+
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
56+
index_mapping_indices[i], 0)
57+
long_lora_offsets.append(lora_offset)
58+
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
59+
device=self.device,
60+
dtype=torch.long)
61+
indices_len[-1] = long_lora_offsets_tensor.shape[-1]
62+
63+
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
64+
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
65+
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
66+
sampler_indices_padded)
67+
self._embeddings_indices[:embeddings_indices.
68+
shape[0], :embeddings_indices.shape[1]].copy_(
69+
embeddings_indices)
70+
if long_lora_offsets_tensor is not None:
71+
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
72+
long_lora_offsets_tensor)
73+
else:
74+
self._long_lora_indices.zero_()
75+
self.indices_len[:] = indices_len
76+
2277
def add_lora_embedding(self,
2378
y: torch.Tensor,
2479
x: torch.Tensor,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,10 @@ def forward_hpu(
206206
) -> Tuple[torch.Tensor, torch.Tensor]:
207207
from habana_frameworks.torch.hpex.kernels import (
208208
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
209-
positions = positions.flatten()
210209
if offsets is not None:
210+
offsets = offsets.view(positions.shape[0], -1)
211211
positions = positions + offsets
212+
positions = positions.flatten()
212213
num_tokens = positions.shape[0]
213214
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
214215
num_tokens, 1, -1)

vllm/worker/hpu_model_runner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,12 +639,25 @@ def load_model(self) -> None:
639639
"Bias support in LoRA is not enabled in HPU yet."
640640
assert not self.lora_config.fully_sharded_loras, \
641641
"Fully sharded LoRAs is not enabled in HPU yet."
642+
# It's necessary to distinguish between the
643+
# max_position_embeddings of VLMs and LLMs.
644+
if hasattr(self.model.config, "max_position_embeddings"):
645+
max_pos_embeddings = (
646+
self.model.config.max_position_embeddings)
647+
else:
648+
max_pos_embeddings = (
649+
self.model.config.text_config.max_position_embeddings)
650+
642651
self.lora_manager = LRUCacheWorkerLoRAManager(
643652
self.scheduler_config.max_num_seqs,
644653
self.scheduler_config.max_num_batched_tokens,
645-
self.vocab_size, self.lora_config, self.device,
654+
self.vocab_size,
655+
self.lora_config,
656+
self.device,
646657
self.model.embedding_modules,
647-
self.model.embedding_padding_modules)
658+
self.model.embedding_padding_modules,
659+
max_position_embeddings=max_pos_embeddings,
660+
)
648661
self.model = self.lora_manager.create_lora_manager(self.model)
649662

650663
if self.model_config.quantization == 'inc':

0 commit comments

Comments
 (0)