Skip to content

Commit fdcbbdf

Browse files
hlkysayakpaul
andauthored
Add torch_xla and from_single_file support to TextToVideoZeroPipeline (#10445)
Co-authored-by: Sayak Paul <[email protected]>
1 parent 4e44534 commit fdcbbdf

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,30 @@
1111
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
1212

1313
from ...image_processor import VaeImageProcessor
14-
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
14+
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
1515
from ...models import AutoencoderKL, UNet2DConditionModel
1616
from ...models.lora import adjust_lora_scale_text_encoder
1717
from ...schedulers import KarrasDiffusionSchedulers
18-
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
18+
from ...utils import (
19+
USE_PEFT_BACKEND,
20+
BaseOutput,
21+
is_torch_xla_available,
22+
logging,
23+
scale_lora_layers,
24+
unscale_lora_layers,
25+
)
1926
from ...utils.torch_utils import randn_tensor
2027
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
2128
from ..stable_diffusion import StableDiffusionSafetyChecker
2229

2330

31+
if is_torch_xla_available():
32+
import torch_xla.core.xla_model as xm
33+
34+
XLA_AVAILABLE = True
35+
else:
36+
XLA_AVAILABLE = False
37+
2438
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2539

2640

@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
282296

283297

284298
class TextToVideoZeroPipeline(
285-
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
299+
DiffusionPipeline,
300+
StableDiffusionMixin,
301+
TextualInversionLoaderMixin,
302+
StableDiffusionLoraLoaderMixin,
303+
FromSingleFileMixin,
286304
):
287305
r"""
288306
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
@@ -440,6 +458,10 @@ def backward_loop(
440458
if callback is not None and i % callback_steps == 0:
441459
step_idx = i // getattr(self.scheduler, "order", 1)
442460
callback(step_idx, t, latents)
461+
462+
if XLA_AVAILABLE:
463+
xm.mark_step()
464+
443465
return latents.clone().detach()
444466

445467
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs

0 commit comments

Comments
 (0)