-
Notifications
You must be signed in to change notification settings - Fork 6k
Support for cross-attention bias / mask #2634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
57bdc4f
bd763a4
d2f99b9
9ad3ed9
fecc595
370daf5
b60b7ed
faef7ac
725c27a
e0437ae
1a68b65
ea2948e
786249e
23440ef
9155215
479574e
7e9ba8f
ba3da64
0c020cf
3ed583f
c4931a0
0d78e8b
937ab70
83762b2
8c323c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None): | |
labels = torch.where(drop_ids, self.num_classes, labels) | ||
return labels | ||
|
||
def forward(self, labels, force_drop_ids=None): | ||
def forward(self, labels: torch.LongTensor, force_drop_ids=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hey @Birch-san PR's great :) just in the future it'd be super helpful to not add type annotations if the argument isn't a part of the PR. No issues with adding type annotations you feel are important in a separate PR. It just helps us to be able to quickly look at a diff of the function definition and see what arguments were added/removed. As is, we have to manually look at each argument to tell if it was just an added type annotation vs an actual change to the argument. No biggie :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will try to be more restrained in future |
||
use_dropout = self.dropout_prob > 0 | ||
if (self.training and use_dropout) or (force_drop_ids is not None): | ||
labels = self.token_drop(labels, force_drop_ids) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from typing import Any, Dict, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
@@ -213,11 +213,13 @@ def __init__( | |
|
||
def forward( | ||
self, | ||
hidden_states, | ||
encoder_hidden_states=None, | ||
timestep=None, | ||
class_labels=None, | ||
cross_attention_kwargs=None, | ||
hidden_states: torch.Tensor, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
timestep: Optional[torch.LongTensor] = None, | ||
class_labels: Optional[torch.LongTensor] = None, | ||
cross_attention_kwargs: Dict[str, Any] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
return_dict: bool = True, | ||
Birch-san marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
""" | ||
|
@@ -228,11 +230,17 @@ def forward( | |
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | ||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | ||
self-attention. | ||
timestep ( `torch.long`, *optional*): | ||
timestep ( `torch.LongTensor`, *optional*): | ||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. | ||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | ||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels | ||
conditioning. | ||
encoder_attention_mask ( `torch.Tensor`, *optional* ). | ||
Cross-attention mask, applied to encoder_hidden_states. Two formats supported: | ||
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 | ||
= keep, -10000 = discard. | ||
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format | ||
above. This bias will be added to the cross-attention scores. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. | ||
|
||
|
@@ -241,6 +249,29 @@ def forward( | |
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
returning a tuple, the first element is the sample tensor. | ||
""" | ||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! |
||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. | ||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. | ||
# expects mask of shape: | ||
# [batch, key_tokens] | ||
# adds singleton query_tokens dimension: | ||
# [batch, 1, key_tokens] | ||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | ||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | ||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | ||
if attention_mask is not None and attention_mask.ndim == 2: | ||
# assume that mask is expressed as: | ||
# (1 = keep, 0 = discard) | ||
# convert mask into a bias that can be added to attention scores: | ||
# (keep = +0, discard = -10000.0) | ||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | ||
attention_mask = attention_mask.unsqueeze(1) | ||
|
||
# convert encoder_attention_mask to a bias the same way we do for attention_mask | ||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | ||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | ||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | ||
|
||
# 1. Input | ||
if self.is_input_continuous: | ||
batch, _, height, width = hidden_states.shape | ||
|
@@ -264,7 +295,9 @@ def forward( | |
for block in self.transformer_blocks: | ||
hidden_states = block( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
timestep=timestep, | ||
cross_attention_kwargs=cross_attention_kwargs, | ||
class_labels=class_labels, | ||
|
Uh oh!
There was an error while loading. Please reload this page.