|
11 | 11 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
12 | 12 |
|
13 | 13 | from ...image_processor import VaeImageProcessor
|
14 |
| -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
| 14 | +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
15 | 15 | from ...models import AutoencoderKL, UNet2DConditionModel
|
16 | 16 | from ...models.lora import adjust_lora_scale_text_encoder
|
17 | 17 | from ...schedulers import KarrasDiffusionSchedulers
|
18 |
| -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers |
| 18 | +from ...utils import ( |
| 19 | + USE_PEFT_BACKEND, |
| 20 | + BaseOutput, |
| 21 | + is_torch_xla_available, |
| 22 | + logging, |
| 23 | + scale_lora_layers, |
| 24 | + unscale_lora_layers, |
| 25 | +) |
19 | 26 | from ...utils.torch_utils import randn_tensor
|
20 | 27 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
21 | 28 | from ..stable_diffusion import StableDiffusionSafetyChecker
|
22 | 29 |
|
23 | 30 |
|
| 31 | +if is_torch_xla_available(): |
| 32 | + import torch_xla.core.xla_model as xm |
| 33 | + |
| 34 | + XLA_AVAILABLE = True |
| 35 | +else: |
| 36 | + XLA_AVAILABLE = False |
| 37 | + |
24 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 | 39 |
|
26 | 40 |
|
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
|
282 | 296 |
|
283 | 297 |
|
284 | 298 | class TextToVideoZeroPipeline(
|
285 |
| - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin |
| 299 | + DiffusionPipeline, |
| 300 | + StableDiffusionMixin, |
| 301 | + TextualInversionLoaderMixin, |
| 302 | + StableDiffusionLoraLoaderMixin, |
| 303 | + FromSingleFileMixin, |
286 | 304 | ):
|
287 | 305 | r"""
|
288 | 306 | Pipeline for zero-shot text-to-video generation using Stable Diffusion.
|
@@ -440,6 +458,10 @@ def backward_loop(
|
440 | 458 | if callback is not None and i % callback_steps == 0:
|
441 | 459 | step_idx = i // getattr(self.scheduler, "order", 1)
|
442 | 460 | callback(step_idx, t, latents)
|
| 461 | + |
| 462 | + if XLA_AVAILABLE: |
| 463 | + xm.mark_step() |
| 464 | + |
443 | 465 | return latents.clone().detach()
|
444 | 466 |
|
445 | 467 | # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
|
|
0 commit comments