@@ -585,6 +585,7 @@ def progress(loaded_modules: int, total_modules: int)
585
585
cache_class = draft_cache_class ,
586
586
autosplit = True ,
587
587
use_tp = False ,
588
+ model = self .draft_model ,
588
589
)
589
590
590
591
for value in self .draft_model .load_autosplit_gen (
@@ -635,6 +636,7 @@ def progress(loaded_modules: int, total_modules: int)
635
636
cache_class = cache_class ,
636
637
autosplit = self .gpu_split_auto ,
637
638
use_tp = self .use_tp ,
639
+ model = self .model ,
638
640
)
639
641
640
642
# Load model with autosplit (without TP)
@@ -669,20 +671,24 @@ def get_cache_class(self, cache_mode: str):
669
671
return ExLlamaV2Cache
670
672
671
673
def create_cache (
672
- self , cache_class : ExLlamaV2CacheBase , autosplit : bool , use_tp : bool
674
+ self ,
675
+ cache_class : ExLlamaV2CacheBase ,
676
+ autosplit : bool ,
677
+ use_tp : bool ,
678
+ model : ExLlamaV2 ,
673
679
):
674
680
"""Utility function to create a model cache."""
675
681
676
682
if has_tp and use_tp :
677
683
return ExLlamaV2Cache_TP (
678
- self . model ,
684
+ model ,
679
685
base = cache_class ,
680
686
max_seq_len = self .cache_size ,
681
687
batch_size = 1 ,
682
688
)
683
689
else :
684
690
return cache_class (
685
- self . model ,
691
+ model ,
686
692
max_seq_len = self .cache_size ,
687
693
lazy = autosplit ,
688
694
batch_size = 1 ,
@@ -865,7 +871,7 @@ def get_special_tokens(
865
871
def get_logprobs (self , token_ids : torch .Tensor , token_probs : torch .Tensor ):
866
872
top_tokens = [
867
873
self .tokenizer .extended_id_to_piece .get (
868
- index , self .tokenizer .id_to_piece [index ]
874
+ index , self .tokenizer .get_id_to_piece_list ( True ) [index ]
869
875
)
870
876
for index in token_ids .flatten ().tolist ()
871
877
]
@@ -1140,7 +1146,7 @@ async def generate_gen(
1140
1146
1141
1147
# Map logits to the tensor with their biases
1142
1148
for token_id , bias in logit_bias .items ():
1143
- if 0 <= token_id < len (self .tokenizer .id_to_piece ):
1149
+ if 0 <= token_id < len (self .tokenizer .get_id_to_piece_list ( True ) ):
1144
1150
gen_settings .token_bias [token_id ] = bias
1145
1151
else :
1146
1152
logger .warning (
0 commit comments