-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[wip] attention refactor #2143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] attention refactor #2143
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
2658f5a
to
c6709ed
Compare
src/diffusers/models/attention.py
Outdated
attn = CrossAttention( | ||
self.channels, | ||
heads=self.num_heads, | ||
dim_head=dim_head, | ||
bias=True, | ||
upcast_softmax=True, | ||
norm_num_groups=self.group_norm.num_groups, | ||
processor=processor, | ||
eps=self.group_norm.eps, | ||
rescale_output_factor=self.rescale_output_factor, | ||
) | ||
|
||
# compute next hidden_states | ||
hidden_states = self.proj_attn(hidden_states) | ||
attn.group_norm = self.group_norm | ||
attn.to_q = self.query | ||
attn.to_k = self.key | ||
attn.to_v = self.value | ||
attn.to_out[0] = self.proj_attn | ||
|
||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) | ||
hidden_states = attn(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating CrossAttention
on the fly like this causes some of the mps tests that rely on reproducibility to fail. Could this have something to do with the mps warm up passes? cc @pcuenca
e11a51a
to
52f16b8
Compare
|
||
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...schedulers.scheduling_utils import SchedulerMixin | ||
|
||
|
||
warnings.filterwarnings("ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warnings filter silences all warnings. Need to remove to see the attention block deprectation warning
cbec387
to
ae18d1d
Compare
ae18d1d
to
06f3e9b
Compare
re: #1880
AttentionBlock
's forward is called, dynamically create aCrossAttention
module with the same parameters and call its forward method instead.