Skip to content

Commit 332d2bb

Browse files
Improve memory text to video (#3930)
* Improve memory text to video * Apply suggestions from code review * add test * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * finish test setup --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent b8a5dda commit 332d2bb

File tree

5 files changed

+88
-1
lines changed

5 files changed

+88
-1
lines changed

src/diffusers/models/attention.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ def __init__(
119119
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120120
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121121

122+
# let chunk size default to None
123+
self._chunk_size = None
124+
self._chunk_dim = 0
125+
126+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127+
# Sets chunk feed-forward
128+
self._chunk_size = chunk_size
129+
self._chunk_dim = dim
130+
122131
def forward(
123132
self,
124133
hidden_states: torch.FloatTensor,
@@ -141,6 +150,7 @@ def forward(
141150
norm_hidden_states = self.norm1(hidden_states)
142151

143152
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
153+
144154
attn_output = self.attn1(
145155
norm_hidden_states,
146156
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
@@ -171,7 +181,20 @@ def forward(
171181
if self.use_ada_layer_norm_zero:
172182
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
173183

174-
ff_output = self.ff(norm_hidden_states)
184+
if self._chunk_size is not None:
185+
# "feed_forward_chunk_size" can be used to save memory
186+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187+
raise ValueError(
188+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189+
)
190+
191+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192+
ff_output = torch.cat(
193+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
194+
dim=self._chunk_dim,
195+
)
196+
else:
197+
ff_output = self.ff(norm_hidden_states)
175198

176199
if self.use_ada_layer_norm_zero:
177200
ff_output = gate_mlp.unsqueeze(1) * ff_output

src/diffusers/models/unet_3d_condition.py

+40
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
389389
for name, module in self.named_children():
390390
fn_recursive_attn_processor(name, module, processor)
391391

392+
def enable_forward_chunking(self, chunk_size=None, dim=0):
393+
"""
394+
Sets the attention processor to use [feed forward
395+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
396+
397+
Parameters:
398+
chunk_size (`int`, *optional*):
399+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
400+
over each tensor of dim=`dim`.
401+
dim (`int`, *optional*, defaults to `0`):
402+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
403+
or dim=1 (sequence length).
404+
"""
405+
if dim not in [0, 1]:
406+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
407+
408+
# By default chunk size is 1
409+
chunk_size = chunk_size or 1
410+
411+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
412+
if hasattr(module, "set_chunk_feed_forward"):
413+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
414+
415+
for child in module.children():
416+
fn_recursive_feed_forward(child, chunk_size, dim)
417+
418+
for module in self.children():
419+
fn_recursive_feed_forward(module, chunk_size, dim)
420+
421+
def disable_forward_chunking(self):
422+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
423+
if hasattr(module, "set_chunk_feed_forward"):
424+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
425+
426+
for child in module.children():
427+
fn_recursive_feed_forward(child, chunk_size, dim)
428+
429+
for module in self.children():
430+
fn_recursive_feed_forward(module, None, 0)
431+
392432
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
393433
def set_default_attn_processor(self):
394434
"""

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py

+3
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,9 @@ def __call__(
634634
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
635635
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
636636

637+
# 6.1 Chunk feed-forward computation to save memory
638+
self.unet.enable_forward_chunking(chunk_size=1, dim=1)
639+
637640
# 7. Denoising loop
638641
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
639642
with self.progress_bar(total=num_inference_steps) as progress_bar:

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py

+3
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,9 @@ def __call__(
709709
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
710710
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
711711

712+
# 6.1 Chunk feed-forward computation to save memory
713+
self.unet.enable_forward_chunking(chunk_size=1, dim=1)
714+
712715
# 7. Denoising loop
713716
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
714717
with self.progress_bar(total=num_inference_steps) as progress_bar:

tests/models/test_models_unet_3d_condition.py

+18
Original file line numberDiff line numberDiff line change
@@ -399,5 +399,23 @@ def test_lora_xformers_on_off(self):
399399
assert (sample - on_sample).abs().max() < 1e-4
400400
assert (sample - off_sample).abs().max() < 1e-4
401401

402+
def test_feed_forward_chunking(self):
403+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
404+
init_dict["norm_num_groups"] = 32
405+
406+
model = self.model_class(**init_dict)
407+
model.to(torch_device)
408+
model.eval()
409+
410+
with torch.no_grad():
411+
output = model(**inputs_dict)[0]
412+
413+
model.enable_forward_chunking()
414+
with torch.no_grad():
415+
output_2 = model(**inputs_dict)[0]
416+
417+
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
418+
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
419+
402420

403421
# (todo: sayakpaul) implement SLOW tests.

0 commit comments

Comments
 (0)