1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
- from typing import Optional , Tuple , Union , final
3
+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union , final
4
4
5
5
import torch
6
6
from vllm_hpu_extension .ops import (dispatch_bgmv_embedding ,
7
7
dispatch_bgmv_linear )
8
8
9
9
from .punica_base import PunicaWrapperBase
10
+ from .utils import convert_mapping
11
+
12
+ if TYPE_CHECKING :
13
+ # avoid circuit import
14
+ from vllm .lora .layers import LoRAMapping
15
+ from vllm .lora .models import LongContextLoRAContext
10
16
11
17
12
18
@final
@@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
19
25
PunicaWrapperBase .__init__ (self , 3 * max_num_batched_tokens ,
20
26
max_batches , device )
21
27
28
+ def _update_base_metadata (
29
+ self ,
30
+ mapping : "LoRAMapping" ,
31
+ lora_index_to_id : List [Optional [int ]],
32
+ max_loras : int ,
33
+ vocab_size : int ,
34
+ extra_vocab_size : int ,
35
+ long_lora_context : Optional ["LongContextLoRAContext" ] = None ,
36
+ ):
37
+ (
38
+ base_indices ,
39
+ sampler_indices ,
40
+ sampler_indices_padded ,
41
+ embeddings_indices ,
42
+ long_lora_offsets_tensor ,
43
+ indices_len ,
44
+ ) = convert_mapping (mapping , lora_index_to_id , max_loras , vocab_size ,
45
+ extra_vocab_size , self .device , None )
46
+ # Updating each element in `long_lora_offsets` with `lora_offset` slows
47
+ # down perf in HPU due to a series of `strided_insert` ops during lazy
48
+ # graph accumulation. Hence HPU appends `lora_offset` to a list and
49
+ # converts it to a tensor only after it is ready.
50
+ if long_lora_context :
51
+ index_mapping_indices : List [int ] = list (
52
+ mapping .index_mapping ).copy ()
53
+ long_lora_offsets : List [int ] = []
54
+ for i in range (len (index_mapping_indices )):
55
+ lora_offset : int = long_lora_context .offsets_by_lora_id .get (
56
+ index_mapping_indices [i ], 0 )
57
+ long_lora_offsets .append (lora_offset )
58
+ long_lora_offsets_tensor = torch .tensor (long_lora_offsets ,
59
+ device = self .device ,
60
+ dtype = torch .long )
61
+ indices_len [- 1 ] = long_lora_offsets_tensor .shape [- 1 ]
62
+
63
+ self ._token_lora_indices [:base_indices .shape [0 ]].copy_ (base_indices )
64
+ self ._sampler_indices [:sampler_indices .shape [0 ]].copy_ (sampler_indices )
65
+ self ._sampler_indices_padded [:sampler_indices_padded .shape [0 ]].copy_ (
66
+ sampler_indices_padded )
67
+ self ._embeddings_indices [:embeddings_indices .
68
+ shape [0 ], :embeddings_indices .shape [1 ]].copy_ (
69
+ embeddings_indices )
70
+ if long_lora_offsets_tensor is not None :
71
+ self ._long_lora_indices [:long_lora_offsets_tensor .shape [0 ]].copy_ (
72
+ long_lora_offsets_tensor )
73
+ else :
74
+ self ._long_lora_indices .zero_ ()
75
+ self .indices_len [:] = indices_len
76
+
22
77
def add_lora_embedding (self ,
23
78
y : torch .Tensor ,
24
79
x : torch .Tensor ,
0 commit comments