Skip to content

Commit b5b5ece

Browse files
authored
Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (huggingface#3275)
The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.
1 parent bdffdaa commit b5b5ece

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

models/unet_2d_blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def __init__(
734734
else:
735735
self.downsamplers = None
736736

737-
def forward(self, hidden_states, temb=None):
737+
def forward(self, hidden_states, temb=None, upsample_size=None):
738738
output_states = ()
739739

740740
for resnet, attn in zip(self.resnets, self.attentions):
@@ -1720,7 +1720,7 @@ def __init__(
17201720
else:
17211721
self.upsamplers = None
17221722

1723-
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1723+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
17241724
for resnet, attn in zip(self.resnets, self.attentions):
17251725
# pop res hidden states
17261726
res_hidden_states = res_hidden_states_tuple[-1]

0 commit comments

Comments
 (0)