Skip to content

Constrained Beam Search [*With* Disjunctive Decoding] #15761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 63 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b35aa4c
added classes to get started with constrained beam search
cwkeam Jan 9, 2022
a2ba6c4
in progress, think i can directly force tokens now but not yet with t…
cwkeam Jan 15, 2022
c36eaa2
think now i have total control, now need to code the bank selection
cwkeam Jan 19, 2022
13c1808
technically works as desired, need to optimize and fix design choices…
cwkeam Jan 20, 2022
8363aa5
Merge branch 'huggingface:master' into master
cwkeam Jan 23, 2022
8a0d871
complete PR #1 without disjunctive decoding
cwkeam Jan 23, 2022
3776c25
Merge branch 'master' of https://github.com/cwkeam/transformers
cwkeam Jan 23, 2022
125a9aa
removed incorrect tests
cwkeam Jan 23, 2022
9fcba0d
Delete k.txt
cwkeam Jan 23, 2022
3773495
Delete test.py
cwkeam Jan 23, 2022
d214c83
Delete test.sh
cwkeam Jan 23, 2022
2625580
revert changes to test scripts
cwkeam Jan 23, 2022
f34c3aa
Merge branch 'master' of https://github.com/cwkeam/transformers
cwkeam Jan 23, 2022
859262a
Merge branch 'master' into master
cwkeam Jan 29, 2022
97dc8cc
genutils
cwkeam Jan 29, 2022
b5ce4f4
Merge branch 'master' of https://github.com/cwkeam/transformers
cwkeam Jan 29, 2022
91a6403
full implementation with testing, no disjunctive yet
cwkeam Jan 31, 2022
10f0679
shifted docs
cwkeam Jan 31, 2022
e6b60f3
Merge branch 'huggingface:master' into constrained_beam_search
cwkeam Jan 31, 2022
db9e964
passing all tests realistically ran locally
cwkeam Jan 31, 2022
8242282
removing accidentally included print statements
cwkeam Jan 31, 2022
88945d5
fixed source of error in initial PR test
cwkeam Jan 31, 2022
73f3acd
fixing the get_device() vs device trap
cwkeam Jan 31, 2022
42efa23
fixed documentation docstrings about constrained_beam_search
cwkeam Jan 31, 2022
fb2195a
fixed tests having failing for Speech2TextModel's floating point inputs
cwkeam Jan 31, 2022
f522031
fix cuda long tensor
patrickvonplaten Jan 31, 2022
d50fc39
Merge branch 'constrained_beam_search' of https://github.com/cwkeam/t…
patrickvonplaten Jan 31, 2022
12ac97d
added examples and testing for them and founx & fixed a bug in beam_s…
cwkeam Feb 3, 2022
ea5fe70
merge fix
cwkeam Feb 3, 2022
2169a9f
deleted accidentally added test halting code with assert False
cwkeam Feb 3, 2022
77660bd
code reformat
cwkeam Feb 3, 2022
b21aae0
Update tests/test_generation_utils.py
cwkeam Feb 4, 2022
0050621
Update tests/test_generation_utils.py
cwkeam Feb 4, 2022
3e35647
Update tests/test_generation_utils.py
cwkeam Feb 4, 2022
edd8681
Update tests/test_generation_utils.py
cwkeam Feb 4, 2022
71125d0
Merge branch 'huggingface:master' into constrained_beam_search
cwkeam Feb 4, 2022
e1f6419
Update tests/test_generation_utils.py
patrickvonplaten Feb 7, 2022
ba7a310
fixing based on comments on PR
cwkeam Feb 8, 2022
77a18ae
took out the testing code that should but work fails without the beam…
cwkeam Feb 8, 2022
7a78633
fixing comments issues
cwkeam Feb 9, 2022
bbd9e88
docstrings for ConstraintListState
cwkeam Feb 9, 2022
17ab474
typo in PhrsalConstraint docstring
cwkeam Feb 9, 2022
aab1d9e
Merge branch 'huggingface:master' into constrained_beam_search
cwkeam Feb 9, 2022
88e938d
docstrings improvements
cwkeam Feb 9, 2022
f44ba46
merge
cwkeam Feb 17, 2022
7e05468
Merge branch 'huggingface-master' into constrained_beam_search
cwkeam Feb 17, 2022
b21baf5
Merge branch 'huggingface:master' into constrained_beam_search
cwkeam Feb 20, 2022
1d0f862
finished adding what is sort of an opinionated implementation of disj…
cwkeam Feb 21, 2022
d047b2b
fixed bug found in constrained beam search that used beam_idx that we…
cwkeam Feb 21, 2022
852a3a5
disjunctive constraint working 100% correctly
cwkeam Feb 22, 2022
2d8361e
Merge branch 'huggingface:master' into constrained_beam_search
cwkeam Feb 22, 2022
cbd92f4
passing all tests
cwkeam Feb 22, 2022
8fef1f4
Accidentally included mlruns
cwkeam Feb 22, 2022
fb091d4
Merge branch 'constrained_beam_search' of https://github.com/cwkeam/t…
cwkeam Feb 22, 2022
7f7d344
Update src/transformers/generation_beam_constraints.py
cwkeam Feb 23, 2022
1bfbcca
Update src/transformers/generation_beam_constraints.py
cwkeam Feb 23, 2022
62f451a
complete overhaul of type complexities and other nits
cwkeam Feb 23, 2022
01bb61c
strict type checks in generate()
cwkeam Feb 23, 2022
cd2c48d
fixing second round of feedback by narsil
cwkeam Feb 24, 2022
a25b2cd
fixing merge conflict with master
cwkeam Feb 24, 2022
a7e07c9
fixed failing generation test because of type check overhaul
cwkeam Feb 24, 2022
148dcb8
generation test fail fix
cwkeam Feb 24, 2022
e19636d
fixing test fails
cwkeam Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] PhrasalConstraint

[[autodoc]] DisjunctiveConstraint

[[autodoc]] ConstraintListState

## BeamSearch
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@
_import_structure["generation_beam_constraints"] = [
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
Expand Down Expand Up @@ -2795,7 +2796,12 @@
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
from .generation_beam_constraints import (
Constraint,
ConstraintListState,
DisjunctiveConstraint,
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
Expand Down
204 changes: 176 additions & 28 deletions src/transformers/generation_beam_constraints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import torch
from typing import List


class Constraint(ABC):
Expand Down Expand Up @@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint):
The id of the token that must be generated by the output.
"""

def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
def __init__(self, token_ids: List[int]):
super(Constraint, self).__init__()

is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
is_tensor = isinstance(token_ids, torch.Tensor)
is_int_tensor = (
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
)
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")

if not is_tensor:
token_ids = torch.tensor(token_ids)
if not isinstance(token_ids, list) or len(token_ids) == 0:
raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.")
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")

self.token_ids = token_ids

self.seqlen = self.token_ids.size(0)
self.seqlen = len(self.token_ids)
self.fulfilled_idx = -1 # the index of the currently fulfilled step
self.completed = False

def advance(self):
if self.completed:
return None
return self.token_ids[self.fulfilled_idx + 1]

def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")

if self.completed:
return False
# move to cpu to guarantee no device issues.
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()

return token_id == self.token_ids[self.fulfilled_idx + 1]

def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")

stepped = False
completed = False
reset = False
Expand Down Expand Up @@ -202,6 +201,145 @@ def copy(self, stateful=False):
return new_constraint


class DisjunctiveTrie:
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
r"""
A helper class that builds a trie with the words represented in `nested_token_ids`.

For example, if `nested_token_ids==[[1,2,3], [1,2,4], [1,5,6]]`, then the trie is: 1 -> 2 -> 3
\ -> 4 -> 5 -> 6
"""
self.max_height = max([len(one) for one in nested_token_ids])

leaves = []
root = dict()
for token_ids in nested_token_ids:
level = root
for tidx, token_id in enumerate(token_ids):
if id(level) in leaves:
raise ValueError(
f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}."
)

if token_id not in level:
level[token_id] = dict()

level = level[token_id]

if tidx == len(token_ids) - 1:
leaves.append(id(level))

self.trie = root

def next_tokens(self, current_seq):
"""
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
"""
start = self.trie

for current_token in current_seq:
start = start[current_token]

next_tokens = list(start.keys())

return next_tokens

def reached_leaf(self, current_seq):
next_tokens = self.next_tokens(current_seq)

return len(next_tokens) == 0


class DisjunctiveConstraint(Constraint):
r"""
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.

Args:
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint
is fulfilled by generating just one from the list of words.
"""

def __init__(self, nested_token_ids: List[List[int]]):
super(Constraint, self).__init__()

if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.")
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in nested_token_ids
):
raise ValueError(
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
)

self.trie = DisjunctiveTrie(nested_token_ids)
self.token_ids = nested_token_ids

self.seqlen = self.trie.max_height
self.current_seq = []
self.completed = False

def advance(self):
token_list = self.trie.next_tokens(self.current_seq)
token_list = [t for t in token_list if t >= 0]
Copy link
Contributor

@Narsil Narsil Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to actually just ignore subset options.

Note: This is just a comment. Should our sanitation fail, we would just ignore the subset word_ids passed as input. Is that right ? That's actually a nice property IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. To be more exact, after the checking of subsets in __init__, the rest of the code can safely assume there are no subset options.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you switched to not using -1 anymore, this line can be safely removed I think.


if len(token_list) == 0:
return None
else:
return token_list

def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")

next_tokens = self.trie.next_tokens(self.current_seq)

return token_id in next_tokens

def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")

stepped = False
completed = False
reset = False

if self.does_advance(token_id):
self.current_seq.append(token_id)
stepped = True
else:
reset = True
self.reset()

completed = self.trie.reached_leaf(self.current_seq)
self.completed = completed

return stepped, completed, reset

def reset(self):
self.completed = False
self.current_seq = []

def remaining(self):
if self.completed:
# since this can be completed without reaching max height
return 0
else:
return self.seqlen - len(self.current_seq)

def copy(self, stateful=False):
new_constraint = DisjunctiveConstraint(self.token_ids)

if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.current_seq = self.current_seq
new_constraint.completed = self.completed

return new_constraint


class ConstraintListState:
r"""
A class for beam scorers to track its progress through a list of constraints.
Expand All @@ -215,7 +353,7 @@ def __init__(self, constraints: List[Constraint]):
self.constraints = constraints

# max # of steps required to fulfill a given constraint
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
self.max_seqlen = max([c.seqlen for c in constraints])
self.n_constraints = len(constraints)
self.completed = False

Expand Down Expand Up @@ -249,26 +387,33 @@ def advance(self):
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
that's the only one we'll return.
"""
token_list = []
if self.inprogress_constraint is None:
token_list = []
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
advance = constraint.advance()
token_list.append(advance)
if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list = token_list + advance
else:
token_list = [self.inprogress_constraint.advance()]
advance = self.inprogress_constraint.advance()
if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list = token_list + advance

if len(token_list) == 0:
return None
else:
return torch.stack(token_list)
return token_list

def reset(self, token_ids: Optional[torch.LongTensor]):
def reset(self, token_ids: List[int]):
"""
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
"""
self.init_state()

if token_ids is not None and token_ids.size(0) > 0:
if token_ids is not None and len(token_ids) > 0:
for token in token_ids:
# completes or steps **one** constraint
complete, stepped = self.add(token)
Expand All @@ -279,7 +424,10 @@ def reset(self, token_ids: Optional[torch.LongTensor]):

return self

def add(self, token_id: Union[int, torch.LongTensor]):
def add(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.")

complete, stepped = False, False

if self.completed:
Expand Down Expand Up @@ -324,8 +472,8 @@ def add(self, token_id: Union[int, torch.LongTensor]):

if not stepped:
raise Exception(
"constraint.update(token_id) is not yielding incremental progress, "
"even though constraint.does_advance(token_id) is true."
"`constraint.update(token_id)` is not yielding incremental progress, "
"even though `constraint.does_advance(token_id)` is true."
)

if complete:
Expand Down
20 changes: 11 additions & 9 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def process(
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
all
non-finished beams.

- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
added
to the non-finished beam_hypotheses.
Expand Down Expand Up @@ -537,7 +538,7 @@ def process(
if is_beam_token_worse_than_top_num_beams:
continue

completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx])
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
if completes_constraint:
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
Expand Down Expand Up @@ -628,23 +629,23 @@ def step_sentence_constraint(
# hypotheses.

topk_state = topk_contraint_states[seq_idx]
topk_state.reset(full_hypotheses[seq_idx])
topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())

advance_state = advance_constraint_states[seq_idx]
advance_state.reset(pre_seq)
advance_state.reset(pre_seq.cpu().tolist())

if not advance_state.completed:
advance_tokens = advance_state.advance()
for advance_token in advance_tokens.to(device):
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
for advance_token in advance_tokens:
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
new_state = advance_state.copy(stateful=True)
new_state.add(advance_token)
new_state.add(advance_token.cpu().tolist())

advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
if advance_seq not in track_new["new_seqs"]:
# prevent duplicates, which are basically bound to happen in this process.
track_new["new_seqs"].append(advance_seq)
track_new["new_indices"].append(seq_idx)
track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
track_new["new_tokens"].append(advance_token)
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
track_new["new_states"].append(new_state)
Expand Down Expand Up @@ -673,8 +674,9 @@ def step_sentence_constraint(

advance_state = advance_constraint_states[seq_idx]

advance_state.reset(advance_seq)
advance_seq = advance_seq.cpu().tolist()

advance_state.reset(advance_seq)
if advance_seq not in track_new["new_seqs"]:
# but still don't want to have duplicates
track_new["new_seqs"].append(advance_seq)
Expand Down Expand Up @@ -745,7 +747,7 @@ def finalize(
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]

completes_constraint = self.check_completes_constraints(final_tokens)
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_hyp.add(final_tokens, final_score)
ids_collect.append(beam_id)
Expand Down
Loading