From 860adf63c9da0d52ac12e0345b8e24deaecc3cf3 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 5 Nov 2022 09:49:13 -0700 Subject: [PATCH 01/42] =?UTF-8?q?lint(ldm.invoke.generator):=20?= =?UTF-8?q?=F0=9F=9A=AE=20remove=20unused=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ldm/invoke/generator/base.py | 12 +++++++----- ldm/invoke/generator/embiggen.py | 14 ++++++++------ ldm/invoke/generator/img2img.py | 10 +++++----- ldm/invoke/generator/inpaint.py | 23 ++++++++++++----------- ldm/invoke/generator/omnibus.py | 8 +++----- ldm/invoke/generator/txt2img.py | 3 +-- ldm/invoke/generator/txt2img2img.py | 11 ++++++----- scripts/invoke.py | 3 +-- 8 files changed, 43 insertions(+), 41 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 3c6eca08a21..427f8d07984 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -2,15 +2,17 @@ Base class for ldm.invoke.generator.* including img2img, txt2img, and inpaint ''' -import torch -import numpy as np -import random import os +import random import traceback -from tqdm import tqdm, trange + +import numpy as np +import torch from PIL import Image, ImageFilter -from einops import rearrange, repeat +from einops import rearrange from pytorch_lightning import seed_everything +from tqdm import trange + from ldm.invoke.devices import choose_autocast from ldm.util import rand_perlin_2d diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py index dc6af35a6c9..0b9fda7ac29 100644 --- a/ldm/invoke/generator/embiggen.py +++ b/ldm/invoke/generator/embiggen.py @@ -3,14 +3,16 @@ and generates with ldm.invoke.generator.img2img ''' +import numpy as np import torch -import numpy as np +from PIL import Image from tqdm import trange -from PIL import Image -from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.img2img import Img2Img + from ldm.invoke.devices import choose_autocast -from ldm.models.diffusion.ddim import DDIMSampler +from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.img2img import Img2Img +from ldm.models.diffusion.ddim import DDIMSampler + class Embiggen(Generator): def __init__(self, model, precision): @@ -493,7 +495,7 @@ def make_image(): # Layer tile onto final image outputsuperimage.alpha_composite(intileimage, (left, top)) else: - print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') + print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') # after internal loops and patching up return Embiggen image return outputsuperimage diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 1981b4eacb6..edcc855a290 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -2,15 +2,15 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator ''' -import torch -import numpy as np import PIL -from torch import Tensor +import numpy as np +import torch from PIL import Image +from torch import Tensor + from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent + class Img2Img(Generator): def __init__(self, model, precision): diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 2ba45ed8048..55636b1315f 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -3,19 +3,20 @@ ''' import math -import torch -import torchvision.transforms as T -import numpy as np -import cv2 as cv + import PIL +import cv2 as cv +import numpy as np +import torch from PIL import Image, ImageFilter, ImageOps -from skimage.exposure.histogram_matching import match_histograms -from einops import rearrange, repeat -from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.img2img import Img2Img -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.ksampler import KSampler +from einops import repeat + +from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import downsampling +from ldm.invoke.generator.img2img import Img2Img +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.ksampler import KSampler + class Inpaint(Img2Img): def __init__(self, model, precision): @@ -187,7 +188,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): print( - f">> Using recommended DDIM sampler for inpainting." + ">> Using recommended DDIM sampler for inpainting." ) sampler = DDIMSampler(self.model, device=self.model.device) diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py index e8426a9205e..e4b9b9d8ef6 100644 --- a/ldm/invoke/generator/omnibus.py +++ b/ldm/invoke/generator/omnibus.py @@ -1,14 +1,14 @@ """omnibus module to be used with the runwayml 9-channel custom inpainting model""" import torch -import numpy as np -from einops import repeat from PIL import Image, ImageOps +from einops import repeat + from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.txt2img import Txt2Img + class Omnibus(Img2Img,Txt2Img): def __init__(self, model, precision): super().__init__(model, precision) @@ -49,8 +49,6 @@ def get_make_image( if isinstance(mask_image, Image.Image): mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False) - t_enc = steps - if init_image is not None and mask_image is not None: # inpainting masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index ba49d2ef558..a04207259b8 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -3,9 +3,8 @@ ''' import torch -import numpy as np + from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Txt2Img(Generator): diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 759ba2dba4e..3da42ebb8af 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -2,14 +2,15 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' -import torch -import numpy as np import math + +import torch +from PIL import Image + from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler from ldm.invoke.generator.omnibus import Omnibus -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from PIL import Image +from ldm.models.diffusion.ddim import DDIMSampler + class Txt2Img2Img(Generator): def __init__(self, model, precision): diff --git a/scripts/invoke.py b/scripts/invoke.py index 1e9a84295eb..54161e5ed9a 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -11,9 +11,8 @@ import traceback import yaml -from ldm.invoke.prompt_parser import PromptParser - sys.path.append('.') # corrects a weird problem on Macs +from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.readline import get_completer from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata From e7794c086cfe0b28d5a53c0782116461f8f582d4 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 10:08:49 -0800 Subject: [PATCH 02/42] initial commit of DiffusionPipeline class --- ldm/invoke/generator/diffusers_pipeline.py | 325 +++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 ldm/invoke/generator/diffusers_pipeline.py diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py new file mode 100644 index 00000000000..b13f85b645e --- /dev/null +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -0,0 +1,325 @@ +import secrets +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +@dataclass +class PipelineIntermediateState: + run_id: str + step: int + timestep: int + latents: torch.Tensor + predicted_original: Optional[torch.Tensor] = None + + +class StableDiffusionGeneratorPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline. + Hopefully future versions of diffusers provide access to more of these functions so that we don't + need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384 + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + ID_LENGTH = 8 + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + **extra_step_kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + result = None + for result in self.generate( + prompt, height=height, width=width, num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, generator=generator, latents=latents, + **extra_step_kwargs): + pass # discarding intermediates + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def generate( + self, + prompt: Union[str, List[str]], + *, + opposing_prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + run_id: str = None, + **extra_step_kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if run_id is None: + run_id = secrets.token_urlsafe(self.ID_LENGTH) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ + .to(self.unet.device) + self.scheduler.set_timesteps(num_inference_steps) + latents = self.prepare_latents(latents, batch_size, height, width, + generator, self.unet.dtype) + + yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, + latents=latents) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs) + latents = step_output.prev_sample + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, + predicted_original=step_output.pred_original_sample) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + image = self.decode_to_image(latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + yield self.check_for_safety(output) + + @torch.inference_mode() + def step(self, t, latents: torch.Tensor, guidance_scale, text_embeddings: torch.Tensor, **extra_step_kwargs): + do_classifier_free_guidance = guidance_scale > 1.0 + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + return self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs) + + @torch.inference_mode() + def check_for_safety(self, output): + if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): + return output + images = output.images + safety_checker_output = self.feature_extractor(self.numpy_to_pil(images), + return_tensors="pt").to(self.device) + screened_images, has_nsfw_concept = self.safety_checker( + images=images, clip_input=safety_checker_output.pixel_values) + return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) + + @torch.inference_mode() + def decode_to_image(self, latents): + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return image + + @torch.inference_mode() + def get_text_embeddings(self, + prompt: Union[str, List[str]], + opposing_prompt: Union[str, List[str]], + do_classifier_free_guidance: bool, + batch_size: int): + # get prompt text embeddings + text_input = self._tokenize(prompt) + + text_embeddings = self.text_encoder(text_input.input_ids)[0] + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + # opposing prompt defaults to blank caption for everything in the batch + text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size) + uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # FIXME: assert these two are the same size + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + return text_embeddings + + @torch.inference_mode() + def _tokenize(self, prompt: Union[str, List[str]]): + return self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + def prepare_latents(self, latents, batch_size, height, width, generator, dtype): + # get the initial random noise unless the user supplied it + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.unet.device, + dtype=dtype + ) + else: + if latents.shape != latents_shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if latents.device != self.unet.device: + raise ValueError(f"Unexpected latents device, got {latents.device}, " + f"expected {self.unet.device}") + + # scale the initial noise by the standard deviation required by the scheduler + latents *= self.scheduler.init_noise_sigma + return latents From d009a094dda9d9763f3ef1ae3a00acf306c79b78 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 11:33:19 -0800 Subject: [PATCH 03/42] spike: proof of concept using diffusers for txt2img --- environment.yml | 3 +- ldm/invoke/generator/diffusers_pipeline.py | 36 +++++++++--- ldm/invoke/generator/txt2img.py | 66 ++++++++++++---------- requirements-linux-arm64.txt | 2 +- requirements.txt | 2 +- 5 files changed, 68 insertions(+), 41 deletions(-) diff --git a/environment.yml b/environment.yml index fc648e82625..a4d1392f4ae 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: - pytorch=1.12.1 - cudatoolkit=11.6 - pip: + - accelerate~=0.13 - albumentations==0.4.3 - opencv-python==4.5.5.64 - pudb==2019.2 @@ -27,7 +28,7 @@ dependencies: - pyreadline3 - torch-fidelity==0.3.0 - transformers==4.21.3 - - diffusers==0.6.0 + - diffusers~=0.7 - torchmetrics==0.7.0 - flask==2.1.3 - flask_socketio==5.3.0 diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index b13f85b645e..d4d60b761a7 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,6 +1,6 @@ import secrets from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable import torch from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -131,6 +131,7 @@ def __call__( guidance_scale: Optional[float] = 7.5, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, + callback: Optional[Callable[[PipelineIntermediateState], None]] = None, **extra_step_kwargs, ): r""" @@ -172,7 +173,22 @@ def __call__( prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, latents=latents, **extra_step_kwargs): - pass # discarding intermediates + if callback is not None: + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, + text_embeddings: torch.Tensor, guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: + self.scheduler.set_timesteps(num_inference_steps) + result = None + for result in self.generate_from_embeddings( + latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs): + if callback is not None: + callback(result) if result is None: raise AssertionError("why was that an empty generator?") return result @@ -199,9 +215,6 @@ def generate( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if run_id is None: - run_id = secrets.token_urlsafe(self.ID_LENGTH) - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -209,16 +222,23 @@ def generate( text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ .to(self.unet.device) self.scheduler.set_timesteps(num_inference_steps) - latents = self.prepare_latents(latents, batch_size, height, width, - generator, self.unet.dtype) + latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype) + yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs) + + def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float, + run_id: str = None, **extra_step_kwargs): + if run_id is None: + run_id = secrets.token_urlsafe(self.ID_LENGTH) yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) + # NOTE: Depends on scheduler being already initialized! for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs) latents = step_output.prev_sample + predicted_original = getattr(step_output, 'pred_original_sample', None) yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=step_output.pred_original_sample) + predicted_original=predicted_original) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index a04207259b8..a882b156713 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -1,10 +1,11 @@ ''' ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' - +import PIL.Image import torch -from ldm.invoke.generator.base import Generator +from .base import Generator +from .diffusers_pipeline import StableDiffusionGeneratorPipeline class Txt2Img(Generator): @@ -13,7 +14,8 @@ def __init__(self, model, precision): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): + conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0, + **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it @@ -22,38 +24,42 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, self.perlin = perlin uc, c, extra_conditioning_info = conditioning - @torch.no_grad() - def make_image(x_T): - shape = [ - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ] + # FIXME: this should probably be either passed in to __init__ instead of model & precision, + # or be constructed in __init__ from those inputs. + pipeline = StableDiffusionGeneratorPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", torch_dtype=torch.float16, + safety_checker=None, # TODO + # scheduler=sampler + ddim_eta, # TODO + # TODO: local_files_only=True + ) + pipeline.unet.to("cuda") + pipeline.vae.to("cuda") + + def make_image(x_T) -> PIL.Image.Image: + # FIXME: restore free_gpu_mem functionality + # if self.free_gpu_mem and self.model.model.device != self.model.device: + # self.model.model.to(self.model.device) - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) + # FIXME: how the embeddings are combined should be internal to the pipeline + combined_text_embeddings = torch.cat([uc, c]) - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x_T, - conditioning = c, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - extra_conditioning_info = extra_conditioning_info, - eta = ddim_eta, - img_callback = step_callback, - threshold = threshold, + pipeline_output = pipeline.image_from_embeddings( + latents=x_T, + num_inference_steps=steps, + text_embeddings=combined_text_embeddings, + guidance_scale=cfg_scale, + callback=step_callback, + # TODO: extra_conditioning_info = extra_conditioning_info, + # TODO: eta = ddim_eta, + # TODO: threshold = threshold, ) - if self.free_gpu_mem: - self.model.model.to("cpu") + # FIXME: restore free_gpu_mem functionality + # if self.free_gpu_mem: + # self.model.model.to("cpu") - return self.sample_to_image(samples) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/requirements-linux-arm64.txt b/requirements-linux-arm64.txt index a0be77057b9..7f73a585313 100644 --- a/requirements-linux-arm64.txt +++ b/requirements-linux-arm64.txt @@ -1,6 +1,6 @@ albumentations==0.4.3 einops==0.3.0 -diffusers==0.6.0 +diffusers[torch]~=0.7 huggingface-hub==0.8.1 imageio==2.9.0 imageio-ffmpeg==0.4.2 diff --git a/requirements.txt b/requirements.txt index 939463e36e7..76adcbe558b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ send2trash dependency_injector==4.40.0 eventlet realesrgan -diffusers +diffusers[torch]~=0.7 git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan From a267b455421f9f73fff887b1f43db262ab7d094a Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 15:23:42 -0800 Subject: [PATCH 04/42] doc: type hints for Generator --- ldm/invoke/generator/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 427f8d07984..6cacd969f72 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -14,13 +14,19 @@ from tqdm import trange from ldm.invoke.devices import choose_autocast +from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.util import rand_perlin_2d downsampling = 8 CAUTION_IMG = 'assets/caution.png' -class Generator(): - def __init__(self, model, precision): +class Generator: + downsampling_factor: int + latent_channels: int + precision: str + model: DiffusionWrapper + + def __init__(self, model: DiffusionWrapper, precision: str): self.model = model self.precision = precision self.seed = None From 9f5e496053a97103559b7c3a0cd198f2bfcd056a Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 15:25:56 -0800 Subject: [PATCH 05/42] refactor(model_cache): factor out load_ckpt --- ldm/invoke/model_cache.py | 63 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 1999973ea88..4c7297e087e 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -94,7 +94,7 @@ def get_model(self, model_name:str): 'hash': hash } - def default_model(self) -> str: + def default_model(self) -> str | None: ''' Returns the name of the default model, or None if none is defined. @@ -191,13 +191,6 @@ def _load_model(self, model_name:str): return None mconfig = self.config[model_name] - config = mconfig.config - weights = mconfig.weights - vae = mconfig.get('vae',None) - width = mconfig.width - height = mconfig.height - - print(f'>> Loading {model_name} from {weights}') # for usage statistics if self._has_cuda(): @@ -207,15 +200,44 @@ def _load_model(self, model_name:str): tic = time.time() # this does the work - c = OmegaConf.load(config) - with open(weights,'rb') as f: + model_format = mconfig.get('format', 'ckpt') + if model_format == 'ckpt': + weights = mconfig.weights + print(f'>> Loading {model_name} from {weights}') + model, width, height, model_hash = self._load_ckpt_model(mconfig) + elif model_format == 'diffusers': + model, width, height, model_hash = self._load_diffusers_model(mconfig) + else: + raise NotImplementedError(f"Unknown model format {model_name}: {model_format}") + + # usage statistics + toc = time.time() + print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) + if self._has_cuda(): + print( + '>> Max VRAM used to load the model:', + '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), + '\n>> Current VRAM usage:' + '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), + ) + return model, width, height, model_hash + + def _load_ckpt_model(self, mconfig): + config = mconfig.config + weights = mconfig.weights + vae = mconfig.get('vae', None) + width = mconfig.width + height = mconfig.height + + c = OmegaConf.load(config) + with open(weights, 'rb') as f: weight_bytes = f.read() - model_hash = self._cached_sha256(weights,weight_bytes) + model_hash = self._cached_sha256(weights, weight_bytes) pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') del weight_bytes - sd = pl_sd['state_dict'] + sd = pl_sd['state_dict'] model = instantiate_from_config(c.model) - m, u = model.load_state_dict(sd, strict=False) + m, u = model.load_state_dict(sd, strict=False) if self.precision == 'float16': print(' | Using faster float16 precision') @@ -243,18 +265,11 @@ def _load_model(self, model_name:str): if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): m._orig_padding_mode = m.padding_mode - # usage statistics - toc = time.time() - print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) - if self._has_cuda(): - print( - '>> Max VRAM used to load the model:', - '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), - '\n>> Current VRAM usage:' - '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), - ) return model, width, height, model_hash - + + def _load_diffusers_model(self, mconfig): + raise NotImplementedError() # return pipeline, width, height, model_hash + def offload_model(self, model_name:str): ''' Offload the indicated model to CPU. Will call From b39d04d40c7b516e393a06a266e41dbd128523d2 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 17:17:52 -0800 Subject: [PATCH 06/42] model_cache: add ability to load a diffusers model pipeline and update associated things in Generate & Generator to not instantly fail when that happens --- ldm/generate.py | 53 +++++++++++++++++++++- ldm/invoke/generator/base.py | 5 +- ldm/invoke/generator/diffusers_pipeline.py | 28 ++++++++++++ ldm/invoke/generator/txt2img.py | 13 +----- ldm/invoke/model_cache.py | 41 ++++++++++++++++- 5 files changed, 124 insertions(+), 16 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index e2d4a40de70..cefa85be5d2 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -18,6 +18,8 @@ import hashlib import cv2 import skimage +from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ + EulerAncestralDiscreteScheduler from omegaconf import OmegaConf from ldm.invoke.generator.base import downsampling @@ -386,7 +388,10 @@ def process_image(image,seed): width = width or self.width height = height or self.height - configure_model_padding(model, seamless, seamless_axes) + if isinstance(model, DiffusionPipeline): + configure_model_padding(model.unet, seamless, seamless_axes) + else: + configure_model_padding(model, seamless, seamless_axes) assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' assert threshold >= 0.0, '--threshold must be >=0.0' @@ -930,9 +935,15 @@ def sample_to_image(self, samples): def sample_to_lowres_estimated_image(self, samples): return self._make_base().sample_to_lowres_estimated_image(samples) + def _set_sampler(self): + if isinstance(self.model, DiffusionPipeline): + return self._set_scheduler() + else: + return self._set_sampler_legacy() + # very repetitive code - can this be simplified? The KSampler names are # consistent, at least - def _set_sampler(self): + def _set_sampler_legacy(self): msg = f'>> Setting Sampler to {self.sampler_name}' if self.sampler_name == 'plms': self.sampler = PLMSSampler(self.model, device=self.device) @@ -956,6 +967,44 @@ def _set_sampler(self): print(msg) + def _set_scheduler(self): + msg = f'>> Setting Sampler to {self.sampler_name}' + default = self.model.scheduler + # TODO: Test me! Not all schedulers take the same args. + scheduler_args = dict( + num_train_timesteps=default.num_train_timesteps, + beta_start=default.beta_start, + beta_end=default.beta_end, + beta_schedule=default.beta_schedule, + ) + trained_betas = getattr(self.model.scheduler, 'trained_betas') + if trained_betas is not None: + scheduler_args.update(trained_betas=trained_betas) + if self.sampler_name == 'plms': + raise NotImplementedError("What's the diffusers implementation of PLMS?") + elif self.sampler_name == 'ddim': + self.sampler = DDIMScheduler(**scheduler_args) + elif self.sampler_name == 'k_dpm_2_a': + raise NotImplementedError("no diffusers implementation of dpm_2 samplers") + elif self.sampler_name == 'k_dpm_2': + raise NotImplementedError("no diffusers implementation of dpm_2 samplers") + elif self.sampler_name == 'k_euler_a': + self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args) + elif self.sampler_name == 'k_euler': + self.sampler = EulerDiscreteScheduler(**scheduler_args) + elif self.sampler_name == 'k_heun': + raise NotImplementedError("no diffusers implementation of Heun's sampler") + elif self.sampler_name == 'k_lms': + self.sampler = LMSDiscreteScheduler(**scheduler_args) + else: + msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}' + + print(msg) + + if not hasattr(self.sampler, 'uses_inpainting_model'): + # FIXME: terrible kludge! + self.sampler.uses_inpainting_model = lambda: False + def _load_img(self, img)->Image: if isinstance(img, Image.Image): image = img diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 6cacd969f72..58da250c5c2 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -9,6 +9,7 @@ import numpy as np import torch from PIL import Image, ImageFilter +from diffusers import DiffusionPipeline from einops import rearrange from pytorch_lightning import seed_everything from tqdm import trange @@ -24,9 +25,9 @@ class Generator: downsampling_factor: int latent_channels: int precision: str - model: DiffusionWrapper + model: DiffusionWrapper | DiffusionPipeline - def __init__(self, model: DiffusionWrapper, precision: str): + def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str): self.model = model self.precision = precision self.seed = None diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index d4d60b761a7..c9fff12e0cb 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,4 +1,5 @@ import secrets +import warnings from dataclasses import dataclass from typing import List, Optional, Union, Callable @@ -309,6 +310,28 @@ def get_text_embeddings(self, text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings + def get_learned_conditioning(self, c: List[List[str]], return_tokens=True, + fragment_weights=None, **kwargs): + """ + Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. + """ + assert return_tokens == True + if fragment_weights: + weights = fragment_weights[0] + if any(weight != 1.0 for weight in weights): + warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2) + + if kwargs: + warnings.warn(f"unsupported args {kwargs}", stacklevel=2) + + text_fragments = c[0] + text_input = self._tokenize(text_fragments) + + with torch.inference_mode(): + token_ids = text_input.input_ids.to(self.text_encoder.device) + text_embeddings = self.text_encoder(token_ids)[0] + return text_embeddings, text_input.input_ids + @torch.inference_mode() def _tokenize(self, prompt: Union[str, List[str]]): return self.tokenizer( @@ -319,6 +342,11 @@ def _tokenize(self, prompt: Union[str, List[str]]): return_tensors="pt", ) + @property + def channels(self) -> int: + """Compatible with DiffusionWrapper""" + return self.unet.in_channels + def prepare_latents(self, latents, batch_size, height, width, generator, dtype): # get the initial random noise unless the user supplied it # Unlike in other pipelines, latents need to be generated in the target device diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index a882b156713..7b36d37df44 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -24,17 +24,8 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, self.perlin = perlin uc, c, extra_conditioning_info = conditioning - # FIXME: this should probably be either passed in to __init__ instead of model & precision, - # or be constructed in __init__ from those inputs. - pipeline = StableDiffusionGeneratorPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - revision="fp16", torch_dtype=torch.float16, - safety_checker=None, # TODO - # scheduler=sampler + ddim_eta, # TODO - # TODO: local_files_only=True - ) - pipeline.unet.to("cuda") - pipeline.vae.to("cuda") + pipeline = self.model + # TODO: customize a new pipeline for the given sampler (Scheduler) def make_image(x_T) -> PIL.Image.Image: # FIXME: restore free_gpu_mem functionality diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 4c7297e087e..65efecffcde 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,6 +4,7 @@ below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +from pathlib import Path import torch import os @@ -18,6 +19,8 @@ from sys import getrefcount from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError + +from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.util import instantiate_from_config DEFAULT_MAX_MODELS=2 @@ -268,7 +271,43 @@ def _load_ckpt_model(self, mconfig): return model, width, height, model_hash def _load_diffusers_model(self, mconfig): - raise NotImplementedError() # return pipeline, width, height, model_hash + pipeline_args = {} + + if 'repo_name' in mconfig: + name_or_path = mconfig['repo_name'] + model_hash = "FIXME" + # model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash + elif 'path' in mconfig: + name_or_path = Path(mconfig['path']) + # FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all + # the submodels hashed together? The commit ID from the repo? + model_hash = "FIXME TOO" + else: + raise ValueError("Model config must specify either repo_name or path.") + + print(f'>> Loading diffusers model from {name_or_path}') + + if self.precision == 'float16': + print(' | Using faster float16 precision') + pipeline_args.update(revision="fp16", torch_dtype=torch.float16) + else: + # TODO: more accurately, "using the model's default precision." + # How do we find out what that is? + print(' | Using more accurate float32 precision') + + pipeline = StableDiffusionGeneratorPipeline.from_pretrained( + name_or_path, + safety_checker=None, # TODO + # TODO: alternate VAE + # TODO: local_files_only=True + **pipeline_args + ) + pipeline.to(self.device) + + width = pipeline.vae.sample_size + height = pipeline.vae.sample_size + + return pipeline, width, height, model_hash def offload_model(self, model_name:str): ''' From f49317c4f5e4a29c747aaaa433e617ff8ae98f13 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 19:21:58 -0800 Subject: [PATCH 07/42] model_cache: fix model default image dimensions --- ldm/invoke/generator/txt2img.py | 1 - ldm/invoke/model_cache.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 7b36d37df44..15650ccbd98 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -5,7 +5,6 @@ import torch from .base import Generator -from .diffusers_pipeline import StableDiffusionGeneratorPipeline class Txt2Img(Generator): diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 65efecffcde..13cedaba8a4 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -304,8 +304,8 @@ def _load_diffusers_model(self, mconfig): ) pipeline.to(self.device) - width = pipeline.vae.sample_size - height = pipeline.vae.sample_size + width = pipeline.vae.block_out_channels[-1] + height = pipeline.vae.block_out_channels[-1] return pipeline, width, height, model_hash From 8db7054807ae3a2804cac331917b9fdcb855f16d Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 10 Nov 2022 14:36:45 -0800 Subject: [PATCH 08/42] txt2img: support switching diffusers schedulers --- backend/modules/parameters.py | 3 ++ ldm/generate.py | 59 +++++++++++++++++---------------- ldm/invoke/args.py | 3 ++ ldm/invoke/generator/txt2img.py | 2 +- ldm/invoke/model_cache.py | 12 +++++++ 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index f3079e04973..37f3f921ce1 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -10,6 +10,9 @@ "k_heun", "k_lms", "plms", + # diffusers: + "ipndm", + "pndm", ] diff --git a/ldm/generate.py b/ldm/generate.py index cefa85be5d2..8fc60c0b116 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -19,7 +19,7 @@ import cv2 import skimage from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ - EulerAncestralDiscreteScheduler + EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler from omegaconf import OmegaConf from ldm.invoke.generator.base import downsampling @@ -968,36 +968,39 @@ def _set_sampler_legacy(self): print(msg) def _set_scheduler(self): - msg = f'>> Setting Sampler to {self.sampler_name}' default = self.model.scheduler - # TODO: Test me! Not all schedulers take the same args. - scheduler_args = dict( - num_train_timesteps=default.num_train_timesteps, - beta_start=default.beta_start, - beta_end=default.beta_end, - beta_schedule=default.beta_schedule, + + higher_order_samplers = [ + 'k_dpm_2', + 'k_dpm_2_a', + 'k_heun', + 'plms', # Its first step is like Heun + ] + scheduler_map = dict( + ddim=DDIMScheduler, + ipndm=IPNDMScheduler, + k_euler=EulerDiscreteScheduler, + k_euler_a=EulerAncestralDiscreteScheduler, + k_lms=LMSDiscreteScheduler, + pndm=PNDMScheduler, ) - trained_betas = getattr(self.model.scheduler, 'trained_betas') - if trained_betas is not None: - scheduler_args.update(trained_betas=trained_betas) - if self.sampler_name == 'plms': - raise NotImplementedError("What's the diffusers implementation of PLMS?") - elif self.sampler_name == 'ddim': - self.sampler = DDIMScheduler(**scheduler_args) - elif self.sampler_name == 'k_dpm_2_a': - raise NotImplementedError("no diffusers implementation of dpm_2 samplers") - elif self.sampler_name == 'k_dpm_2': - raise NotImplementedError("no diffusers implementation of dpm_2 samplers") - elif self.sampler_name == 'k_euler_a': - self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args) - elif self.sampler_name == 'k_euler': - self.sampler = EulerDiscreteScheduler(**scheduler_args) - elif self.sampler_name == 'k_heun': - raise NotImplementedError("no diffusers implementation of Heun's sampler") - elif self.sampler_name == 'k_lms': - self.sampler = LMSDiscreteScheduler(**scheduler_args) + + if self.sampler_name in scheduler_map: + sampler_class = scheduler_map[self.sampler_name] + msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})' + self.sampler = sampler_class.from_config( + self.model_cache.model_name_or_path(self.model_name), + subfolder="scheduler" + ) + elif self.sampler_name in higher_order_samplers: + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'— diffusers does not yet support higher-order samplers, ' + f'Defaulting to {default}') + self.sampler = default else: - msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}' + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'Defaulting to {default}') + self.sampler = default print(msg) diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 3c3b2059d52..66c0a209881 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -104,6 +104,9 @@ 'k_heun', 'k_lms', 'plms', + # diffusers: + "ipndm", + "pndm", ] PRECISION_CHOICES = [ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 15650ccbd98..219e8131724 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -24,7 +24,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, uc, c, extra_conditioning_info = conditioning pipeline = self.model - # TODO: customize a new pipeline for the given sampler (Scheduler) + pipeline.scheduler = sampler def make_image(x_T) -> PIL.Image.Image: # FIXME: restore free_gpu_mem functionality diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 13cedaba8a4..5745cdf6ae5 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -309,6 +309,18 @@ def _load_diffusers_model(self, mconfig): return pipeline, width, height, model_hash + def model_name_or_path(self, model_name:str) -> str | Path: + if model_name not in self.config: + raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file') + + mconfig = self.config[model_name] + if 'repo_name' in mconfig: + return mconfig['repo_name'] + elif 'path' in mconfig: + return Path(mconfig['path']) + else: + raise ValueError("Model config must specify either repo_name or path.") + def offload_model(self, model_name:str): ''' Offload the indicated model to CPU. Will call From 1f83920dd10d6872e096cd9a4f9e50bf51e32ee2 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 10 Nov 2022 15:27:25 -0800 Subject: [PATCH 09/42] diffusers: let the scheduler do its scaling of the initial latents Remove IPNDM scheduler; it is not behaving. --- backend/modules/parameters.py | 1 - ldm/invoke/args.py | 1 - ldm/invoke/generator/diffusers_pipeline.py | 2 ++ 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index 37f3f921ce1..4cc0831c764 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -11,7 +11,6 @@ "k_lms", "plms", # diffusers: - "ipndm", "pndm", ] diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 83f8a824961..28014098be9 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -105,7 +105,6 @@ 'k_lms', 'plms', # diffusers: - "ipndm", "pndm", ] diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index c9fff12e0cb..bad21b09565 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -231,6 +231,8 @@ def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch run_id: str = None, **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) + # scale the initial noise by the standard deviation required by the scheduler + latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) # NOTE: Depends on scheduler being already initialized! From 6b586b7b8fd3e72348bb77d25898f8b02d5c0834 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 10 Nov 2022 15:28:22 -0800 Subject: [PATCH 10/42] web server: update image_progress callback for diffusers data --- backend/invoke_ai_web_server.py | 5 ++++- ldm/invoke/generator/diffusers_pipeline.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 0ca94a63187..88a3f5e1f88 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -15,6 +15,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash +from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.prompt_parser import split_weighted_subprompts @@ -602,7 +603,9 @@ def generate_images( self.socketio.emit("progressUpdate", progress.to_formatted_dict()) eventlet.sleep(0) - def image_progress(sample, step): + def image_progress(progress_state: PipelineIntermediateState): + step = progress_state.step + sample = progress_state.latents if self.canceled.is_set(): raise CanceledException diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index bad21b09565..20caecada0b 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -188,7 +188,7 @@ def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, result = None for result in self.generate_from_embeddings( latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs): - if callback is not None: + if callback is not None and isinstance(result, PipelineIntermediateState): callback(result) if result is None: raise AssertionError("why was that an empty generator?") From 7904d0c74a52ce49efb1543781283ddcbff25333 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 11 Nov 2022 13:16:09 -0800 Subject: [PATCH 11/42] diffusers: restore prompt weighting feature --- ldm/invoke/generator/diffusers_pipeline.py | 29 ++++++++-------------- ldm/modules/encoders/modules.py | 12 ++++++--- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 20caecada0b..6846ff84567 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,5 +1,4 @@ import secrets -import warnings from dataclasses import dataclass from typing import List, Optional, Union, Callable @@ -11,6 +10,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder + @dataclass class PipelineIntermediateState: @@ -76,6 +77,11 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + # InvokeAI's interface for text embeddings and whatnot + self.clip_embedder = WeightedFrozenCLIPEmbedder( + tokenizer=self.tokenizer, + transformer=self.text_encoder + ) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -312,27 +318,12 @@ def get_text_embeddings(self, text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings - def get_learned_conditioning(self, c: List[List[str]], return_tokens=True, - fragment_weights=None, **kwargs): + @torch.inference_mode() + def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): """ Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. """ - assert return_tokens == True - if fragment_weights: - weights = fragment_weights[0] - if any(weight != 1.0 for weight in weights): - warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2) - - if kwargs: - warnings.warn(f"unsupported args {kwargs}", stacklevel=2) - - text_fragments = c[0] - text_input = self._tokenize(text_fragments) - - with torch.inference_mode(): - token_ids = text_input.input_ids.to(self.text_encoder.device) - text_embeddings = self.text_encoder(token_ids)[0] - return text_embeddings, text_input.input_ids + return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights) @torch.inference_mode() def _tokenize(self, prompt: Union[str, List[str]]): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index cf8644a7fb2..263d00bdb60 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -240,17 +240,17 @@ class FrozenCLIPEmbedder(AbstractEncoder): def __init__( self, version='openai/clip-vit-large-patch14', - device=choose_torch_device(), max_length=77, + tokenizer=None, + transformer=None, ): super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained( + self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( version, local_files_only=True ) - self.transformer = CLIPTextModel.from_pretrained( + self.transformer = transformer or CLIPTextModel.from_pretrained( version, local_files_only=True ) - self.device = device self.max_length = max_length self.freeze() @@ -456,6 +456,10 @@ def forward(self, text, **kwargs): def encode(self, text, **kwargs): return self(text, **kwargs) + @property + def device(self): + return self.transformer.device + class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights" From fdf2ed258a1d59500ae58e0ad2809456fb1af76e Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 11 Nov 2022 13:17:36 -0800 Subject: [PATCH 12/42] diffusers: fix set-sampler error following model switch --- ldm/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/generate.py b/ldm/generate.py index 678c473a32d..8fb0ca91473 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -847,8 +847,8 @@ def set_model(self,model_name): self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' ) - self._set_sampler() self.model_name = model_name + self._set_sampler() # requires self.model_name to be set first return self.model def correct_colors(self, From 1b326e7e12072eba5198b395e1be67e89a24067e Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 11 Nov 2022 16:25:27 -0800 Subject: [PATCH 13/42] diffusers: use InvokeAIDiffuserComponent for conditioning --- ldm/invoke/generator/diffusers_pipeline.py | 160 +++++++++--------- ldm/invoke/generator/txt2img.py | 12 +- .../diffusion/shared_invokeai_diffusion.py | 1 + 3 files changed, 83 insertions(+), 90 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 6846ff84567..0bd096ff6b2 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,4 +1,5 @@ import secrets +import warnings from dataclasses import dataclass from typing import List, Optional, Union, Callable @@ -10,6 +11,7 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -82,6 +84,7 @@ def __init__( tokenizer=self.tokenizer, transformer=self.text_encoder ) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -128,72 +131,36 @@ def disable_xformers_memory_efficient_attention(self): """ self.unet.set_use_memory_efficient_attention_xformers(False) - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - callback: Optional[Callable[[PipelineIntermediateState], None]] = None, - **extra_step_kwargs, - ): + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None]=None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, + run_id=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + :param latents: Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for + image generation. Can be used to tweak the same generation with different prompts. + :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + :param text_embeddings: + :param guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + :param callback: + :param extra_conditioning_info: + :param run_id: + :param extra_step_kwargs: """ - result = None - for result in self.generate( - prompt, height=height, width=width, num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, generator=generator, latents=latents, - **extra_step_kwargs): - if callback is not None: - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result - - def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, - text_embeddings: torch.Tensor, guidance_scale: float, - *, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None, - **extra_step_kwargs) -> StableDiffusionPipelineOutput: - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) result = None for result in self.generate_from_embeddings( - latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs): + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + run_id=run_id, **extra_step_kwargs): if callback is not None and isinstance(result, PipelineIntermediateState): callback(result) if result is None: @@ -226,24 +193,40 @@ def generate( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ + text_embeddings, unconditioned_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ .to(self.unet.device) self.scheduler.set_timesteps(num_inference_steps) latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype) - yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs) - - def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float, - run_id: str = None, **extra_step_kwargs): + yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings, + guidance_scale, run_id=run_id, **extra_step_kwargs) + + def generate_from_embeddings( + self, + latents: torch.Tensor, + text_embeddings: torch.Tensor, + unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, + run_id: str = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) + + batch_size = latents.shape[0] + batched_t = torch.full((batch_size,), self.scheduler.timesteps[0], + dtype=self.scheduler.timesteps.dtype, device=self.unet.device) # NOTE: Depends on scheduler being already initialized! for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs) + batched_t.fill_(t) + step_output = self.step(batched_t, latents, guidance_scale, + text_embeddings, unconditioned_embeddings, + i, **extra_step_kwargs) latents = step_output.prev_sample predicted_original = getattr(step_output, 'pred_original_sample', None) yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, @@ -257,23 +240,30 @@ def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch yield self.check_for_safety(output) @torch.inference_mode() - def step(self, t, latents: torch.Tensor, guidance_scale, text_embeddings: torch.Tensor, **extra_step_kwargs): - do_classifier_free_guidance = guidance_scale > 1.0 + def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + step_index:int | None = None, + **extra_step_kwargs): + # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value + timestep = t[0] - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # TODO: should this scaling happen here or inside self._unet_forward? + # i.e. before or after passing it to InvokeAIDiffuserComponent + latent_model_input = self.scheduler.scale_model_input(latents, timestep) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = self.invokeai_diffuser.do_diffusion_step( + latent_model_input, t, + unconditioned_embeddings, text_embeddings, + guidance_scale, + step_index=step_index) # compute the previous noisy sample x_t -> x_t-1 - return self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs) + return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs) + + def _unet_forward(self, latents, t, text_embeddings): + # predict the noise residual + return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample @torch.inference_mode() def check_for_safety(self, output): @@ -310,13 +300,10 @@ def get_text_embeddings(self, # opposing prompt defaults to blank caption for everything in the batch text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size) uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0] + else: + uncond_embeddings = None - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # FIXME: assert these two are the same size - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - return text_embeddings + return text_embeddings, uncond_embeddings @torch.inference_mode() def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): @@ -325,6 +312,11 @@ def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fr """ return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights) + @property + def cond_stage_model(self): + warnings.warn("legacy compatibility layer", DeprecationWarning) + return self.clip_embedder + @torch.inference_mode() def _tokenize(self, prompt: Union[str, List[str]]): return self.tokenizer( diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 219e8131724..f9af1ac3ed7 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -5,6 +5,7 @@ import torch from .base import Generator +from .diffusers_pipeline import StableDiffusionGeneratorPipeline class Txt2Img(Generator): @@ -23,7 +24,8 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, self.perlin = perlin uc, c, extra_conditioning_info = conditioning - pipeline = self.model + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model pipeline.scheduler = sampler def make_image(x_T) -> PIL.Image.Image: @@ -31,16 +33,14 @@ def make_image(x_T) -> PIL.Image.Image: # if self.free_gpu_mem and self.model.model.device != self.model.device: # self.model.model.to(self.model.device) - # FIXME: how the embeddings are combined should be internal to the pipeline - combined_text_embeddings = torch.cat([uc, c]) - pipeline_output = pipeline.image_from_embeddings( latents=x_T, num_inference_steps=steps, - text_embeddings=combined_text_embeddings, + text_embeddings=c, + unconditioned_embeddings=uc, guidance_scale=cfg_scale, callback=step_callback, - # TODO: extra_conditioning_info = extra_conditioning_info, + extra_conditioning_info=extra_conditioning_info, # TODO: eta = ddim_eta, # TODO: threshold = threshold, ) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0a18eb25c84..27cc734ccf7 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -34,6 +34,7 @@ def __init__(self, model, model_forward_callback: :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ + self.conditioning = None self.model = model self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None From cbbe3a6ad47ff2011feba00c33ca374ffec011f3 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 12 Nov 2022 10:10:46 -0800 Subject: [PATCH 14/42] cross_attention_control: stub (no-op) implementations for diffusers --- ldm/invoke/generator/diffusers_pipeline.py | 7 ++++ .../diffusion/cross_attention_control.py | 35 +++++++++++++------ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 0bd096ff6b2..861bf22a7a7 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -213,6 +213,13 @@ def generate_from_embeddings( **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, + step_count=len(self.scheduler.timesteps)) + else: + self.invokeai_diffuser.remove_cross_attention_control() + # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index ff90a248566..f0d02776f08 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,4 +1,5 @@ import enum +import warnings from typing import Optional import torch @@ -243,20 +244,32 @@ def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim return attention_slice - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.identifier = name + cross_attention_modules = [(name, module) for (name, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for identifier, module in cross_attention_modules: + module.identifier = identifier + try: module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ - context.get_slicing_strategy(module_identifier)) + module.set_slicing_strategy_getter( + lambda module: context.get_slicing_strategy(identifier) + ) + except AttributeError as e: + if e.name == 'set_attention_slice_wrangler': + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise @classmethod def remove_attention_function(cls, unet): - # clear wrangler callback - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": + cross_attention_modules = [module for (_, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for module in cross_attention_modules: + try: + # clear wrangler callback module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) - + except AttributeError as e: + if e.name == 'set_attention_slice_wrangler': + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise From 0d52b8571c2c9fba6232c26e49ea3f4d58398d45 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 12 Nov 2022 10:11:50 -0800 Subject: [PATCH 15/42] model_cache: let offload_model work with DiffusionPipeline, sorta. --- ldm/invoke/model_cache.py | 12 ++++++++---- ldm/modules/encoders/modules.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index b06b880df65..e659c77a0b9 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,6 +4,7 @@ below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +import warnings from pathlib import Path import torch @@ -387,10 +388,13 @@ def _invalidate_cached_model(self,model_name:str): def _model_to_cpu(self,model): if self.device != 'cpu': - model.cond_stage_model.device = 'cpu' - model.first_stage_model.to('cpu') - model.cond_stage_model.to('cpu') - model.model.to('cpu') + try: + model.cond_stage_model.device = 'cpu' + model.first_stage_model.to('cpu') + model.cond_stage_model.to('cpu') + model.model.to('cpu') + except AttributeError as e: + warnings.warn(f"TODO: clean up legacy model-management: {e}") return model.to('cpu') else: return model diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 263d00bdb60..cb805489566 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,4 +1,5 @@ import math +from typing import Optional import torch import torch.nn as nn @@ -236,13 +237,15 @@ def encode(self, x): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" + tokenizer: CLIPTokenizer + transformer: CLIPTextModel def __init__( self, - version='openai/clip-vit-large-patch14', - max_length=77, - tokenizer=None, - transformer=None, + version:str='openai/clip-vit-large-patch14', + max_length:int=77, + tokenizer:Optional[CLIPTokenizer]=None, + transformer:Optional[CLIPTextModel]=None, ): super().__init__() self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( @@ -460,6 +463,10 @@ def encode(self, text, **kwargs): def device(self): return self.transformer.device + @device.setter + def device(self, device): + self.transformer.to(device=device) + class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights" From 8d6189df1c3bdbc75dff4c55bfa16ef517c47bf6 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 15:19:38 -0800 Subject: [PATCH 16/42] models.yaml.example: add diffusers-format model, set as default --- configs/models.yaml.example | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/models.yaml.example b/configs/models.yaml.example index 9c152c25c1e..576a4de7af9 100644 --- a/configs/models.yaml.example +++ b/configs/models.yaml.example @@ -5,6 +5,11 @@ # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. +diffusers-1.5: + description: Diffusers version of Stable Diffusion version 1.5 + format: diffusers + repo_name: runwayml/stable-diffusion-v1-5 + default: true stable-diffusion-1.5: description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) weights: ./models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt From 90ec3a771fded82e252df60cb702be0f85c685bf Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 15:36:51 -0800 Subject: [PATCH 17/42] test-invoke-conda: use diffusers-format model --- .github/workflows/test-invoke-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 41838ba565a..7c6ff599f52 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -32,7 +32,7 @@ jobs: # stable-diffusion-model-switch: stable-diffusion-1.4 - stable-diffusion-model: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt - stable-diffusion-model-switch: stable-diffusion-1.5 + stable-diffusion-model-switch: diffusers-1.5 name: ${{ matrix.os }} with ${{ matrix.stable-diffusion-model-switch }} runs-on: ${{ matrix.os }} env: From 12efacc988ce537eece91f2182334a1ffd1a6748 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 16:09:49 -0800 Subject: [PATCH 18/42] test-invoke-conda: put huggingface-token where the library can use it --- .github/workflows/test-invoke-conda.yml | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 7c6ff599f52..26ce2cf362b 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -37,6 +37,7 @@ jobs: runs-on: ${{ matrix.os }} env: CONDA_ENV_NAME: invokeai + PYTHONUNBUFFERED: 1 defaults: run: shell: ${{ matrix.default-shell }} @@ -82,25 +83,16 @@ jobs: id: cache-sd-model uses: actions/cache@v3 env: - cache-name: cache-${{ matrix.stable-diffusion-model-switch }} + cache-name: cache-huggingface-${{ matrix.stable-diffusion-model-switch }} with: - path: ${{ matrix.stable-diffusion-model-dl-path }} + path: ~/.cache/huggingface key: ${{ env.cache-name }} - - name: Download ${{ matrix.stable-diffusion-model-switch }} - id: download-stable-diffusion-model - if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }} - run: | - [[ -d models/ldm/stable-diffusion-v1 ]] \ - || mkdir -p models/ldm/stable-diffusion-v1 - curl \ - -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" \ - -o ${{ matrix.stable-diffusion-model-dl-path }} \ - -L ${{ matrix.stable-diffusion-model }} - - name: run preload_models.py id: run-preload-models run: | + mkdir -p ~/.huggingface + echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token python scripts/preload_models.py \ --no-interactive From e1c678c146eb16a3b73c00b03655b10e05de56e3 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 16:29:34 -0800 Subject: [PATCH 19/42] test-invoke-conda: some diagnostic info from huggingface-cli --- .github/workflows/test-invoke-conda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 26ce2cf362b..72f12d93f6a 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -93,6 +93,8 @@ jobs: run: | mkdir -p ~/.huggingface echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token + wc ~/.huggingface/token + echo -n Logged in to huggingface as: ; huggingface-cli whoami python scripts/preload_models.py \ --no-interactive From d8dd1f94fa7f644762c4f0a89d499b8427b147ee Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 16:48:48 -0800 Subject: [PATCH 20/42] test-invoke-conda: some diagnostic info from huggingface-cli --- .github/workflows/test-invoke-conda.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 72f12d93f6a..a921cc7f38f 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -93,8 +93,11 @@ jobs: run: | mkdir -p ~/.huggingface echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token - wc ~/.huggingface/token - echo -n Logged in to huggingface as: ; huggingface-cli whoami + if [ -s ~/.huggingface/token ] ; then + echo -n Logged in to huggingface as: ; huggingface-cli whoami + else + echo -e '\a ⛔ I have no huggingface token!' ; exit 1 + fi python scripts/preload_models.py \ --no-interactive From a63a5f6d429c8a0025ce759a6a58ceb19b68e254 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 19:02:44 -0800 Subject: [PATCH 21/42] test-invoke-conda: it's a cache name, it doesn't need cache in the name --- .github/workflows/test-invoke-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index a921cc7f38f..1e733b56b7d 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -83,7 +83,7 @@ jobs: id: cache-sd-model uses: actions/cache@v3 env: - cache-name: cache-huggingface-${{ matrix.stable-diffusion-model-switch }} + cache-name: huggingface-${{ matrix.stable-diffusion-model-switch }} with: path: ~/.cache/huggingface key: ${{ env.cache-name }} From be1e6f7645174c2a5e0eb43f7863c7c930ca2d25 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 22:30:40 -0800 Subject: [PATCH 22/42] test-invoke-conda: run on pushes to forks The problem is that PRs from forks don't have access to secrets.HUGGINGFACE_TOKEN in the main project. pull_request actions run in the main project, but push actions may run in forks (if the fork has actions enabled). Running in the forked project should allow the forked project to use its own secrets. --- .github/workflows/test-invoke-conda.yml | 50 +++++++++++++------------ 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 1e733b56b7d..915f91fe586 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -1,16 +1,16 @@ name: Test invoke.py -on: - push: - branches: - - 'main' - - 'development' - pull_request: - branches: - - 'main' - - 'development' +on: [push, pull_request] jobs: matrix: + # Run on: + # - pull requests + # - pushes to forks (will run in the forked project with that fork's secrets) + # - pushes to branches that are *not* pull requests + if: | + github.event_name == 'pull_request' + || github.repository != 'invoke-ai/InvokeAI' + || github.ref_protected strategy: fail-fast: false matrix: @@ -38,6 +38,7 @@ jobs: env: CONDA_ENV_NAME: invokeai PYTHONUNBUFFERED: 1 + HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }} defaults: run: shell: ${{ matrix.default-shell }} @@ -52,6 +53,19 @@ jobs: - name: create environment.yml run: cp environments-and-requirements/${{ matrix.environment-file }} environment.yml + - name: Use Cached Stable Diffusion Model + id: cache-sd-model + uses: actions/cache@v3 + env: + cache-name: huggingface-${{ matrix.stable-diffusion-model-switch }} + with: + path: ~/.cache/huggingface + key: ${{ env.cache-name }} + + - name: Check model availability + if: steps.cache-sd-model.outputs.cache-hit != true && !env.HAVE_SECRETS + run: echo -e '\a ⛔ GitHub model cache not found, and no HUGGINGFACE_TOKEN is available. Will not be able to load Stable Diffusion.' ; exit 1 + - name: Use cached conda packages id: use-cached-conda-packages uses: actions/cache@v3 @@ -79,24 +93,12 @@ jobs: if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }} run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV - - name: Use Cached Stable Diffusion Model - id: cache-sd-model - uses: actions/cache@v3 - env: - cache-name: huggingface-${{ matrix.stable-diffusion-model-switch }} - with: - path: ~/.cache/huggingface - key: ${{ env.cache-name }} - - name: run preload_models.py id: run-preload-models run: | - mkdir -p ~/.huggingface - echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token - if [ -s ~/.huggingface/token ] ; then - echo -n Logged in to huggingface as: ; huggingface-cli whoami - else - echo -e '\a ⛔ I have no huggingface token!' ; exit 1 + if [ "${HAVE_SECRETS}" == true ] ; then + mkdir -p ~/.huggingface + echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token fi python scripts/preload_models.py \ --no-interactive From bc5a7ba906fdb360860d9b781380fad7462975b1 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 23:02:54 -0800 Subject: [PATCH 23/42] test-invoke-conda: fix string comparison things in the env context are strings and so `!` probably didn't do what I meant --- .github/workflows/test-invoke-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 915f91fe586..6a3b842a8da 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -63,7 +63,7 @@ jobs: key: ${{ env.cache-name }} - name: Check model availability - if: steps.cache-sd-model.outputs.cache-hit != true && !env.HAVE_SECRETS + if: steps.cache-sd-model.outputs.cache-hit != true && env.HAVE_SECRETS != 'true' run: echo -e '\a ⛔ GitHub model cache not found, and no HUGGINGFACE_TOKEN is available. Will not be able to load Stable Diffusion.' ; exit 1 - name: Use cached conda packages From a71ec35c958323f90a7e70423121612e296c81ac Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 13 Nov 2022 23:35:33 -0800 Subject: [PATCH 24/42] environment-mac: upgrade to diffusers 0.7 (from 0.6) this was already done for linux; mac must have been lost in the merge. --- environments-and-requirements/environment-mac.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environments-and-requirements/environment-mac.yml b/environments-and-requirements/environment-mac.yml index 1ff49ec585e..777bfeb2336 100644 --- a/environments-and-requirements/environment-mac.yml +++ b/environments-and-requirements/environment-mac.yml @@ -22,7 +22,7 @@ dependencies: - albumentations=1.2 - coloredlogs=15.0 - - diffusers=0.6 + - diffusers~=0.7 - einops=0.3 - eventlet - grpcio=1.46 From 638489db4e313c9f7ba814a9cfc8074e0b7bb23b Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 15 Nov 2022 20:22:11 -0800 Subject: [PATCH 25/42] =?UTF-8?q?lint(preload=5Fmodels):=20pyflakes=20?= =?UTF-8?q?=F0=9F=9A=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/preload_models.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 67b143d82db..0f7716f59bf 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -8,24 +8,23 @@ # print('Loading Python libraries...\n') import argparse -import sys import os +import sys +import traceback import warnings from urllib import request -from tqdm import tqdm -from omegaconf import OmegaConf -from huggingface_hub import HfFolder, hf_hub_url -from pathlib import Path -from getpass_asterisk import getpass_asterisk -from transformers import CLIPTokenizer, CLIPTextModel -import traceback + import requests -import clip -import transformers import torch +import transformers +from getpass_asterisk import getpass_asterisk +from huggingface_hub import HfFolder, hf_hub_url +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import CLIPTokenizer, CLIPTextModel + transformers.logging.set_verbosity_error() -import warnings warnings.filterwarnings('ignore') #warnings.simplefilter('ignore') #warnings.filterwarnings('ignore',category=DeprecationWarning) @@ -277,7 +276,7 @@ def download_weight_datasets(models:dict, access_token:str): if success: successful[mod] = True if len(successful) < len(models): - print(f'\n\n** There were errors downloading one or more files. **') + print('\n\n** There were errors downloading one or more files. **') print('Please double-check your license agreements, and your access token.') HfFolder.delete_token() print('Press any key to try again. Type ^C to quit.\n') @@ -403,8 +402,8 @@ def download_bert(): sys.stdout.flush() with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) - from transformers import BertTokenizerFast, AutoFeatureExtractor - tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + from transformers import BertTokenizerFast + BertTokenizerFast.from_pretrained('bert-base-uncased') print('...success') #--------------------------------------------- @@ -413,6 +412,8 @@ def download_kornia(): print('Installing Kornia requirements (ignore deprecation errors)...', end='') sys.stdout.flush() import kornia + # Is importing it all we need to do to get it to download weights for all models? + assert kornia.__version__ # reference kornia in some way to avoid `unused` warning. print('...success') #--------------------------------------------- @@ -420,8 +421,8 @@ def download_clip(): print('Loading CLIP model (ignore deprecation errors)...',end='') sys.stdout.flush() version = 'openai/clip-vit-large-patch14' - tokenizer = CLIPTokenizer.from_pretrained(version) - transformer = CLIPTextModel.from_pretrained(version) + CLIPTokenizer.from_pretrained(version) + CLIPTextModel.from_pretrained(version) print('...success') #--------------------------------------------- @@ -528,8 +529,8 @@ def download_safety_checker(): print(traceback.format_exc()) return safety_model_id = "CompVis/stable-diffusion-safety-checker" - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + AutoFeatureExtractor.from_pretrained(safety_model_id) + StableDiffusionSafetyChecker.from_pretrained(safety_model_id) print('...success') #------------------------------------- From 5b44a051636684eb60129f94b25f1eda0e1b51ef Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 15 Nov 2022 22:04:38 -0800 Subject: [PATCH 26/42] preload_models: explicitly load diffusers models In non-interactive mode too, as long as you're logged in. --- .github/workflows/test-invoke-conda.yml | 7 +- scripts/preload_models.py | 90 ++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 6a3b842a8da..99e2593869d 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -101,10 +101,15 @@ jobs: echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token fi python scripts/preload_models.py \ - --no-interactive + --no-interactive \ + --full-precision # can't use fp16 weights without a GPU - name: Run the tests id: run-tests + env: + HF_HUB_OFFLINE: 1 + HF_DATASETS_OFFLINE: 1 + TRANSFORMERS_OFFLINE: 1 run: | time python scripts/invoke.py \ --model ${{ matrix.stable-diffusion-model-switch }} \ diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 0f7716f59bf..2befb144bf1 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -12,17 +12,27 @@ import sys import traceback import warnings +from pathlib import Path +from typing import Dict from urllib import request +import huggingface_hub import requests import torch import transformers +from diffusers import StableDiffusionPipeline, AutoencoderKL from getpass_asterisk import getpass_asterisk from huggingface_hub import HfFolder, hf_hub_url from omegaconf import OmegaConf from tqdm import tqdm from transformers import CLIPTokenizer, CLIPTextModel +try: + from ldm.invoke.model_cache import ModelCache +except ImportError: + sys.path.append('.') + from ldm.invoke.model_cache import ModelCache + transformers.logging.set_verbosity_error() warnings.filterwarnings('ignore') @@ -287,7 +297,20 @@ def download_weight_datasets(models:dict, access_token:str): keys = ', '.join(successful.keys()) print(f'Successfully installed {keys}') return successful - + +#--------------------------------------------- +def is_huggingface_authenticated(): + # huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError, + # maybe other things, not all end-user-friendly. + # noinspection PyBroadException + try: + response = huggingface_hub.whoami() + if response.get('id') is not None: + return True + except Exception: + pass + return False + #--------------------------------------------- def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: model_dest = os.path.join(Model_dir, model_name) @@ -336,7 +359,55 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: print(f'An error occurred while downloading {model_name}: {str(e)}') return False return True - + +#--------------------------------------------- +def download_diffusers(models: Dict, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + diffusers_repos = { + 'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5', + 'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting', + 'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion' + } + vae_repos = { + 'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse', + } + precision_args = {} + if not full_precision: + precision_args.update(revision='fp16') + + for model_name, model in models.items(): + repo_id = model['repo_id'] + if repo_id in vae_repos: + print(f" * Downloading diffusers VAE {model_name}...") + # TODO: can we autodetect when a repo has no fp16 revision? + AutoencoderKL.from_pretrained(repo_id) + elif repo_id not in diffusers_repos: + print(f" * Downloading diffusers {model_name}...") + StableDiffusionPipeline.from_pretrained(repo_id, **precision_args) + else: + warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}") + continue + + +def download_diffusers_in_config(config_path: Path, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + precision = 'full' if full_precision else 'float16' + cache = ModelCache(OmegaConf.load(config_path), precision=precision, + device_type='cpu', max_loaded_models=1) + for model_name in cache.list_models(): + # TODO: download model without loading it. + # https://github.com/huggingface/diffusers/issues/1301 + model_config = cache.config[model_name] + if model_config.get('format') == 'diffusers': + print(f" * Downloading diffusers {model_name}...") + cache.get_model(model_name) + cache.offload_model(model_name) + #--------------------------------------------- def update_config_file(successfully_downloaded:dict,opt:dict): Config_file = opt.config_file or Default_config_file @@ -541,6 +612,12 @@ def download_safety_checker(): action=argparse.BooleanOptionalAction, default=True, help='run in interactive mode (default)') + parser.add_argument('--full-precision', + dest='full_precision', + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help='use 32-bit weights instead of faster 16-bit weights') parser.add_argument('--config_file', '-c', dest='config_file', @@ -563,7 +640,16 @@ def download_safety_checker(): access_token = authenticate() print('\n** DOWNLOADING WEIGHTS **') successfully_downloaded = download_weight_datasets(models, access_token) + download_diffusers(models, full_precision=opt.full_precision) update_config_file(successfully_downloaded,opt) + elif is_huggingface_authenticated(): + config_path = Path(opt.config_file or Default_config_file) + if config_path.exists(): + download_diffusers_in_config(config_path, full_precision=opt.full_precision) + else: + print("*⚠ No config file found; downloading no weights.") + else: + print("*⚠ No Hugging Face access; downloading no weights.") print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_kornia() From 7267d88926e7aec3ba24f44bd5899f2af3ae33cd Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Mon, 21 Nov 2022 16:46:32 -0800 Subject: [PATCH 27/42] fix(model_cache): don't check `model.config` in diffusers format clean-up from recent merge. --- ldm/invoke/model_cache.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 2ee5eae0958..c701cce38ab 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,27 +4,25 @@ below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +import gc +import hashlib +import io +import os +import sys +import time +import traceback import warnings from pathlib import Path import torch -import os -import io -import time -import gc -import hashlib -import psutil -import sys import transformers -import traceback -import os from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError +from picklescan.scanner import scan_file_path from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline -from ldm.util import instantiate_from_config from ldm.invoke.globals import Globals -from picklescan.scanner import scan_file_path +from ldm.util import instantiate_from_config DEFAULT_MAX_MODELS=2 @@ -198,7 +196,6 @@ def _load_model(self, model_name:str): return None mconfig = self.config[model_name] - config = mconfig.config # for usage statistics if self._has_cuda(): @@ -208,8 +205,6 @@ def _load_model(self, model_name:str): tic = time.time() # this does the work - if not os.path.isabs(config): - config = os.path.join(Globals.root,config) model_format = mconfig.get('format', 'ckpt') if model_format == 'ckpt': weights = mconfig.weights @@ -239,6 +234,8 @@ def _load_ckpt_model(self, model_name, mconfig): width = mconfig.width height = mconfig.height + if not os.path.isabs(config): + config = os.path.join(Globals.root,config) if not os.path.isabs(weights): weights = os.path.normpath(os.path.join(Globals.root,weights)) # scan model From 98dacba4012680190f80955588ab9989110b2fd7 Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 21 Nov 2022 02:36:57 +0100 Subject: [PATCH 28/42] fix typo in setup.py - `scripts/preload_models.py` --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0a2a808d320..1f95227d5dc 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ def frontend_files(directory): 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Image Processing', ], - scripts = ['scripts/invoke.py','scripts/load_models.py','scripts/sd-metadata.py'], + scripts = ['scripts/invoke.py','scripts/preload_models.py','scripts/sd-metadata.py'], data_files=[('frontend',frontend_files)], ) From 95848c9f8346b20f8d9c36dd74f7838704f36d04 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 23 Nov 2022 14:46:41 -0800 Subject: [PATCH 29/42] dev: upgrade to diffusers 0.8 (from 0.7.1) We get to remove some code by using methods that were factored out in the base class. --- .../environment-lin-amd.yml | 2 +- .../environment-lin-cuda.yml | 2 +- .../environment-mac.yml | 2 +- .../environment-win-cuda.yml | 2 +- .../requirements-base.txt | 2 +- ldm/invoke/generator/diffusers_pipeline.py | 142 +++--------------- 6 files changed, 28 insertions(+), 124 deletions(-) diff --git a/environments-and-requirements/environment-lin-amd.yml b/environments-and-requirements/environment-lin-amd.yml index 15a8b9b0db4..e834da00bd3 100644 --- a/environments-and-requirements/environment-lin-amd.yml +++ b/environments-and-requirements/environment-lin-amd.yml @@ -11,7 +11,7 @@ dependencies: - --extra-index-url https://download.pytorch.org/whl/rocm5.2/ - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers==0.6.0 + - diffusers~=0.8 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-lin-cuda.yml b/environments-and-requirements/environment-lin-cuda.yml index b0aec548aaf..8f4c6e8ab21 100644 --- a/environments-and-requirements/environment-lin-cuda.yml +++ b/environments-and-requirements/environment-lin-cuda.yml @@ -15,7 +15,7 @@ dependencies: - accelerate~=0.13 - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers~=0.7 + - diffusers~=0.8 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-mac.yml b/environments-and-requirements/environment-mac.yml index 7c8aa2f0f94..ed6faff7866 100644 --- a/environments-and-requirements/environment-mac.yml +++ b/environments-and-requirements/environment-mac.yml @@ -22,7 +22,7 @@ dependencies: - albumentations=1.2 - coloredlogs=15.0 - - diffusers~=0.7 + - diffusers~=0.8 - einops=0.3 - eventlet - grpcio=1.46 diff --git a/environments-and-requirements/environment-win-cuda.yml b/environments-and-requirements/environment-win-cuda.yml index d42aa62a3d3..88624cae0b5 100644 --- a/environments-and-requirements/environment-win-cuda.yml +++ b/environments-and-requirements/environment-win-cuda.yml @@ -15,7 +15,7 @@ dependencies: - albumentations==0.4.3 - basicsr==1.4.1 - dependency_injector==4.40.0 - - diffusers==0.6.0 + - diffusers~=0.8 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/requirements-base.txt b/environments-and-requirements/requirements-base.txt index fb2b211571b..aae3ddec622 100644 --- a/environments-and-requirements/requirements-base.txt +++ b/environments-and-requirements/requirements-base.txt @@ -1,7 +1,7 @@ # pip will resolve the version which matches torch albumentations dependency_injector==4.40.0 -diffusers[torch]~=0.7 +diffusers[torch]~=0.8 einops eventlet flask==2.1.3 diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 861bf22a7a7..ab0f7b4db72 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -4,8 +4,8 @@ from typing import List, Optional, Union, Callable import torch +from diffusers import StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -24,7 +24,7 @@ class PipelineIntermediateState: predicted_original: Optional[torch.Tensor] = None -class StableDiffusionGeneratorPipeline(DiffusionPipeline): +class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -65,10 +65,10 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + safety_checker: Optional[StableDiffusionSafetyChecker], + feature_extractor: Optional[CLIPFeatureExtractor], ): - super().__init__() + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor) self.register_modules( vae=vae, @@ -86,51 +86,6 @@ def __init__( ) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.unet.set_use_memory_efficient_attention_xformers(True) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.unet.set_use_memory_efficient_attention_xformers(False) - def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, guidance_scale: float, @@ -193,10 +148,17 @@ def generate( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings, unconditioned_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ - .to(self.unet.device) + + combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=opposing_prompt) + text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2) self.scheduler.set_timesteps(num_inference_steps) - latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype) + latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels, + height=height, width=width, + dtype=self.unet.dtype, device=self._execution_device, + generator=generator, + latents=latents) yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings, guidance_scale, run_id=run_id, **extra_step_kwargs) @@ -242,9 +204,10 @@ def generate_from_embeddings( # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() - image = self.decode_to_image(latents) - output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) - yield self.check_for_safety(output) + with torch.inference_mode(): + image = self.decode_latents(latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + yield self.check_for_safety(output, dtype=text_embeddings.dtype) @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, @@ -272,46 +235,12 @@ def _unet_forward(self, latents, t, text_embeddings): # predict the noise residual return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample - @torch.inference_mode() - def check_for_safety(self, output): - if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): - return output - images = output.images - safety_checker_output = self.feature_extractor(self.numpy_to_pil(images), - return_tensors="pt").to(self.device) - screened_images, has_nsfw_concept = self.safety_checker( - images=images, clip_input=safety_checker_output.pixel_values) + def check_for_safety(self, output, dtype): + with torch.inference_mode(): + screened_images, has_nsfw_concept = self.run_safety_checker( + output.images, device=self._execution_device, dtype=dtype) return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) - @torch.inference_mode() - def decode_to_image(self, latents): - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - return image - - @torch.inference_mode() - def get_text_embeddings(self, - prompt: Union[str, List[str]], - opposing_prompt: Union[str, List[str]], - do_classifier_free_guidance: bool, - batch_size: int): - # get prompt text embeddings - text_input = self._tokenize(prompt) - - text_embeddings = self.text_encoder(text_input.input_ids)[0] - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - # opposing prompt defaults to blank caption for everything in the batch - text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size) - uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0] - else: - uncond_embeddings = None - - return text_embeddings, uncond_embeddings - @torch.inference_mode() def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): """ @@ -338,28 +267,3 @@ def _tokenize(self, prompt: Union[str, List[str]]): def channels(self) -> int: """Compatible with DiffusionWrapper""" return self.unet.in_channels - - def prepare_latents(self, latents, batch_size, height, width, generator, dtype): - # get the initial random noise unless the user supplied it - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.unet.device, - dtype=dtype - ) - else: - if latents.shape != latents_shape: - raise ValueError( - f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - if latents.device != self.unet.device: - raise ValueError(f"Unexpected latents device, got {latents.device}, " - f"expected {self.unet.device}") - - # scale the initial noise by the standard deviation required by the scheduler - latents *= self.scheduler.init_noise_sigma - return latents From 375f3beb308282e26195f0526861212f33a17c24 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 23 Nov 2022 17:38:31 -0800 Subject: [PATCH 30/42] diffusers integration: support img2img --- ldm/invoke/generator/diffusers_pipeline.py | 74 +++++++++++++++++- ldm/invoke/generator/img2img.py | 89 +++++++++------------- 2 files changed, 106 insertions(+), 57 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 861bf22a7a7..911de67601a 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -3,10 +3,12 @@ from dataclasses import dataclass from typing import List, Optional, Union, Callable +import PIL.Image import torch from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -210,6 +212,7 @@ def generate_from_embeddings( *, run_id: str = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + timesteps = None, **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) @@ -220,16 +223,19 @@ def generate_from_embeddings( else: self.invokeai_diffuser.remove_cross_attention_control() + if timesteps is None: + timesteps = self.scheduler.timesteps + # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), self.scheduler.timesteps[0], - dtype=self.scheduler.timesteps.dtype, device=self.unet.device) + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) # NOTE: Depends on scheduler being already initialized! - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, guidance_scale, text_embeddings, unconditioned_embeddings, @@ -272,6 +278,68 @@ def _unet_forward(self, latents, t, text_embeddings): # predict the noise residual return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + def img2img_from_embeddings(self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float, + num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None] = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + run_id=None, + noise_func=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: + device = self.unet.device + latents_dtype = text_embeddings.dtype + batch_size = 1 + num_images_per_prompt = 1 + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image.convert('RGB')) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + + result = None + for result in self.generate_from_embeddings( + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + run_id=run_id, **extra_step_kwargs): + if callback is not None and isinstance(result, PipelineIntermediateState): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor: + # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents + # because we have our own noise function + init_image = init_image.to(device=device, dtype=dtype) + with torch.inference_mode(): + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible! + init_latents = 0.18215 * init_latents + + noise = noise_func(init_latents) + + return self.scheduler.add_noise(init_latents, noise, timestep) + + def _diffusers08_get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + @torch.inference_mode() def check_for_safety(self, output): if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index edcc855a290..6ea41fda33c 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -2,14 +2,10 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator ''' -import PIL -import numpy as np import torch -from PIL import Image -from torch import Tensor -from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline class Img2Img(Generator): @@ -25,66 +21,51 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, """ self.perlin = perlin - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image.convert('RGB')) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - t_enc = int(strength * steps) uc, c, extra_conditioning_info = conditioning + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.scheduler = sampler + def make_image(x_T): - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - self.init_latent, - torch.tensor([t_enc]).to(self.model.device), - noise=x_T - ) - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info, - all_timesteps_count = steps + # FIXME: use x_T for initial seeded noise + pipeline_output = pipeline.img2img_from_embeddings( + init_image, strength, steps, c, uc, cfg_scale, + extra_conditioning_info=extra_conditioning_info, + noise_func=self.get_noise_like, + callback=step_callback ) - return self.sample_to_image(samples) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image - def get_noise(self,width,height): - device = self.model.device - init_latent = self.init_latent - assert init_latent is not None,'call to get_noise() when init_latent not set' + def get_noise_like(self, like: torch.Tensor): + device = like.device if device.type == 'mps': - x = torch.randn_like(init_latent, device='cpu').to(device) + x = torch.randn_like(like, device='cpu').to(device) else: - x = torch.randn_like(init_latent, device=device) + x = torch.randn_like(like, device=device) if self.perlin > 0.0: - shape = init_latent.shape + shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) return x - def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: - image = np.array(image).astype(np.float32) / 255.0 - if len(image.shape) == 2: # 'L' image, as in a mask - image = image[None,None] - else: # 'RGB' image - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - if normalize: - image = 2.0 * image - 1.0 - return image.to(self.model.device) + def get_noise(self,width,height): + # copy of the Txt2Img.get_noise + device = self.model.device + if self.use_mps_noise or device.type == 'mps': + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(device) + else: + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=device) + if self.perlin > 0.0: + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) + return x From da5fee494c32c40f124aafc9e9b6f60c35ffbdca Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 23 Nov 2022 20:30:06 -0800 Subject: [PATCH 31/42] refactor: remove backported img2img.get_timesteps now that we can use it directly from diffusers 0.8.1 --- ldm/invoke/generator/diffusers_pipeline.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 4a4700232ca..5d6b66ccf01 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -5,7 +5,7 @@ import PIL.Image import torch -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess @@ -260,8 +260,9 @@ def img2img_from_embeddings(self, if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image.convert('RGB')) - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength) + img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) + img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -292,17 +293,6 @@ def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_ return self.scheduler.add_noise(init_latents, noise, timestep) - def _diffusers08_get_timesteps(self, num_inference_steps, strength): - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps - def check_for_safety(self, output, dtype): with torch.inference_mode(): screened_images, has_nsfw_concept = self.run_safety_checker( From 9f49f540b6ec3416fedf011c4be4d2a24598a0f8 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 24 Nov 2022 18:38:08 -0800 Subject: [PATCH 32/42] CI: use huggingface cache for test-invoke-pip --- .github/workflows/test-invoke-conda.yml | 5 +- .github/workflows/test-invoke-pip.yml | 61 +++++++++++++------------ 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index a61e648a36f..fb625759fc6 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -61,7 +61,7 @@ jobs: id: cache-sd-model uses: actions/cache@v3 env: - cache-name: huggingface-${{ matrix.stable-diffusion-model-switch }} + cache-name: huggingface-${{ matrix.stable-diffusion-model }} with: path: ~/.cache/huggingface key: ${{ env.cache-name }} @@ -105,12 +105,13 @@ jobs: echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token fi python scripts/configure_invokeai.py \ - --no-interactive \ + --no-interactive --yes \ --full-precision # can't use fp16 weights without a GPU - name: Run the tests id: run-tests env: + # Set offline mode to make sure configure preloaded successfully. HF_HUB_OFFLINE: 1 HF_DATASETS_OFFLINE: 1 TRANSFORMERS_OFFLINE: 1 diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 69fc7c9ce97..a101e7d99cb 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -1,16 +1,16 @@ name: Test invoke.py pip -on: - push: - branches: - - 'main' - - 'development' - pull_request: - branches: - - 'main' - - 'development' +on: [push, pull_request] jobs: matrix: + # Run on: + # - pull requests + # - pushes to forks (will run in the forked project with that fork's secrets) + # - pushes to branches that are *not* pull requests + if: | + github.event_name == 'pull_request' + || github.repository != 'invoke-ai/InvokeAI' + || github.ref_protected strategy: fail-fast: false matrix: @@ -44,6 +44,8 @@ jobs: shell: ${{ matrix.default-shell }} env: INVOKEAI_ROOT: '${{ github.workspace }}/invokeai' + PYTHONUNBUFFERED: 1 + HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }} steps: - name: Checkout sources id: checkout-sources @@ -54,6 +56,19 @@ jobs: mkdir -p ${{ env.INVOKEAI_ROOT }}/configs cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml + - name: Use Cached Stable Diffusion Model + id: cache-sd-model + uses: actions/cache@v3 + env: + cache-name: cache-${{ matrix.stable-diffusion-model }} + with: + path: ${{ matrix.stable-diffusion-model-dl-path }} + key: ${{ env.cache-name }} + + - name: Check model availability + if: steps.cache-sd-model.outputs.cache-hit != true && env.HAVE_SECRETS != 'true' + run: echo -e '\a ⛔ GitHub model cache not found, and no HUGGINGFACE_TOKEN is available. Will not be able to load Stable Diffusion.' ; exit 1 + - name: set test prompt to main branch validation if: ${{ github.ref == 'refs/heads/main' }} run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> $GITHUB_ENV @@ -84,32 +99,20 @@ jobs: ${{ env.pythonLocation }}/bin/pip install --upgrade -r '${{ matrix.requirements-file }}' ${{ env.pythonLocation }}/bin/pip install -e . - - name: Use Cached Stable Diffusion Model - id: cache-sd-model - uses: actions/cache@v3 - env: - cache-name: cache-${{ matrix.stable-diffusion-model }} - with: - path: ${{ matrix.stable-diffusion-model-dl-path }} - key: ${{ env.cache-name }} - - - name: Download ${{ matrix.stable-diffusion-model }} - id: download-stable-diffusion-model - if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }} - run: | - mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}" - curl \ - -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" \ - -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" \ - -L ${{ matrix.stable-diffusion-model-url }} - - name: run configure_invokeai.py id: run-preload-models run: | - ${{ env.pythonLocation }}/bin/python scripts/configure_invokeai.py --no-interactive --yes + ${{ env.pythonLocation }}/bin/python scripts/configure_invokeai.py \ + --no-interactive --yes \ + --full-precision # can't use fp16 weights without a GPU - name: Run the tests id: run-tests + env: + # Set offline mode to make sure configure preloaded successfully. + HF_HUB_OFFLINE: 1 + HF_DATASETS_OFFLINE: 1 + TRANSFORMERS_OFFLINE: 1 run: | time ${{ env.pythonLocation }}/bin/python scripts/invoke.py \ --model ${{ matrix.stable-diffusion-model }} \ From afae108cec548794de1fb47d690ab555d5b9ce97 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 24 Nov 2022 20:12:48 -0800 Subject: [PATCH 33/42] ci: use diffusers model --- .github/workflows/test-invoke-conda.yml | 6 +----- .github/workflows/test-invoke-pip.yml | 10 +++------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index fb625759fc6..92ab6b5a6e2 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: stable-diffusion-model: - - 'stable-diffusion-1.5' + - 'diffusers-1.5' environment-yaml: - environment-lin-amd.yml - environment-lin-cuda.yml @@ -30,10 +30,6 @@ jobs: - environment-yaml: environment-mac.yml os: macos-12 default-shell: bash -l {0} - - stable-diffusion-model: stable-diffusion-1.5 - stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt - stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1 - stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt name: ${{ matrix.environment-yaml }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} env: diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index a101e7d99cb..10c9c7d7f64 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: stable-diffusion-model: - - stable-diffusion-1.5 + - diffusers-1.5 requirements-file: - requirements-lin-cuda.txt - requirements-lin-amd.txt @@ -33,10 +33,6 @@ jobs: - requirements-file: requirements-mac-mps-cpu.txt os: macOS-12 default-shell: bash -l {0} - - stable-diffusion-model: stable-diffusion-1.5 - stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt - stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1 - stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt name: ${{ matrix.requirements-file }} on ${{ matrix.python-version }} runs-on: ${{ matrix.os }} defaults: @@ -60,9 +56,9 @@ jobs: id: cache-sd-model uses: actions/cache@v3 env: - cache-name: cache-${{ matrix.stable-diffusion-model }} + cache-name: huggingface-${{ matrix.stable-diffusion-model }} with: - path: ${{ matrix.stable-diffusion-model-dl-path }} + path: ~/.cache/huggingface key: ${{ env.cache-name }} - name: Check model availability From 1cbfa5e9b39217a2589953dcd65f07656f1fd302 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 24 Nov 2022 20:32:18 -0800 Subject: [PATCH 34/42] fixup! ci: use diffusers model --- .github/workflows/test-invoke-pip.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 10c9c7d7f64..705636a876c 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -98,6 +98,10 @@ jobs: - name: run configure_invokeai.py id: run-preload-models run: | + if [ "${HAVE_SECRETS}" == true ] ; then + mkdir -p ~/.huggingface + echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token + fi ${{ env.pythonLocation }}/bin/python scripts/configure_invokeai.py \ --no-interactive --yes \ --full-precision # can't use fp16 weights without a GPU From 5a6f236961d655f368eb35514ace22b360fb0b5a Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 25 Nov 2022 13:31:56 -0800 Subject: [PATCH 35/42] dev: upgrade to diffusers 0.9 (from 0.8.1) --- environments-and-requirements/environment-lin-amd.yml | 2 +- environments-and-requirements/environment-lin-cuda.yml | 2 +- environments-and-requirements/environment-mac.yml | 2 +- environments-and-requirements/environment-win-cuda.yml | 2 +- environments-and-requirements/requirements-base.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/environments-and-requirements/environment-lin-amd.yml b/environments-and-requirements/environment-lin-amd.yml index 1ca7877abbb..4949a912bda 100644 --- a/environments-and-requirements/environment-lin-amd.yml +++ b/environments-and-requirements/environment-lin-amd.yml @@ -11,7 +11,7 @@ dependencies: - --extra-index-url https://download.pytorch.org/whl/rocm5.2/ - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers~=0.8 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-lin-cuda.yml b/environments-and-requirements/environment-lin-cuda.yml index fad21502897..c36a9d2ba33 100644 --- a/environments-and-requirements/environment-lin-cuda.yml +++ b/environments-and-requirements/environment-lin-cuda.yml @@ -15,7 +15,7 @@ dependencies: - accelerate~=0.13 - albumentations==0.4.3 - dependency_injector==4.40.0 - - diffusers~=0.8 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/environment-mac.yml b/environments-and-requirements/environment-mac.yml index a9fe0cd4448..f08b4364cfa 100644 --- a/environments-and-requirements/environment-mac.yml +++ b/environments-and-requirements/environment-mac.yml @@ -22,7 +22,7 @@ dependencies: - albumentations=1.2 - coloredlogs=15.0 - - diffusers~=0.8 + - diffusers~=0.9 - einops=0.3 - eventlet - grpcio=1.46 diff --git a/environments-and-requirements/environment-win-cuda.yml b/environments-and-requirements/environment-win-cuda.yml index 88624cae0b5..039e07807ab 100644 --- a/environments-and-requirements/environment-win-cuda.yml +++ b/environments-and-requirements/environment-win-cuda.yml @@ -15,7 +15,7 @@ dependencies: - albumentations==0.4.3 - basicsr==1.4.1 - dependency_injector==4.40.0 - - diffusers~=0.8 + - diffusers~=0.9 - einops==0.3.0 - eventlet - flask==2.1.3 diff --git a/environments-and-requirements/requirements-base.txt b/environments-and-requirements/requirements-base.txt index 3770ebf2a5c..6b12b07067b 100644 --- a/environments-and-requirements/requirements-base.txt +++ b/environments-and-requirements/requirements-base.txt @@ -1,7 +1,7 @@ # pip will resolve the version which matches torch albumentations dependency_injector==4.40.0 -diffusers[torch]~=0.8 +diffusers[torch]~=0.9 einops eventlet facexlib From 15e99aef47bc4eb7fcf90266a266c0c850bea5cc Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 25 Nov 2022 13:52:29 -0800 Subject: [PATCH 36/42] lint: correct annotations for Python 3.9. --- ldm/invoke/generator/base.py | 2 ++ ldm/invoke/generator/diffusers_pipeline.py | 2 ++ ldm/invoke/model_cache.py | 8 +++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index e49f50e7c3c..2ff49a42c3c 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -2,6 +2,8 @@ Base class for ldm.invoke.generator.* including img2img, txt2img, and inpaint ''' +from __future__ import annotations + import os import random import traceback diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 5d6b66ccf01..2d3f694687a 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import secrets import warnings from dataclasses import dataclass diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 80e0d26f3b8..60b9b06d38b 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,18 +4,20 @@ below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +from __future__ import annotations + +import contextlib import gc import hashlib import io import os import sys +import textwrap import time import traceback -import textwrap -import contextlib -from typing import Union import warnings from pathlib import Path +from typing import Union import torch import transformers From 66c5689cf964fdfa922b62ac3d98ed2f7adc2a3e Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 25 Nov 2022 14:11:19 -0800 Subject: [PATCH 37/42] lint: correct AttributeError.name reference for Python 3.9. --- ldm/models/diffusion/cross_attention_control.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index cd741397a19..ec7c3c215cc 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -4,6 +4,7 @@ import torch + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -255,7 +256,7 @@ def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim lambda module: context.get_slicing_strategy(identifier) ) except AttributeError as e: - if e.name == 'set_attention_slice_wrangler': + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): warnings.warn(f"TODO: implement for {type(module)}") # TODO else: raise @@ -270,7 +271,14 @@ def remove_attention_function(unet): module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) except AttributeError as e: - if e.name == 'set_attention_slice_wrangler': + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): warnings.warn(f"TODO: implement for {type(module)}") # TODO else: raise + + +def is_attribute_error_about(error: AttributeError, attribute: str): + if hasattr(error, 'name'): # Python 3.10 + return error.name == attribute + else: # Python 3.9 + return attribute in str(error) From f72b0c84538eb79e109a5cd968e07bb5a615eb10 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 25 Nov 2022 14:24:00 -0800 Subject: [PATCH 38/42] CI: prefer diffusers-1.4 because it no longer requires a token The RunwayML models still do. --- .github/workflows/test-invoke-conda.yml | 6 +----- .github/workflows/test-invoke-pip.yml | 6 +----- configs/models.yaml.example | 6 +++++- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index 92ab6b5a6e2..e5b91a59573 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: stable-diffusion-model: - - 'diffusers-1.5' + - diffusers-1.4 environment-yaml: - environment-lin-amd.yml - environment-lin-cuda.yml @@ -62,10 +62,6 @@ jobs: path: ~/.cache/huggingface key: ${{ env.cache-name }} - - name: Check model availability - if: steps.cache-sd-model.outputs.cache-hit != true && env.HAVE_SECRETS != 'true' - run: echo -e '\a ⛔ GitHub model cache not found, and no HUGGINGFACE_TOKEN is available. Will not be able to load Stable Diffusion.' ; exit 1 - - name: Use cached conda packages id: use-cached-conda-packages uses: actions/cache@v3 diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 705636a876c..4be586cb95b 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: stable-diffusion-model: - - diffusers-1.5 + - diffusers-1.4 requirements-file: - requirements-lin-cuda.txt - requirements-lin-amd.txt @@ -61,10 +61,6 @@ jobs: path: ~/.cache/huggingface key: ${{ env.cache-name }} - - name: Check model availability - if: steps.cache-sd-model.outputs.cache-hit != true && env.HAVE_SECRETS != 'true' - run: echo -e '\a ⛔ GitHub model cache not found, and no HUGGINGFACE_TOKEN is available. Will not be able to load Stable Diffusion.' ; exit 1 - - name: set test prompt to main branch validation if: ${{ github.ref == 'refs/heads/main' }} run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> $GITHUB_ENV diff --git a/configs/models.yaml.example b/configs/models.yaml.example index 1eb2781f4a8..87bc13645d1 100644 --- a/configs/models.yaml.example +++ b/configs/models.yaml.example @@ -5,11 +5,15 @@ # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. +diffusers-1.4: + description: Diffusers version of Stable Diffusion version 1.4 + format: diffusers + repo_name: CompVis/stable-diffusion-v1-4 + default: true diffusers-1.5: description: Diffusers version of Stable Diffusion version 1.5 format: diffusers repo_name: runwayml/stable-diffusion-v1-5 - default: true stable-diffusion-1.5: description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt From 50ef6ef18cdfc5b496c790a87234a9b86334b63f Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 26 Nov 2022 10:52:40 -0800 Subject: [PATCH 39/42] build: there's yet another place to update requirements? --- environments-and-requirements/requirements-base.txt | 4 ++-- installer/requirements.in | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/environments-and-requirements/requirements-base.txt b/environments-and-requirements/requirements-base.txt index ac4ef8fbc93..c502c1a4846 100644 --- a/environments-and-requirements/requirements-base.txt +++ b/environments-and-requirements/requirements-base.txt @@ -30,10 +30,10 @@ taming-transformers-rom1504 test-tube>=0.7.5 torch-fidelity torchmetrics -transformers==4.21.* +transformers~=4.24 picklescan git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg git+https://github.com/invoke-ai/GFPGAN@basicsr-1.4.2#egg=gfpgan -git+https://github.com/invoke-ai/PyPatchMatch@0.1.1#egg=pypatchmatch \ No newline at end of file +git+https://github.com/invoke-ai/PyPatchMatch@0.1.1#egg=pypatchmatch diff --git a/installer/requirements.in b/installer/requirements.in index ab6b2a1ff5c..de97f06f1b9 100644 --- a/installer/requirements.in +++ b/installer/requirements.in @@ -3,7 +3,7 @@ --trusted-host https://download.pytorch.org accelerate~=0.14 albumentations -diffusers +diffusers[torch]~=0.9 einops eventlet facexlib From 03c057ad923bc53019266fcf9e8bcd8e025769e2 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 26 Nov 2022 10:58:12 -0800 Subject: [PATCH 40/42] configure: try to download models even without token Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.) --- scripts/configure_invokeai.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 284841e96e2..753478d7e05 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -8,27 +8,28 @@ # print('Loading Python libraries...\n') import argparse -import sys import os import re -from typing import Dict import shutil +import sys +import traceback +import warnings +from pathlib import Path +from typing import Dict from urllib import request -from tqdm import tqdm + +import requests +import transformers from diffusers import StableDiffusionPipeline, AutoencoderKL -from omegaconf import OmegaConf -from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami -from pathlib import Path from getpass_asterisk import getpass_asterisk +from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami +from omegaconf import OmegaConf +from tqdm import tqdm from transformers import CLIPTokenizer, CLIPTextModel + from ldm.invoke.globals import Globals from ldm.invoke.readline import generic_completer -import traceback -import requests -import clip -import transformers -import warnings warnings.filterwarnings('ignore') import torch transformers.logging.set_verbosity_error() @@ -386,6 +387,9 @@ def download_diffusers_in_config(config_path: Path, full_precision: bool): # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, # which moves a bunch of stuff. # We can be more complete after we know it won't be all merge conflicts. + if not is_huggingface_authenticated(): + print("*⚠ No Hugging Face access token; some downloads may be blocked.") + precision = 'full' if full_precision else 'float16' cache = ModelCache(OmegaConf.load(config_path), precision=precision, device_type='cpu', max_loaded_models=1) @@ -471,7 +475,7 @@ def download_bert(): print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) - from transformers import BertTokenizerFast, AutoFeatureExtractor + from transformers import BertTokenizerFast download_from_hf(BertTokenizerFast,'bert-base-uncased') print('...success',file=sys.stderr) @@ -761,14 +765,12 @@ def main(): if opt.interactive: print('** DOWNLOADING DIFFUSION WEIGHTS **') download_weights(opt) - elif is_huggingface_authenticated(): + else: config_path = Path(opt.config_file or Default_config_file) if config_path.exists(): download_diffusers_in_config(config_path, full_precision=opt.full_precision) else: print("*⚠ No config file found; downloading no weights.") - else: - print("*⚠ No Hugging Face access; downloading no weights.") print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_clip() From 18cd56a0be093c04b024e378ef48e0960e8a93a1 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 27 Nov 2022 08:27:03 -0800 Subject: [PATCH 41/42] configure: add troubleshooting info for config-not-found --- scripts/configure_invokeai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index ecf7f6eaa19..77a485d59d6 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -770,7 +770,7 @@ def main(): if config_path.exists(): download_diffusers_in_config(config_path, full_precision=opt.full_precision) else: - print("*⚠ No config file found; downloading no weights.") + print(f"*⚠ No config file found; downloading no weights. Looked in {config_path}") print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_clip() From a7dd76f21452112412740ed188de3fd93f930372 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 27 Nov 2022 09:09:38 -0800 Subject: [PATCH 42/42] fix(configure): prepend root to config path --- scripts/configure_invokeai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 77a485d59d6..7089795db02 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -766,7 +766,7 @@ def main(): print('** DOWNLOADING DIFFUSION WEIGHTS **') download_weights(opt) else: - config_path = Path(opt.config_file or Default_config_file) + config_path = Path(Globals.root, opt.config_file or Default_config_file) if config_path.exists(): download_diffusers_in_config(config_path, full_precision=opt.full_precision) else: