@@ -374,7 +374,7 @@ def copy_from_last():
374
374
375
375
for warper in logits_warper :
376
376
#assert type(warper) == TemperatureLogitsWarper or type(warper) == TopPLogitsWarper or type(warper) == TopKLogitsWarper, f"please set top_k=0 {warper}"
377
- assert type (warper ) == TemperatureLogitsWarper , f"please set top_k=0.0 and top_p=1.0 { warper } "
377
+ assert type (warper ) == TemperatureLogitsWarper or type ( warper ) == TopKLogitsWarper or type ( warper ) == TopPLogitsWarper , f"please set top_k=0.0 and top_p=1.0 { warper } "
378
378
379
379
# auto-regressive generation
380
380
while True :
@@ -485,9 +485,8 @@ def copy_from_last():
485
485
probs_next = torch .nn .functional .softmax (next_token_scores , dim = - 1 )[0 ]
486
486
hits = []
487
487
#= original model output
488
- #print("size: ", input_ids.size(), outputs.guess_logits.size())
489
- guess_logits = logits_warper (input_ids , outputs .guess_logits )
490
- guess_probs = torch .nn .functional .softmax (guess_logits , dim = - 1 )[0 ] #
488
+ guess_logits = logits_warper (input_ids , outputs .guess_logits [0 ])
489
+ guess_probs = torch .nn .functional .softmax (guess_logits , dim = - 1 ) #
491
490
#guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
492
491
guess_indices = list (range (outputs .guess_logits .size (1 ) // GUESS_SIZE ))
493
492
#algorithm modified from specinfer
@@ -887,7 +886,7 @@ def random_set():
887
886
888
887
def copy_from ():
889
888
return random .choice (all_old_tokens )
890
-
889
+
891
890
def order_copy_from ():
892
891
if order_copy_from_idx [0 ] >= len (all_old_tokens ):
893
892
order_copy_from_idx [0 ] = 0
@@ -915,12 +914,12 @@ def copy_from_last():
915
914
916
915
if POOL_FROM_PROMPT :
917
916
fill_pool_with_prompt (all_old_tokens , token_map , LEVEL , GUESS_SET_SIZE )
918
-
917
+
919
918
if chat :
920
919
init = self .tokenizer .decode (all_old_tokens , skip_special_tokens = True , \
921
920
spaces_between_special_tokens = False , clean_up_tokenization_spaces = True ,)
922
921
prev = len (init )
923
-
922
+
924
923
while True :
925
924
if synced_gpus :
926
925
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@@ -964,7 +963,7 @@ def copy_from_last():
964
963
guess_tokens = None
965
964
else :
966
965
guess_tokens = None
967
-
966
+
968
967
assert return_dict_in_generate == False
969
968
assert len (logits_processor ) == 0
970
969
# forward pass to get next token
@@ -985,7 +984,7 @@ def copy_from_last():
985
984
past_tokens_inp .append (tokens [window_start : window_end ] if tokens is not None else None )
986
985
else :
987
986
past_tokens_inp = past_tokens
988
-
987
+
989
988
outputs = self .jforward_multilevel (
990
989
** model_inputs ,
991
990
past_tokens = past_tokens_inp ,
@@ -1040,7 +1039,7 @@ def copy_from_last():
1040
1039
assert fill_level == 0
1041
1040
past_tokens [0 ] = past_tokens [0 ][1 :]
1042
1041
past_tokens [1 ] = torch .argmax (outputs .inp_logits , dim = - 1 )[0 ].tolist ()
1043
-
1042
+
1044
1043
if DIST_WORKERS > 1 :
1045
1044
nn_past_tokens = [copy .deepcopy (past_tokens [1 ])]
1046
1045
torch .distributed .broadcast_object_list (nn_past_tokens , src = DIST_WORKERS - 1 )
@@ -1051,7 +1050,7 @@ def copy_from_last():
1051
1050
for level in range (fill_level + 1 ):
1052
1051
past_tokens [level ] = past_tokens [level ][1 :]
1053
1052
current_past_tokens = torch .argmax (outputs .inp_logits , dim = - 1 )[0 ].tolist ()
1054
-
1053
+
1055
1054
1056
1055
if DIST_WORKERS > 1 :
1057
1056
nn_past_tokens = [None ] * DIST_WORKERS
@@ -1063,9 +1062,9 @@ def copy_from_last():
1063
1062
past_tokens [fill_level + 1 ] = current_past_tokens [1 :]
1064
1063
#print("new past: ", (LOCAL_RANK, past_tokens))
1065
1064
1066
-
1065
+
1067
1066
fill_level += 1
1068
- else :
1067
+ else :
1069
1068
#time.sleep(10000)
1070
1069
#multi-level window is filled
1071
1070
#match guess tokens
@@ -1101,7 +1100,7 @@ def copy_from_last():
1101
1100
# print("rank: ",hits, max_hit)
1102
1101
#sync new_results
1103
1102
new_results = torch .argmax (outputs .inp_logits , dim = - 1 )[0 ].tolist ()
1104
-
1103
+
1105
1104
if DIST_WORKERS > 1 :
1106
1105
nn_past_tokens = [None ] * DIST_WORKERS
1107
1106
torch .distributed .all_gather_object (nn_past_tokens , new_results )
@@ -1149,7 +1148,7 @@ def copy_from_last():
1149
1148
if DIST_WORKERS > 1 and max_hit > 0 :
1150
1149
1151
1150
guess_skip_dist = max_hit
1152
- for idx , kv in enumerate (outputs .past_key_values ):
1151
+ for idx , kv in enumerate (outputs .past_key_values ):
1153
1152
past_key_values .append ( (kv [0 ][:,:,:outputs .kvcache_len ,:], kv [1 ][:,:,:outputs .kvcache_len ,:]) )
1154
1153
outputs .past_key_values = past_key_values
1155
1154
else :
@@ -1160,8 +1159,8 @@ def copy_from_last():
1160
1159
if max_hit > 0 :
1161
1160
kv [0 ][:,:,outputs .kvcache_len :outputs .kvcache_len + max_hit ,:] = kv [0 ][:,:,offset_kv_cache :offset_kv_cache + max_hit ,:]
1162
1161
kv [1 ][:,:,outputs .kvcache_len :outputs .kvcache_len + max_hit ,:] = kv [1 ][:,:,offset_kv_cache :offset_kv_cache + max_hit ,:]
1163
- past_key_values .append ( (kv [0 ][:,:,:outputs .kvcache_len + max_hit ,:], kv [1 ][:,:,:outputs .kvcache_len + max_hit ,:]) )
1164
- outputs .past_key_values = past_key_values
1162
+ past_key_values .append ( (kv [0 ][:,:,:outputs .kvcache_len + max_hit ,:], kv [1 ][:,:,:outputs .kvcache_len + max_hit ,:]) )
1163
+ outputs .past_key_values = past_key_values
1165
1164
1166
1165
lst_token = hits [max_hit ]
1167
1166
@@ -1176,7 +1175,7 @@ def copy_from_last():
1176
1175
all_old_tokens .append (hits [max_hit ])
1177
1176
if POOL_FROM_PROMPT :
1178
1177
append_new_generated_pool (all_old_tokens [- LEVEL :], token_map , LEVEL , GUESS_SET_SIZE )
1179
-
1178
+
1180
1179
1181
1180
if chat and LOCAL_RANK == 0 :
1182
1181
all_str = self .tokenizer .decode (all_old_tokens , skip_special_tokens = True , \
@@ -1188,7 +1187,7 @@ def copy_from_last():
1188
1187
spaces_between_special_tokens = False , clean_up_tokenization_spaces = True ,)
1189
1188
pt = colored (not_hit [prev :],"blue" ) + colored (all_str [len (not_hit ):], "blue" )
1190
1189
else :
1191
- pt = all_str [prev :]
1190
+ pt = all_str [prev :]
1192
1191
print (pt , flush = True , end = "" )
1193
1192
else :
1194
1193
print (all_str [prev :], flush = True , end = "" )
@@ -1440,7 +1439,7 @@ def greedy_search_chat(
1440
1439
1441
1440
# prepare model inputs
1442
1441
model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
1443
-
1442
+
1444
1443
# forward pass to get next token
1445
1444
outputs = self (
1446
1445
** model_inputs ,
0 commit comments