Skip to content

[wip] attention refactor #2143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Jan 27, 2023

re: #1880

  1. When AttentionBlock's forward is called, dynamically create a CrossAttention module with the same parameters and call its forward method instead.
  2. If an AttentionBlock is constructed, We will log a deprecation warning and instructions for converting the model. We can use a context manager that manages a one off logging method in attention.py to de-dup deprecation messages.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@williamberman williamberman force-pushed the attention_refactor branch 2 times, most recently from 2658f5a to c6709ed Compare January 27, 2023 22:01
Comment on lines 119 to 137
attn = CrossAttention(
self.channels,
heads=self.num_heads,
dim_head=dim_head,
bias=True,
upcast_softmax=True,
norm_num_groups=self.group_norm.num_groups,
processor=processor,
eps=self.group_norm.eps,
rescale_output_factor=self.rescale_output_factor,
)

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
attn.group_norm = self.group_norm
attn.to_q = self.query
attn.to_k = self.key
attn.to_v = self.value
attn.to_out[0] = self.proj_attn

hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
hidden_states = attn(hidden_states)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating CrossAttention on the fly like this causes some of the mps tests that rely on reproducibility to fail. Could this have something to do with the mps warm up passes? cc @pcuenca

@williamberman williamberman force-pushed the attention_refactor branch 10 times, most recently from e11a51a to 52f16b8 Compare January 30, 2023 20:43

from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers.scheduling_utils import SchedulerMixin


warnings.filterwarnings("ignore")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warnings filter silences all warnings. Need to remove to see the attention block deprectation warning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants