Skip to content

Commit 37516d0

Browse files
committed
add sampling
1 parent f7bca84 commit 37516d0

File tree

7 files changed

+1735
-491
lines changed

7 files changed

+1735
-491
lines changed

README.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
| <a href="https://arxiv.org/abs/2402.02057"><b>Paper</b></a> | <a href="https://lmsys.org/blog/2023-11-21-lookahead-decoding/"><b>Blog</b></a> | <a href="https://github.com/hao-ai-lab/LookaheadDecoding/issues/13"><b>Roadmap</b></a> |
55
</p>
66

7+
---
8+
*News* 🔥
9+
- [2024/2] Lookahead Decoding Paper now available on [arXiv](https://arxiv.org/abs/2402.02057). Sampling and FlashAttention are supported. Advanced features for better token prediction are updated.
10+
11+
---
712
## Introduction
813
We introduce lookahead decoding:
914
- A parallel decoding algorithm to accelerate LLM inference.
@@ -148,14 +153,22 @@ lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, DEBUG=0)
148153
#You can obtain a better performance by tuning LEVEL/WINDOW_SIZE/GUESS_SET_SIZE on your own device.
149154
```
150155

151-
Then you can speedup the decoding process.
156+
Then you can speedup the decoding process. Here is an example using greedy search:
152157
```
153158
tokenizer = AutoTokenizer.from_pretrained(model_name)
154159
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
155160
model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
156161
greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained
157162
```
158163

164+
Then you can speedup the decoding process. Here is an example using sampling:
165+
```
166+
tokenizer = AutoTokenizer.from_pretrained(model_name)
167+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
168+
model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
169+
sample_output = model.generate(**model_inputs, max_new_tokens=1024, temperature=0.7) #speedup obtained
170+
```
171+
159172
## Citation
160173
```bibtex
161174
@misc{fu2024break,

lade/decoding.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def copy_from_last():
374374

375375
for warper in logits_warper:
376376
#assert type(warper) == TemperatureLogitsWarper or type(warper) == TopPLogitsWarper or type(warper) == TopKLogitsWarper, f"please set top_k=0 {warper}"
377-
assert type(warper) == TemperatureLogitsWarper, f"please set top_k=0.0 and top_p=1.0 {warper}"
377+
assert type(warper) == TemperatureLogitsWarper or type(warper) == TopKLogitsWarper or type(warper) == TopPLogitsWarper, f"please set top_k=0.0 and top_p=1.0 {warper}"
378378

379379
# auto-regressive generation
380380
while True:
@@ -485,9 +485,8 @@ def copy_from_last():
485485
probs_next = torch.nn.functional.softmax(next_token_scores, dim=-1)[0]
486486
hits = []
487487
#= original model output
488-
#print("size: ", input_ids.size(), outputs.guess_logits.size())
489-
guess_logits = logits_warper(input_ids, outputs.guess_logits)
490-
guess_probs = torch.nn.functional.softmax(guess_logits, dim=-1)[0] #
488+
guess_logits = logits_warper(input_ids, outputs.guess_logits[0])
489+
guess_probs = torch.nn.functional.softmax(guess_logits, dim=-1) #
491490
#guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
492491
guess_indices = list(range(outputs.guess_logits.size(1) // GUESS_SIZE))
493492
#algorithm modified from specinfer
@@ -887,7 +886,7 @@ def random_set():
887886

888887
def copy_from():
889888
return random.choice(all_old_tokens)
890-
889+
891890
def order_copy_from():
892891
if order_copy_from_idx[0] >= len(all_old_tokens):
893892
order_copy_from_idx[0] = 0
@@ -915,12 +914,12 @@ def copy_from_last():
915914

916915
if POOL_FROM_PROMPT:
917916
fill_pool_with_prompt(all_old_tokens, token_map, LEVEL, GUESS_SET_SIZE)
918-
917+
919918
if chat:
920919
init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
921920
spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
922921
prev = len(init)
923-
922+
924923
while True:
925924
if synced_gpus:
926925
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@@ -964,7 +963,7 @@ def copy_from_last():
964963
guess_tokens = None
965964
else:
966965
guess_tokens = None
967-
966+
968967
assert return_dict_in_generate == False
969968
assert len(logits_processor) == 0
970969
# forward pass to get next token
@@ -985,7 +984,7 @@ def copy_from_last():
985984
past_tokens_inp.append(tokens[window_start: window_end] if tokens is not None else None)
986985
else:
987986
past_tokens_inp = past_tokens
988-
987+
989988
outputs = self.jforward_multilevel(
990989
**model_inputs,
991990
past_tokens=past_tokens_inp,
@@ -1040,7 +1039,7 @@ def copy_from_last():
10401039
assert fill_level == 0
10411040
past_tokens[0] = past_tokens[0][1:]
10421041
past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1043-
1042+
10441043
if DIST_WORKERS > 1:
10451044
nn_past_tokens = [copy.deepcopy(past_tokens[1])]
10461045
torch.distributed.broadcast_object_list(nn_past_tokens, src=DIST_WORKERS - 1)
@@ -1051,7 +1050,7 @@ def copy_from_last():
10511050
for level in range(fill_level + 1):
10521051
past_tokens[level] = past_tokens[level][1:]
10531052
current_past_tokens = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1054-
1053+
10551054

10561055
if DIST_WORKERS > 1:
10571056
nn_past_tokens = [None] * DIST_WORKERS
@@ -1063,9 +1062,9 @@ def copy_from_last():
10631062
past_tokens[fill_level + 1] = current_past_tokens[1:]
10641063
#print("new past: ", (LOCAL_RANK, past_tokens))
10651064

1066-
1065+
10671066
fill_level += 1
1068-
else:
1067+
else:
10691068
#time.sleep(10000)
10701069
#multi-level window is filled
10711070
#match guess tokens
@@ -1101,7 +1100,7 @@ def copy_from_last():
11011100
# print("rank: ",hits, max_hit)
11021101
#sync new_results
11031102
new_results = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1104-
1103+
11051104
if DIST_WORKERS > 1:
11061105
nn_past_tokens = [None] * DIST_WORKERS
11071106
torch.distributed.all_gather_object(nn_past_tokens, new_results)
@@ -1149,7 +1148,7 @@ def copy_from_last():
11491148
if DIST_WORKERS > 1 and max_hit > 0:
11501149

11511150
guess_skip_dist = max_hit
1152-
for idx, kv in enumerate(outputs.past_key_values):
1151+
for idx, kv in enumerate(outputs.past_key_values):
11531152
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len,:], kv[1][:,:,:outputs.kvcache_len,:]) )
11541153
outputs.past_key_values = past_key_values
11551154
else:
@@ -1160,8 +1159,8 @@ def copy_from_last():
11601159
if max_hit > 0:
11611160
kv[0][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[0][:,:,offset_kv_cache:offset_kv_cache+max_hit,:]
11621161
kv[1][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[1][:,:,offset_kv_cache:offset_kv_cache+max_hit,:]
1163-
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
1164-
outputs.past_key_values = past_key_values
1162+
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
1163+
outputs.past_key_values = past_key_values
11651164

11661165
lst_token = hits[max_hit]
11671166

@@ -1176,7 +1175,7 @@ def copy_from_last():
11761175
all_old_tokens.append(hits[max_hit])
11771176
if POOL_FROM_PROMPT:
11781177
append_new_generated_pool(all_old_tokens[-LEVEL:], token_map, LEVEL, GUESS_SET_SIZE)
1179-
1178+
11801179

11811180
if chat and LOCAL_RANK == 0:
11821181
all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
@@ -1188,7 +1187,7 @@ def copy_from_last():
11881187
spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
11891188
pt = colored(not_hit[prev:],"blue") + colored(all_str[len(not_hit):], "blue")
11901189
else:
1191-
pt = all_str[prev:]
1190+
pt = all_str[prev:]
11921191
print(pt, flush=True, end="")
11931192
else:
11941193
print(all_str[prev:], flush=True, end="")
@@ -1440,7 +1439,7 @@ def greedy_search_chat(
14401439

14411440
# prepare model inputs
14421441
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1443-
1442+
14441443
# forward pass to get next token
14451444
outputs = self(
14461445
**model_inputs,

0 commit comments

Comments
 (0)