Skip to content

Commit 50f20a6

Browse files
committed
2 parents 342818c + 045bc98 commit 50f20a6

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

backends/exllamav2/model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def progress(loaded_modules: int, total_modules: int)
585585
cache_class=draft_cache_class,
586586
autosplit=True,
587587
use_tp=False,
588+
model=self.draft_model,
588589
)
589590

590591
for value in self.draft_model.load_autosplit_gen(
@@ -635,6 +636,7 @@ def progress(loaded_modules: int, total_modules: int)
635636
cache_class=cache_class,
636637
autosplit=self.gpu_split_auto,
637638
use_tp=self.use_tp,
639+
model=self.model,
638640
)
639641

640642
# Load model with autosplit (without TP)
@@ -669,20 +671,24 @@ def get_cache_class(self, cache_mode: str):
669671
return ExLlamaV2Cache
670672

671673
def create_cache(
672-
self, cache_class: ExLlamaV2CacheBase, autosplit: bool, use_tp: bool
674+
self,
675+
cache_class: ExLlamaV2CacheBase,
676+
autosplit: bool,
677+
use_tp: bool,
678+
model: ExLlamaV2,
673679
):
674680
"""Utility function to create a model cache."""
675681

676682
if has_tp and use_tp:
677683
return ExLlamaV2Cache_TP(
678-
self.model,
684+
model,
679685
base=cache_class,
680686
max_seq_len=self.cache_size,
681687
batch_size=1,
682688
)
683689
else:
684690
return cache_class(
685-
self.model,
691+
model,
686692
max_seq_len=self.cache_size,
687693
lazy=autosplit,
688694
batch_size=1,
@@ -865,7 +871,7 @@ def get_special_tokens(
865871
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
866872
top_tokens = [
867873
self.tokenizer.extended_id_to_piece.get(
868-
index, self.tokenizer.id_to_piece[index]
874+
index, self.tokenizer.get_id_to_piece_list(True)[index]
869875
)
870876
for index in token_ids.flatten().tolist()
871877
]
@@ -1140,7 +1146,7 @@ async def generate_gen(
11401146

11411147
# Map logits to the tensor with their biases
11421148
for token_id, bias in logit_bias.items():
1143-
if 0 <= token_id < len(self.tokenizer.id_to_piece):
1149+
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
11441150
gen_settings.token_bias[token_id] = bias
11451151
else:
11461152
logger.warning(

0 commit comments

Comments
 (0)