Skip to content

Commit 5c6f57e

Browse files
Constrained Beam Search [*With* Disjunctive Decoding] (#15761)
* added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements * finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing. * fixed bug found in constrained beam search that used beam_idx that were not global across all the batches * disjunctive constraint working 100% correctly * passing all tests * Accidentally included mlruns * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <[email protected]> * complete overhaul of type complexities and other nits * strict type checks in generate() * fixing second round of feedback by narsil * fixed failing generation test because of type check overhaul * generation test fail fix * fixing test fails Co-authored-by: Patrick von Platen <[email protected]>
1 parent 040c11f commit 5c6f57e

File tree

9 files changed

+586
-75
lines changed

9 files changed

+586
-75
lines changed

docs/source/internal/generation_utils.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
229229

230230
[[autodoc]] PhrasalConstraint
231231

232+
[[autodoc]] DisjunctiveConstraint
233+
232234
[[autodoc]] ConstraintListState
233235

234236
## BeamSearch

src/transformers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@
623623
_import_structure["generation_beam_constraints"] = [
624624
"Constraint",
625625
"ConstraintListState",
626+
"DisjunctiveConstraint",
626627
"PhrasalConstraint",
627628
]
628629
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
@@ -2857,7 +2858,12 @@
28572858
TextDataset,
28582859
TextDatasetForNextSentencePrediction,
28592860
)
2860-
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
2861+
from .generation_beam_constraints import (
2862+
Constraint,
2863+
ConstraintListState,
2864+
DisjunctiveConstraint,
2865+
PhrasalConstraint,
2866+
)
28612867
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28622868
from .generation_logits_process import (
28632869
ForcedBOSTokenLogitsProcessor,

src/transformers/generation_beam_constraints.py

Lines changed: 181 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Union
3-
4-
import torch
2+
from typing import List, Optional
53

64

75
class Constraint(ABC):
@@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint):
137135
The id of the token that must be generated by the output.
138136
"""
139137

140-
def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
138+
def __init__(self, token_ids: List[int]):
141139
super(Constraint, self).__init__()
142140

143-
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
144-
is_tensor = isinstance(token_ids, torch.Tensor)
145-
is_int_tensor = (
146-
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
147-
)
148-
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
149-
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
150-
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
151-
152-
if not is_tensor:
153-
token_ids = torch.tensor(token_ids)
141+
if not isinstance(token_ids, list) or len(token_ids) == 0:
142+
raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.")
143+
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
144+
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
154145

155146
self.token_ids = token_ids
156147

157-
self.seqlen = self.token_ids.size(0)
148+
self.seqlen = len(self.token_ids)
158149
self.fulfilled_idx = -1 # the index of the currently fulfilled step
159150
self.completed = False
160151

161152
def advance(self):
153+
if self.completed:
154+
return None
162155
return self.token_ids[self.fulfilled_idx + 1]
163156

164157
def does_advance(self, token_id: int):
158+
if not isinstance(token_id, int):
159+
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
160+
165161
if self.completed:
166162
return False
167-
# move to cpu to guarantee no device issues.
168-
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()
163+
164+
return token_id == self.token_ids[self.fulfilled_idx + 1]
169165

170166
def update(self, token_id: int):
167+
if not isinstance(token_id, int):
168+
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
169+
171170
stepped = False
172171
completed = False
173172
reset = False
@@ -202,6 +201,151 @@ def copy(self, stateful=False):
202201
return new_constraint
203202

204203

204+
class DisjunctiveTrie:
205+
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
206+
r"""
207+
A helper class that builds a trie with the words represented in `nested_token_ids`.
208+
"""
209+
self.max_height = max([len(one) for one in nested_token_ids])
210+
211+
root = dict()
212+
for token_ids in nested_token_ids:
213+
level = root
214+
for tidx, token_id in enumerate(token_ids):
215+
if token_id not in level:
216+
level[token_id] = dict()
217+
218+
level = level[token_id]
219+
220+
if no_subsets and self.has_subsets(root, nested_token_ids):
221+
raise ValueError(
222+
f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}."
223+
)
224+
225+
self.trie = root
226+
227+
def next_tokens(self, current_seq):
228+
"""
229+
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
230+
"""
231+
start = self.trie
232+
233+
for current_token in current_seq:
234+
start = start[current_token]
235+
236+
next_tokens = list(start.keys())
237+
238+
return next_tokens
239+
240+
def reached_leaf(self, current_seq):
241+
next_tokens = self.next_tokens(current_seq)
242+
243+
return len(next_tokens) == 0
244+
245+
def count_leaves(self, root):
246+
next_nodes = list(root.values())
247+
if len(next_nodes) == 0:
248+
return 1
249+
else:
250+
return sum([self.count_leaves(nn) for nn in next_nodes])
251+
252+
def has_subsets(self, trie, nested_token_ids):
253+
"""
254+
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
255+
"""
256+
leaf_count = self.count_leaves(trie)
257+
return len(nested_token_ids) != leaf_count
258+
259+
260+
class DisjunctiveConstraint(Constraint):
261+
r"""
262+
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
263+
264+
Args:
265+
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint
266+
is fulfilled by generating just one from the list of words.
267+
"""
268+
269+
def __init__(self, nested_token_ids: List[List[int]]):
270+
super(Constraint, self).__init__()
271+
272+
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
273+
raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.")
274+
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
275+
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
276+
if any(
277+
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
278+
for token_ids in nested_token_ids
279+
):
280+
raise ValueError(
281+
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
282+
)
283+
284+
self.trie = DisjunctiveTrie(nested_token_ids)
285+
self.token_ids = nested_token_ids
286+
287+
self.seqlen = self.trie.max_height
288+
self.current_seq = []
289+
self.completed = False
290+
291+
def advance(self):
292+
token_list = self.trie.next_tokens(self.current_seq)
293+
294+
if len(token_list) == 0:
295+
return None
296+
else:
297+
return token_list
298+
299+
def does_advance(self, token_id: int):
300+
if not isinstance(token_id, int):
301+
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
302+
303+
next_tokens = self.trie.next_tokens(self.current_seq)
304+
305+
return token_id in next_tokens
306+
307+
def update(self, token_id: int):
308+
if not isinstance(token_id, int):
309+
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
310+
311+
stepped = False
312+
completed = False
313+
reset = False
314+
315+
if self.does_advance(token_id):
316+
self.current_seq.append(token_id)
317+
stepped = True
318+
else:
319+
reset = True
320+
self.reset()
321+
322+
completed = self.trie.reached_leaf(self.current_seq)
323+
self.completed = completed
324+
325+
return stepped, completed, reset
326+
327+
def reset(self):
328+
self.completed = False
329+
self.current_seq = []
330+
331+
def remaining(self):
332+
if self.completed:
333+
# since this can be completed without reaching max height
334+
return 0
335+
else:
336+
return self.seqlen - len(self.current_seq)
337+
338+
def copy(self, stateful=False):
339+
new_constraint = DisjunctiveConstraint(self.token_ids)
340+
341+
if stateful:
342+
new_constraint.seq_len = self.seqlen
343+
new_constraint.current_seq = self.current_seq
344+
new_constraint.completed = self.completed
345+
346+
return new_constraint
347+
348+
205349
class ConstraintListState:
206350
r"""
207351
A class for beam scorers to track its progress through a list of constraints.
@@ -215,7 +359,7 @@ def __init__(self, constraints: List[Constraint]):
215359
self.constraints = constraints
216360

217361
# max # of steps required to fulfill a given constraint
218-
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
362+
self.max_seqlen = max([c.seqlen for c in constraints])
219363
self.n_constraints = len(constraints)
220364
self.completed = False
221365

@@ -249,26 +393,33 @@ def advance(self):
249393
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
250394
that's the only one we'll return.
251395
"""
396+
token_list = []
252397
if self.inprogress_constraint is None:
253-
token_list = []
254398
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
255399
advance = constraint.advance()
256-
token_list.append(advance)
400+
if isinstance(advance, int):
401+
token_list.append(advance)
402+
elif isinstance(advance, list):
403+
token_list.extend(advance)
257404
else:
258-
token_list = [self.inprogress_constraint.advance()]
405+
advance = self.inprogress_constraint.advance()
406+
if isinstance(advance, int):
407+
token_list.append(advance)
408+
elif isinstance(advance, list):
409+
token_list.extend(advance)
259410

260411
if len(token_list) == 0:
261412
return None
262413
else:
263-
return torch.stack(token_list)
414+
return token_list
264415

265-
def reset(self, token_ids: Optional[torch.LongTensor]):
416+
def reset(self, token_ids: Optional[List[int]]):
266417
"""
267418
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
268419
"""
269420
self.init_state()
270421

271-
if token_ids is not None and token_ids.size(0) > 0:
422+
if token_ids is not None:
272423
for token in token_ids:
273424
# completes or steps **one** constraint
274425
complete, stepped = self.add(token)
@@ -277,9 +428,10 @@ def reset(self, token_ids: Optional[torch.LongTensor]):
277428
if self.completed:
278429
break
279430

280-
return self
431+
def add(self, token_id: int):
432+
if not isinstance(token_id, int):
433+
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.")
281434

282-
def add(self, token_id: Union[int, torch.LongTensor]):
283435
complete, stepped = False, False
284436

285437
if self.completed:
@@ -324,8 +476,8 @@ def add(self, token_id: Union[int, torch.LongTensor]):
324476

325477
if not stepped:
326478
raise Exception(
327-
"constraint.update(token_id) is not yielding incremental progress, "
328-
"even though constraint.does_advance(token_id) is true."
479+
"`constraint.update(token_id)` is not yielding incremental progress, "
480+
"even though `constraint.does_advance(token_id)` is true."
329481
)
330482

331483
if complete:

src/transformers/generation_beam_search.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def make_constraint_states(self, n):
443443

444444
def check_completes_constraints(self, sequence):
445445
new_state = self.make_constraint_states(1)[0]
446-
new_state = new_state.reset(sequence)
446+
new_state.reset(sequence)
447447
return new_state.completed
448448

449449
def process(
@@ -484,6 +484,7 @@ def process(
484484
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
485485
all
486486
non-finished beams.
487+
487488
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
488489
added
489490
to the non-finished beam_hypotheses.
@@ -537,7 +538,7 @@ def process(
537538
if is_beam_token_worse_than_top_num_beams:
538539
continue
539540

540-
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx])
541+
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
541542
if completes_constraint:
542543
beam_hyp.add(
543544
input_ids[batch_beam_idx].clone(),
@@ -628,23 +629,23 @@ def step_sentence_constraint(
628629
# hypotheses.
629630

630631
topk_state = topk_contraint_states[seq_idx]
631-
topk_state.reset(full_hypotheses[seq_idx])
632+
topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
632633

633634
advance_state = advance_constraint_states[seq_idx]
634-
advance_state.reset(pre_seq)
635+
advance_state.reset(pre_seq.cpu().tolist())
635636

636637
if not advance_state.completed:
637-
advance_tokens = advance_state.advance()
638-
for advance_token in advance_tokens.to(device):
638+
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
639+
for advance_token in advance_tokens:
639640
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
640641
new_state = advance_state.copy(stateful=True)
641-
new_state.add(advance_token)
642+
new_state.add(advance_token.cpu().tolist())
642643

643644
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
644645
if advance_seq not in track_new["new_seqs"]:
645646
# prevent duplicates, which are basically bound to happen in this process.
646647
track_new["new_seqs"].append(advance_seq)
647-
track_new["new_indices"].append(seq_idx)
648+
track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
648649
track_new["new_tokens"].append(advance_token)
649650
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
650651
track_new["new_states"].append(new_state)
@@ -673,8 +674,9 @@ def step_sentence_constraint(
673674

674675
advance_state = advance_constraint_states[seq_idx]
675676

676-
advance_state.reset(advance_seq)
677677
advance_seq = advance_seq.cpu().tolist()
678+
679+
advance_state.reset(advance_seq)
678680
if advance_seq not in track_new["new_seqs"]:
679681
# but still don't want to have duplicates
680682
track_new["new_seqs"].append(advance_seq)
@@ -745,7 +747,7 @@ def finalize(
745747
final_score = final_beam_scores[batch_beam_idx].item()
746748
final_tokens = input_ids[batch_beam_idx]
747749

748-
completes_constraint = self.check_completes_constraints(final_tokens)
750+
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
749751
if completes_constraint:
750752
beam_hyp.add(final_tokens, final_score)
751753
ids_collect.append(beam_id)

0 commit comments

Comments
 (0)