-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Constrained Beam Search [*With* Disjunctive Decoding] #15761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 58 commits
b35aa4c
a2ba6c4
c36eaa2
13c1808
8363aa5
8a0d871
3776c25
125a9aa
9fcba0d
3773495
d214c83
2625580
f34c3aa
859262a
97dc8cc
b5ce4f4
91a6403
10f0679
e6b60f3
db9e964
8242282
88945d5
73f3acd
42efa23
fb2195a
f522031
d50fc39
12ac97d
ea5fe70
2169a9f
77660bd
b21aae0
0050621
3e35647
edd8681
71125d0
e1f6419
ba7a310
77a18ae
7a78633
bbd9e88
17ab474
aab1d9e
88e938d
f44ba46
7e05468
b21baf5
1d0f862
d047b2b
852a3a5
2d8361e
cbd92f4
8fef1f4
fb091d4
7f7d344
1bfbcca
62f451a
01bb61c
cd2c48d
a25b2cd
a7e07c9
148dcb8
e19636d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from typing import List | ||
|
||
|
||
class Constraint(ABC): | ||
|
@@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint): | |
The id of the token that must be generated by the output. | ||
""" | ||
|
||
def __init__(self, token_ids: Union[List[int], torch.LongTensor]): | ||
def __init__(self, token_ids: List[int]): | ||
super(Constraint, self).__init__() | ||
|
||
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int) | ||
is_tensor = isinstance(token_ids, torch.Tensor) | ||
is_int_tensor = ( | ||
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1 | ||
) | ||
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0 | ||
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive: | ||
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}") | ||
|
||
if not is_tensor: | ||
token_ids = torch.tensor(token_ids) | ||
if not isinstance(token_ids, list) or len(token_ids) == 0: | ||
raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.") | ||
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids): | ||
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.") | ||
|
||
self.token_ids = token_ids | ||
|
||
self.seqlen = self.token_ids.size(0) | ||
self.seqlen = len(self.token_ids) | ||
self.fulfilled_idx = -1 # the index of the currently fulfilled step | ||
self.completed = False | ||
|
||
def advance(self): | ||
if self.completed: | ||
return None | ||
return self.token_ids[self.fulfilled_idx + 1] | ||
|
||
def does_advance(self, token_id: int): | ||
if not isinstance(token_id, int): | ||
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}") | ||
|
||
if self.completed: | ||
return False | ||
# move to cpu to guarantee no device issues. | ||
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu() | ||
|
||
return token_id == self.token_ids[self.fulfilled_idx + 1] | ||
|
||
def update(self, token_id: int): | ||
if not isinstance(token_id, int): | ||
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}") | ||
|
||
stepped = False | ||
completed = False | ||
reset = False | ||
|
@@ -202,6 +201,145 @@ def copy(self, stateful=False): | |
return new_constraint | ||
|
||
|
||
class DisjunctiveTrie: | ||
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True): | ||
r""" | ||
A helper class that builds a trie with the words represented in `nested_token_ids`. | ||
|
||
For example, if `nested_token_ids==[[1,2,3], [1,2,4], [1,5,6]]`, then the trie is: 1 -> 2 -> 3 | ||
\ -> 4 -> 5 -> 6 | ||
""" | ||
self.max_height = max([len(one) for one in nested_token_ids]) | ||
|
||
leaves = [] | ||
root = dict() | ||
for token_ids in nested_token_ids: | ||
level = root | ||
for tidx, token_id in enumerate(token_ids): | ||
if id(level) in leaves: | ||
raise ValueError( | ||
f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}." | ||
) | ||
|
||
if token_id not in level: | ||
level[token_id] = dict() | ||
|
||
level = level[token_id] | ||
|
||
if tidx == len(token_ids) - 1: | ||
leaves.append(id(level)) | ||
|
||
self.trie = root | ||
|
||
def next_tokens(self, current_seq): | ||
""" | ||
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`. | ||
""" | ||
start = self.trie | ||
|
||
for current_token in current_seq: | ||
start = start[current_token] | ||
|
||
next_tokens = list(start.keys()) | ||
|
||
return next_tokens | ||
|
||
def reached_leaf(self, current_seq): | ||
next_tokens = self.next_tokens(current_seq) | ||
|
||
return len(next_tokens) == 0 | ||
|
||
|
||
class DisjunctiveConstraint(Constraint): | ||
r""" | ||
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints. | ||
|
||
Args: | ||
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint | ||
is fulfilled by generating just one from the list of words. | ||
""" | ||
|
||
def __init__(self, nested_token_ids: List[List[int]]): | ||
super(Constraint, self).__init__() | ||
|
||
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0: | ||
raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.") | ||
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids): | ||
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.") | ||
if any( | ||
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) | ||
for token_ids in nested_token_ids | ||
): | ||
raise ValueError( | ||
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}." | ||
) | ||
|
||
self.trie = DisjunctiveTrie(nested_token_ids) | ||
self.token_ids = nested_token_ids | ||
|
||
self.seqlen = self.trie.max_height | ||
self.current_seq = [] | ||
self.completed = False | ||
|
||
def advance(self): | ||
token_list = self.trie.next_tokens(self.current_seq) | ||
token_list = [t for t in token_list if t >= 0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to actually just ignore subset options. Note: This is just a comment. Should our sanitation fail, we would just ignore the subset There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. To be more exact, after the checking of subsets in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you switched to not using |
||
|
||
if len(token_list) == 0: | ||
return None | ||
else: | ||
return token_list | ||
|
||
def does_advance(self, token_id: int): | ||
if not isinstance(token_id, int): | ||
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") | ||
|
||
next_tokens = self.trie.next_tokens(self.current_seq) | ||
|
||
return token_id in next_tokens | ||
|
||
def update(self, token_id: int): | ||
if not isinstance(token_id, int): | ||
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") | ||
|
||
stepped = False | ||
completed = False | ||
reset = False | ||
|
||
if self.does_advance(token_id): | ||
self.current_seq.append(token_id) | ||
stepped = True | ||
else: | ||
reset = True | ||
self.reset() | ||
|
||
completed = self.trie.reached_leaf(self.current_seq) | ||
self.completed = completed | ||
|
||
return stepped, completed, reset | ||
|
||
def reset(self): | ||
self.completed = False | ||
self.current_seq = [] | ||
|
||
def remaining(self): | ||
if self.completed: | ||
# since this can be completed without reaching max height | ||
return 0 | ||
else: | ||
return self.seqlen - len(self.current_seq) | ||
|
||
def copy(self, stateful=False): | ||
new_constraint = DisjunctiveConstraint(self.token_ids) | ||
|
||
if stateful: | ||
new_constraint.seq_len = self.seqlen | ||
new_constraint.current_seq = self.current_seq | ||
new_constraint.completed = self.completed | ||
|
||
return new_constraint | ||
|
||
|
||
class ConstraintListState: | ||
r""" | ||
A class for beam scorers to track its progress through a list of constraints. | ||
|
@@ -215,7 +353,7 @@ def __init__(self, constraints: List[Constraint]): | |
self.constraints = constraints | ||
|
||
# max # of steps required to fulfill a given constraint | ||
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)]) | ||
self.max_seqlen = max([c.seqlen for c in constraints]) | ||
self.n_constraints = len(constraints) | ||
self.completed = False | ||
|
||
|
@@ -249,26 +387,33 @@ def advance(self): | |
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint, | ||
that's the only one we'll return. | ||
""" | ||
token_list = [] | ||
if self.inprogress_constraint is None: | ||
token_list = [] | ||
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet" | ||
advance = constraint.advance() | ||
token_list.append(advance) | ||
if isinstance(advance, int): | ||
token_list.append(advance) | ||
elif isinstance(advance, list): | ||
token_list = token_list + advance | ||
cwkeam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
token_list = [self.inprogress_constraint.advance()] | ||
advance = self.inprogress_constraint.advance() | ||
if isinstance(advance, int): | ||
token_list.append(advance) | ||
elif isinstance(advance, list): | ||
token_list = token_list + advance | ||
|
||
if len(token_list) == 0: | ||
return None | ||
else: | ||
return torch.stack(token_list) | ||
return token_list | ||
|
||
def reset(self, token_ids: Optional[torch.LongTensor]): | ||
def reset(self, token_ids: List[int]): | ||
""" | ||
token_ids: the tokens generated thus far to reset the state of the progress through constraints. | ||
""" | ||
self.init_state() | ||
|
||
if token_ids is not None and token_ids.size(0) > 0: | ||
if token_ids is not None and len(token_ids) > 0: | ||
cwkeam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for token in token_ids: | ||
# completes or steps **one** constraint | ||
complete, stepped = self.add(token) | ||
|
@@ -279,7 +424,10 @@ def reset(self, token_ids: Optional[torch.LongTensor]): | |
|
||
return self | ||
cwkeam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def add(self, token_id: Union[int, torch.LongTensor]): | ||
def add(self, token_id: int): | ||
if not isinstance(token_id, int): | ||
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.") | ||
|
||
complete, stepped = False, False | ||
|
||
if self.completed: | ||
|
@@ -324,8 +472,8 @@ def add(self, token_id: Union[int, torch.LongTensor]): | |
|
||
if not stepped: | ||
raise Exception( | ||
"constraint.update(token_id) is not yielding incremental progress, " | ||
"even though constraint.does_advance(token_id) is true." | ||
"`constraint.update(token_id)` is not yielding incremental progress, " | ||
"even though `constraint.does_advance(token_id)` is true." | ||
) | ||
|
||
if complete: | ||
|
Uh oh!
There was an error while loading. Please reload this page.