From c8ae3cb548f773d32b1a83d8622aec2f6974dcde Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 10:39:56 -0700 Subject: [PATCH 01/46] Add inference helpers & tests --- pytest.ini | 3 + sgm/inference/helpers.py | 476 ++++++++++++++++++++++++++++++ tests/inference/test_inference.py | 145 +++++++++ 3 files changed, 624 insertions(+) create mode 100644 pytest.ini create mode 100644 sgm/inference/helpers.py create mode 100644 tests/inference/test_inference.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..d79bd9b2e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') \ No newline at end of file diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py new file mode 100644 index 000000000..f46ef291b --- /dev/null +++ b/sgm/inference/helpers.py @@ -0,0 +1,476 @@ +import os +from typing import Union, List + +import math +import numpy as np +import torch +from PIL import Image +from einops import rearrange, repeat +from imwatermark import WatermarkEncoder +from omegaconf import OmegaConf, ListConfig +from torch import autocast +from torchvision import transforms +from torchvision.utils import make_grid +from safetensors.torch import load_file as load_safetensors + +from sgm.modules.diffusionmodules.sampling import ( + EulerEDMSampler, + HeunEDMSampler, + EulerAncestralSampler, + DPMPP2SAncestralSampler, + DPMPP2MSampler, + LinearMultistepSampler, +) +from sgm.util import append_dims +from sgm.util import instantiate_from_config + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor): + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, C, H, W) in range [0, 1] + + Returns: + same as input but watermarked + """ + # watermarking libary expects input as cv2 BGR format + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy( + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) + ).to(image.device) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + return image + + +# A fixed 48-bit message that was choosen at random +# WATERMARK_MESSAGE = 0xB3EC907BB19E +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watemark = WatermarkEmbedder(WATERMARK_BITS) + + +def load_model_from_config(config, ckpt=None, verbose=True): + model = instantiate_from_config(config.model) + + if ckpt is not None: + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + msg = None + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + else: + msg = None + + model.cuda() + model.eval() + return model, msg + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def perform_save_locally(save_path, samples): + os.makedirs(os.path.join(save_path), exist_ok=True) + base_count = len(os.listdir(os.path.join(save_path))) + samples = embed_watemark(samples) + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(sample.astype(np.uint8)).save( + os.path.join(save_path, f"{base_count:09}.png") + ) + base_count += 1 + + +class Img2ImgDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 1.0): + self.discretization = discretization + self.strength = strength + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] + print("prune index:", max(int(self.strength * len(sigmas)), 1)) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + +def get_guider(guider, **kwargs): + + if guider == "IdentityGuider": + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif guider == "VanillaCFG": + scale = max(0.0, min(100.0, kwargs.pop("scale", 5.0))) + + thresholder = kwargs.pop("thresholder", "None") + + if thresholder == "None": + dyn_thresh_config = { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, + } + else: + raise NotImplementedError + return guider_config + +def get_discretization(discretization, **kwargs): + if discretization == "LegacyDDPMDiscretization": + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif discretization == "EDMDiscretization": + sigma_min = kwargs.pop("sigma_min", 0.03) # 0.0292 + sigma_max = kwargs.pop("sigma_max", 14.61) # 14.6146 + rho = kwargs.pop("rho", 3.0) + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": sigma_min, + "sigma_max": sigma_max, + "rho": rho, + }, + } + else: + raise ValueError(f'unknown discertization {discretization}') + return discretization_config + +def get_sampler(sampler_name, steps, discretization_config, guider_config, **kwargs): + if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": + s_churn = kwargs.pop("s_churn", 0.0) + s_tmin = kwargs.pop("s_tmin", 0.0) + s_tmax = kwargs.pop("s_tmax", 999.0) + s_noise = kwargs.pop("s_noise", 1.0) + + if sampler_name == "EulerEDMSampler": + sampler = EulerEDMSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=s_churn, + s_tmin=s_tmin, + s_tmax=s_tmax, + s_noise=s_noise, + verbose=True, + ) + elif sampler_name == "HeunEDMSampler": + sampler = HeunEDMSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=s_churn, + s_tmin=s_tmin, + s_tmax=s_tmax, + s_noise=s_noise, + verbose=True, + ) + elif ( + sampler_name == "EulerAncestralSampler" + or sampler_name == "DPMPP2SAncestralSampler" + ): + s_noise = kwargs.pop("s_noise", 1.0) + eta = kwargs.pop("eta", 1.0) + + if sampler_name == "EulerAncestralSampler": + sampler = EulerAncestralSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=eta, + s_noise=s_noise, + verbose=True, + ) + elif sampler_name == "DPMPP2SAncestralSampler": + sampler = DPMPP2SAncestralSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=eta, + s_noise=s_noise, + verbose=True, + ) + elif sampler_name == "DPMPP2MSampler": + sampler = DPMPP2MSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=True, + ) + elif sampler_name == "LinearMultistepSampler": + order = kwargs.pop("order", 4) + sampler = LinearMultistepSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=order, + verbose=True, + ) + else: + raise ValueError(f"unknown sampler {sampler_name}!") + + return sampler + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: List = None, + batch2model_input: List = None, + return_latents=False, + filter=None, +): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + if batch2model_input is None: + batch2model_input = [] + + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + num_samples = [num_samples] + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to("cuda") + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def get_input_image_tensor(image: Image, device="cuda"): + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 + image = image.resize((width, height)) + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image.to(device) + +@torch.no_grad() +def do_img2img( + img, + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=[], + additional_kwargs={}, + offset_noise_level: int = 0.0, + return_latents=False, + skip_encode=False, + filter=None, + logger=None, +): + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) + + for k in additional_kwargs: + c[k] = uc[k] = additional_kwargs[k] + if skip_encode: + z = img + else: + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) + sigma = sigmas[0].to(z.device) + + if logger is not None: + logger.info(f"all sigmas: {sigmas}") + logger.info(f"noising sigma: {sigma}") + + if offset_noise_level > 0.0: + noise = noise + offset_noise_level * append_dims( + torch.randn(z.shape[0], device=z.device), z.ndim + ) + noised_z = z + noise * append_dims(sigma, z.ndim) + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + + def denoiser(x, sigma, c): + return model.denoiser(model.model, x, sigma, c) + + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 000000000..0a88e2f42 --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,145 @@ +import numpy +from PIL import Image +import pytest +from pytest import fixture +from omegaconf import OmegaConf +import torch + +import sgm.inference.helpers as helpers + +VERSION2SPECS = { + "SD-XL base": { + "H": 1024, + "W": 1024, + "C": 4, + "f": 8, + "is_legacy": False, + "config": "configs/inference/sd_xl_base.yaml", + "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", + "is_guided": True, + }, + "sd-2.1": { + "H": 512, + "W": 512, + "C": 4, + "f": 8, + "is_legacy": True, + "config": "configs/inference/sd_2_1.yaml", + "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", + "is_guided": True, + }, + "sd-2.1-768": { + "H": 768, + "W": 768, + "C": 4, + "f": 8, + "is_legacy": True, + "config": "configs/inference/sd_2_1_768.yaml", + "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", + }, + "SDXL-Refiner": { + "H": 1024, + "W": 1024, + "C": 4, + "f": 8, + "is_legacy": True, + "config": "configs/inference/sd_xl_refiner.yaml", + "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", + "is_guided": True, + }, +} + +samplers = [ + "EulerEDMSampler", + "HeunEDMSampler", + "EulerAncestralSampler", + "DPMPP2SAncestralSampler", + "DPMPP2MSampler", + "LinearMultistepSampler" +] + +@pytest.mark.inference +class TestInference: + @fixture(scope="class", params=['SD-XL base', 'sd-2.1', 'sd-2.1-768', 'SDXL-Refiner']) + def model(self, request): + specs = VERSION2SPECS[request.param] + config = OmegaConf.load(specs['config']) + model, _ = helpers.load_model_from_config(config, specs['ckpt']) + model.conditioner.half() + model.model.half() + yield model, specs + del model + torch.cuda.empty_cache() + + def create_init_image(self, h, w): + image_array = numpy.random.rand(h,w,3) * 255 + image = Image.fromarray(image_array.astype('uint8')).convert('RGB') + return helpers.get_input_image_tensor(image) + + + @pytest.mark.parametrize("sampler_name", samplers) + def test_txt2img(self, model, sampler_name): + specs = model[1] + model = model[0] + value_dict = { + "prompt": "A professional photograph of an astronaut riding a pig", + "negative_prompt": "", + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + "orig_height": specs['H'], + "orig_width": specs['W'], + "target_height": specs['H'], + "target_width": specs['W'], + "crop_coords_top": 0, + "crop_coords_left": 0 + } + sampler = helpers.get_sampler(sampler_name=sampler_name, + steps=10, + discretization_config=helpers.get_discretization("LegacyDDPMDiscretization"), + guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), + ) + output = helpers.do_sample( + model=model, + sampler=sampler, + value_dict=value_dict, + num_samples=1, + H=specs['H'], + W=specs['W'], + C=specs['C'], + F=specs['f'] + ) + + assert output is not None + + @pytest.mark.parametrize("sampler_name", samplers) + def test_img2img(self, model, sampler_name): + specs = model[1] + model = model[0] + init_image = self.create_init_image(specs['H'], specs['W']).to('cuda') + value_dict = { + "prompt": "A professional photograph of an astronaut riding a pig", + "negative_prompt": "", + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + "orig_height": specs['H'], + "orig_width": specs['W'], + "target_height": specs['H'], + "target_width": specs['W'], + "crop_coords_top": 0, + "crop_coords_left": 0 + } + + sampler = helpers.get_sampler(sampler_name=sampler_name, + steps=10, + discretization_config=helpers.get_discretization("LegacyDDPMDiscretization"), + guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), + ) + + + output = helpers.do_img2img( + img=init_image, + model=model, + sampler=sampler, + value_dict=value_dict, + num_samples=1 + ) \ No newline at end of file From dbad98ce958aa74afbf4cbf843ed6851c7d2d471 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 11:17:03 -0700 Subject: [PATCH 02/46] Support testing with hatch --- pyproject.toml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 3b790a8aa..787d7d14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,3 +32,15 @@ include = [ [tool.hatch.build.targets.wheel.force-include] "./configs" = "sgm/configs" + +[tool.hatch.envs.ci] +dependencies = [ + "pytest" +] + +[tool.hatch.envs.ci.scripts] +test-inference = [ + "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", + "pip install -r requirements_pt2.txt", + "pytest -v tests/test_inference.py" +] \ No newline at end of file From 7e8fbd7696e9e42978c52af4702121f37fd94a52 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 11:43:14 -0700 Subject: [PATCH 03/46] fixes to hatch script --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 787d7d14a..7a330468e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,10 @@ include = [ "./configs" = "sgm/configs" [tool.hatch.envs.ci] +# Skip for now, since requirements.txt is used by scripts and includes the project +# This should be changed when dependencies are handled by Hatch +skip-install = true + dependencies = [ "pytest" ] @@ -42,5 +46,5 @@ dependencies = [ test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements_pt2.txt", - "pytest -v tests/test_inference.py" + "pytest -v tests/inference/test_inference.py" ] \ No newline at end of file From 53338361852699bc0b6d947c3217569bf07a1025 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 11:57:35 -0700 Subject: [PATCH 04/46] add inference test action --- .github/workflows/test-inference.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .github/workflows/test-inference.yml diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml new file mode 100644 index 000000000..9261b449a --- /dev/null +++ b/.github/workflows/test-inference.yml @@ -0,0 +1,21 @@ +name: Test inference + +on: + pull_request: + branches: [ main ] + push: + branches: [ main ] + +jobs: + test: + # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the envrionment + if: github.repository == 'stability-ai/generative-models' + runs-on: [self-hosted, slurm, g40] + steps: + - uses: actions/checkout@v2 + - name: "Symlink checkpoints" + run: ln -s ~/sgm-checkpoints checkpoints + - name: "Install Hatch" + run: pip install hatch + - name: "Run inference tests" + run: hatch run ci:test-inference \ No newline at end of file From 74d261da72065f6a095f905b5efc43cdb9a8fb21 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:09:25 -0700 Subject: [PATCH 05/46] change workflow trigger --- .github/workflows/test-inference.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 9261b449a..e4bee7426 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -2,9 +2,9 @@ name: Test inference on: pull_request: - branches: [ main ] push: - branches: [ main ] + branches: + - main jobs: test: From e02de5a52269270b8b8b47b94942cf7f70d76dc5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:10:46 -0700 Subject: [PATCH 06/46] widen trigger to test --- .github/workflows/test-inference.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index e4bee7426..21a43b455 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -3,8 +3,6 @@ name: Test inference on: pull_request: push: - branches: - - main jobs: test: From 5cc8f7cd0ff5d2073f0ddefa439cb38d4a6e95a7 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:14:39 -0700 Subject: [PATCH 07/46] revert changes to workflow triggers --- .github/workflows/test-inference.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 21a43b455..15f3dc9e1 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -2,7 +2,11 @@ name: Test inference on: pull_request: + branches: + - main push: + branches: + - main jobs: test: From 927ffffbb64e995de81fb0518029a69e5df81d4f Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:19:16 -0700 Subject: [PATCH 08/46] Install local python in action --- .github/workflows/test-inference.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 15f3dc9e1..e814cc4b6 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -16,7 +16,11 @@ jobs: steps: - uses: actions/checkout@v2 - name: "Symlink checkpoints" - run: ln -s ~/sgm-checkpoints checkpoints + run: ln -s /admin/home-palp/sgm-checkpoints checkpoints + - name: "Setup python" + uses: actions/setup-python@v2 + with: + python-version: 3.10 - name: "Install Hatch" run: pip install hatch - name: "Run inference tests" From f1eb78691ac56d1027ef0e5dd73ada29917b7ba5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:20:12 -0700 Subject: [PATCH 09/46] Trigger on push again --- .github/workflows/test-inference.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index e814cc4b6..ef7941779 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -2,11 +2,7 @@ name: Test inference on: pull_request: - branches: - - main push: - branches: - - main jobs: test: From 4e9ffe50bd45c73e04ce4009676b955626064082 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:21:55 -0700 Subject: [PATCH 10/46] fix python version --- .github/workflows/test-inference.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index ef7941779..391fed5e9 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -5,7 +5,8 @@ on: push: jobs: - test: + test: + name: "Test inference" # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the envrionment if: github.repository == 'stability-ai/generative-models' runs-on: [self-hosted, slurm, g40] @@ -16,7 +17,7 @@ jobs: - name: "Setup python" uses: actions/setup-python@v2 with: - python-version: 3.10 + python-version: "3.10" - name: "Install Hatch" run: pip install hatch - name: "Run inference tests" From cc7e9836d581f115ec4dfbe3a29a091b7a8d8831 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:35:15 -0700 Subject: [PATCH 11/46] add CODEOWNERS and change triggers --- .github/workflows/CODEOWNERS | 1 + .github/workflows/test-inference.yml | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/CODEOWNERS diff --git a/.github/workflows/CODEOWNERS b/.github/workflows/CODEOWNERS new file mode 100644 index 000000000..ae8a7a268 --- /dev/null +++ b/.github/workflows/CODEOWNERS @@ -0,0 +1 @@ +.github @Stability-AI/infrastructure \ No newline at end of file diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 391fed5e9..6ff09f85b 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -1,8 +1,10 @@ name: Test inference on: - pull_request: + pull_request: push: + branches: + - main jobs: test: From 29a39d0df7e46beadb35b1628ebe24e55498509a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:40:50 -0700 Subject: [PATCH 12/46] Report tests results --- .github/workflows/test-inference.yml | 10 +++++++++- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 6ff09f85b..c037ce090 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -23,4 +23,12 @@ jobs: - name: "Install Hatch" run: pip install hatch - name: "Run inference tests" - run: hatch run ci:test-inference \ No newline at end of file + run: hatch run ci:test-inference --junit-xml test-results.xml + - name: Surface failing tests + if: always() + uses: pmeier/pytest-results-action@main + with: + path: test-results.xml + summary: true + display-options: fEX + fail-on-empty: true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7a330468e..2c65bc103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,5 +46,5 @@ dependencies = [ test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements_pt2.txt", - "pytest -v tests/inference/test_inference.py" + "pytest -v tests/inference/test_inference.py {args}", ] \ No newline at end of file From 06b9c768969ad141e0e32b2f5836346eac253c60 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 12:56:14 -0700 Subject: [PATCH 13/46] update action versions --- .github/workflows/test-inference.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index c037ce090..76204c75d 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -13,11 +13,11 @@ jobs: if: github.repository == 'stability-ai/generative-models' runs-on: [self-hosted, slurm, g40] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Symlink checkpoints" run: ln -s /admin/home-palp/sgm-checkpoints checkpoints - name: "Setup python" - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.10" - name: "Install Hatch" From 0703ef8fc3588771f9b376d4d58c4d5dc154aade Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 13:47:15 -0700 Subject: [PATCH 14/46] format --- sgm/inference/helpers.py | 26 +++++---- tests/inference/test_inference.py | 91 +++++++++++++++++-------------- 2 files changed, 64 insertions(+), 53 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index f46ef291b..5a4a2941c 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -77,7 +77,7 @@ def load_model_from_config(config, ckpt=None, verbose=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): @@ -142,17 +142,17 @@ def __call__(self, *args, **kwargs): print(f"sigmas after pruning: ", sigmas) return sigmas + def get_guider(guider, **kwargs): - if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = max(0.0, min(100.0, kwargs.pop("scale", 5.0))) - + thresholder = kwargs.pop("thresholder", "None") - + if thresholder == "None": dyn_thresh_config = { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" @@ -168,13 +168,14 @@ def get_guider(guider, **kwargs): raise NotImplementedError return guider_config + def get_discretization(discretization, **kwargs): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": - sigma_min = kwargs.pop("sigma_min", 0.03) # 0.0292 + sigma_min = kwargs.pop("sigma_min", 0.03) # 0.0292 sigma_max = kwargs.pop("sigma_max", 14.61) # 14.6146 rho = kwargs.pop("rho", 3.0) discretization_config = { @@ -186,9 +187,10 @@ def get_discretization(discretization, **kwargs): }, } else: - raise ValueError(f'unknown discertization {discretization}') + raise ValueError(f"unknown discertization {discretization}") return discretization_config + def get_sampler(sampler_name, steps, discretization_config, guider_config, **kwargs): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = kwargs.pop("s_churn", 0.0) @@ -264,6 +266,7 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, **kwa return sampler + def do_sample( model, sampler, @@ -405,6 +408,7 @@ def get_input_image_tensor(image: Image, device="cuda"): image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image.to(device) + @torch.no_grad() def do_img2img( img, @@ -417,7 +421,7 @@ def do_img2img( offset_noise_level: int = 0.0, return_latents=False, skip_encode=False, - filter=None, + filter=None, logger=None, ): precision_scope = autocast @@ -449,8 +453,8 @@ def do_img2img( sigma = sigmas[0].to(z.device) if logger is not None: - logger.info(f"all sigmas: {sigmas}") - logger.info(f"noising sigma: {sigma}") + logger.info(f"all sigmas: {sigmas}") + logger.info(f"noising sigma: {sigma}") if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( @@ -470,7 +474,7 @@ def denoiser(x, sigma, c): if filter is not None: samples = filter(samples) - + if return_latents: return samples, samples_z - return samples \ No newline at end of file + return samples diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 0a88e2f42..331029bac 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -55,58 +55,63 @@ "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", - "LinearMultistepSampler" + "LinearMultistepSampler", ] + @pytest.mark.inference -class TestInference: - @fixture(scope="class", params=['SD-XL base', 'sd-2.1', 'sd-2.1-768', 'SDXL-Refiner']) +class TestInference: + @fixture( + scope="class", params=["SD-XL base", "sd-2.1", "sd-2.1-768", "SDXL-Refiner"] + ) def model(self, request): - specs = VERSION2SPECS[request.param] - config = OmegaConf.load(specs['config']) - model, _ = helpers.load_model_from_config(config, specs['ckpt']) + specs = VERSION2SPECS[request.param] + config = OmegaConf.load(specs["config"]) + model, _ = helpers.load_model_from_config(config, specs["ckpt"]) model.conditioner.half() - model.model.half() + model.model.half() yield model, specs del model - torch.cuda.empty_cache() - + torch.cuda.empty_cache() + def create_init_image(self, h, w): - image_array = numpy.random.rand(h,w,3) * 255 - image = Image.fromarray(image_array.astype('uint8')).convert('RGB') + image_array = numpy.random.rand(h, w, 3) * 255 + image = Image.fromarray(image_array.astype("uint8")).convert("RGB") return helpers.get_input_image_tensor(image) - @pytest.mark.parametrize("sampler_name", samplers) def test_txt2img(self, model, sampler_name): specs = model[1] - model = model[0] + model = model[0] value_dict = { "prompt": "A professional photograph of an astronaut riding a pig", "negative_prompt": "", "aesthetic_score": 6.0, "negative_aesthetic_score": 2.5, - "orig_height": specs['H'], - "orig_width": specs['W'], - "target_height": specs['H'], - "target_width": specs['W'], + "orig_height": specs["H"], + "orig_width": specs["W"], + "target_height": specs["H"], + "target_width": specs["W"], "crop_coords_top": 0, - "crop_coords_left": 0 + "crop_coords_left": 0, } - sampler = helpers.get_sampler(sampler_name=sampler_name, - steps=10, - discretization_config=helpers.get_discretization("LegacyDDPMDiscretization"), - guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), - ) + sampler = helpers.get_sampler( + sampler_name=sampler_name, + steps=10, + discretization_config=helpers.get_discretization( + "LegacyDDPMDiscretization" + ), + guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), + ) output = helpers.do_sample( model=model, sampler=sampler, value_dict=value_dict, num_samples=1, - H=specs['H'], - W=specs['W'], - C=specs['C'], - F=specs['f'] + H=specs["H"], + W=specs["W"], + C=specs["C"], + F=specs["f"], ) assert output is not None @@ -114,32 +119,34 @@ def test_txt2img(self, model, sampler_name): @pytest.mark.parametrize("sampler_name", samplers) def test_img2img(self, model, sampler_name): specs = model[1] - model = model[0] - init_image = self.create_init_image(specs['H'], specs['W']).to('cuda') + model = model[0] + init_image = self.create_init_image(specs["H"], specs["W"]).to("cuda") value_dict = { "prompt": "A professional photograph of an astronaut riding a pig", "negative_prompt": "", "aesthetic_score": 6.0, "negative_aesthetic_score": 2.5, - "orig_height": specs['H'], - "orig_width": specs['W'], - "target_height": specs['H'], - "target_width": specs['W'], + "orig_height": specs["H"], + "orig_width": specs["W"], + "target_height": specs["H"], + "target_width": specs["W"], "crop_coords_top": 0, - "crop_coords_left": 0 + "crop_coords_left": 0, } - sampler = helpers.get_sampler(sampler_name=sampler_name, - steps=10, - discretization_config=helpers.get_discretization("LegacyDDPMDiscretization"), - guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), - ) - + sampler = helpers.get_sampler( + sampler_name=sampler_name, + steps=10, + discretization_config=helpers.get_discretization( + "LegacyDDPMDiscretization" + ), + guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), + ) output = helpers.do_img2img( img=init_image, model=model, sampler=sampler, value_dict=value_dict, - num_samples=1 - ) \ No newline at end of file + num_samples=1, + ) From 0cedb25f776e6edefd5681e0a6931e5d44555e2b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 23 Jul 2023 20:44:38 -0700 Subject: [PATCH 15/46] Fix typo and add refiner helper --- sgm/inference/helpers.py | 43 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 5a4a2941c..953e77157 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -67,7 +67,7 @@ def __call__(self, image: torch.Tensor): WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) def load_model_from_config(config, ckpt=None, verbose=True): @@ -110,7 +110,7 @@ def get_unique_embedder_keys_from_conditioner(conditioner): def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) + samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( @@ -409,6 +409,45 @@ def get_input_image_tensor(image: Image, device="cuda"): return image.to(device) +def apply_refiner( + input, + model, + sampler, + num_samples, + prompt, + negative_prompt, + filter=None, +): + init_dict = { + "orig_width": input.shape[3] * 8, + "orig_height": input.shape[2] * 8, + "target_width": input.shape[3] * 8, + "target_height": input.shape[2] * 8, + } + + value_dict = init_dict + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + + value_dict["crop_coords_top"] = 0 + value_dict["crop_coords_left"] = 0 + + value_dict["aesthetic_score"] = 6.0 + value_dict["negative_aesthetic_score"] = 2.5 + + samples = do_img2img( + input, + model, + sampler, + value_dict, + num_samples, + skip_encode=True, + filter=filter, + ) + + return samples + + @torch.no_grad() def do_img2img( img, From 9f27cc67eac79a720849b735a33e2a74818b870b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 05:06:10 -0700 Subject: [PATCH 16/46] use a shared path loaded from a secret for checkpoints source --- .github/workflows/test-inference.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 76204c75d..1d65def64 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -1,21 +1,21 @@ name: Test inference -on: - pull_request: +on: + pull_request: push: branches: - main jobs: - test: - name: "Test inference" + test: + name: "Test inference" # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the envrionment if: github.repository == 'stability-ai/generative-models' runs-on: [self-hosted, slurm, g40] steps: - uses: actions/checkout@v3 - name: "Symlink checkpoints" - run: ln -s /admin/home-palp/sgm-checkpoints checkpoints + run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints - name: "Setup python" uses: actions/setup-python@v4 with: @@ -29,6 +29,6 @@ jobs: uses: pmeier/pytest-results-action@main with: path: test-results.xml - summary: true + summary: true display-options: fEX - fail-on-empty: true \ No newline at end of file + fail-on-empty: true From 2ebd30aeb76c03a8db045f29548a8bca2e2a1e36 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 05:06:55 -0700 Subject: [PATCH 17/46] typo fix --- .github/workflows/test-inference.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 1d65def64..9687d7ea5 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -9,7 +9,7 @@ on: jobs: test: name: "Test inference" - # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the envrionment + # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment if: github.repository == 'stability-ai/generative-models' runs-on: [self-hosted, slurm, g40] steps: From 8086691275488fbfb574314f7bc9e6420a5285a6 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 17:04:52 +0000 Subject: [PATCH 18/46] Use device from input and remove duplicated code --- sgm/inference/helpers.py | 46 +++++-------------------------- tests/inference/test_inference.py | 3 +- 2 files changed, 9 insertions(+), 40 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 953e77157..de5344e5c 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -69,40 +69,6 @@ def __call__(self, image: torch.Tensor): WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) - -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - else: - msg = None - - model.cuda() - model.eval() - return model, msg - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) @@ -280,6 +246,7 @@ def do_sample( batch2model_input: List = None, return_latents=False, filter=None, + device="cuda", ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] @@ -288,7 +255,7 @@ def do_sample( precision_scope = autocast with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope(device): with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( @@ -312,7 +279,7 @@ def do_sample( for k in c: if not k == "crossattn": c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) + lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) ) additional_model_inputs = {} @@ -320,7 +287,7 @@ def do_sample( additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") + randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): return model.denoiser( @@ -462,10 +429,11 @@ def do_img2img( skip_encode=False, filter=None, logger=None, + device="cuda" ): precision_scope = autocast with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope(device): with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), @@ -479,7 +447,7 @@ def do_img2img( ) for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) + c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 331029bac..414a1e1e6 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -5,6 +5,7 @@ from omegaconf import OmegaConf import torch +from sgm.util import load_model_from_config import sgm.inference.helpers as helpers VERSION2SPECS = { @@ -67,7 +68,7 @@ class TestInference: def model(self, request): specs = VERSION2SPECS[request.param] config = OmegaConf.load(specs["config"]) - model, _ = helpers.load_model_from_config(config, specs["ckpt"]) + model, _ = load_model_from_config(config, specs["ckpt"]) model.conditioner.half() model.model.half() yield model, specs From 2ac4c50941868e8900e6c520edb3e01dad7b1346 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 17:19:13 +0000 Subject: [PATCH 19/46] PR feedback --- sgm/inference/helpers.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index de5344e5c..4f92da6a7 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -5,13 +5,10 @@ import numpy as np import torch from PIL import Image -from einops import rearrange, repeat +from einops import rearrange from imwatermark import WatermarkEncoder -from omegaconf import OmegaConf, ListConfig +from omegaconf import ListConfig from torch import autocast -from torchvision import transforms -from torchvision.utils import make_grid -from safetensors.torch import load_file as load_safetensors from sgm.modules.diffusionmodules.sampling import ( EulerEDMSampler, @@ -22,8 +19,6 @@ LinearMultistepSampler, ) from sgm.util import append_dims -from sgm.util import instantiate_from_config - class WatermarkEmbedder: def __init__(self, watermark): @@ -70,7 +65,7 @@ def __call__(self, image: torch.Tensor): embed_watermark = WatermarkEmbedder(WATERMARK_BITS) def get_unique_embedder_keys_from_conditioner(conditioner): - return list(set([x.input_key for x in conditioner.embedders])) + return list({x.input_key for x in conditioner.embedders}) def perform_save_locally(save_path, samples): @@ -141,8 +136,8 @@ def get_discretization(discretization, **kwargs): "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": - sigma_min = kwargs.pop("sigma_min", 0.03) # 0.0292 - sigma_max = kwargs.pop("sigma_max", 14.61) # 14.6146 + sigma_min = kwargs.pop("sigma_min", 0.0292) + sigma_max = kwargs.pop("sigma_max", 14.6146) rho = kwargs.pop("rho", 3.0) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", @@ -385,23 +380,19 @@ def apply_refiner( negative_prompt, filter=None, ): - init_dict = { + value_dict = { "orig_width": input.shape[3] * 8, "orig_height": input.shape[2] * 8, "target_width": input.shape[3] * 8, "target_height": input.shape[2] * 8, + "prompt": prompt, + "negative_prompt": negative_prompt, + "crop_coords_top": 0, + "crop_coords_left": 0, + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5 } - value_dict = init_dict - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - - value_dict["crop_coords_top"] = 0 - value_dict["crop_coords_left"] = 0 - - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 - samples = do_img2img( input, model, @@ -424,7 +415,7 @@ def do_img2img( num_samples, force_uc_zero_embeddings=[], additional_kwargs={}, - offset_noise_level: int = 0.0, + offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, filter=None, From 81e1047098ee20327c9ec4c49845c9db1370ac2b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 17:26:35 +0000 Subject: [PATCH 20/46] fix call to load_model_from_config --- tests/inference/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 414a1e1e6..4ded6a3c4 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -68,7 +68,7 @@ class TestInference: def model(self, request): specs = VERSION2SPECS[request.param] config = OmegaConf.load(specs["config"]) - model, _ = load_model_from_config(config, specs["ckpt"]) + model = load_model_from_config(config, specs["ckpt"]) model.conditioner.half() model.model.half() yield model, specs From 5648dce8d11c7cddb3b99ae250b8389543d27329 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 25 Jul 2023 17:39:23 +0000 Subject: [PATCH 21/46] Move model to gpu --- tests/inference/test_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 4ded6a3c4..38e738fc3 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -69,6 +69,7 @@ def model(self, request): specs = VERSION2SPECS[request.param] config = OmegaConf.load(specs["config"]) model = load_model_from_config(config, specs["ckpt"]) + model.cuda() model.conditioner.half() model.model.half() yield model, specs From ed78819021a5df761b4a23adfe2c5b1493cfe745 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 00:48:11 +0000 Subject: [PATCH 22/46] Refactor helpers --- sgm/inference/api.py | 360 ++++++++++++++++++++++++++++++ sgm/inference/helpers.py | 170 +------------- tests/inference/test_inference.py | 151 ++----------- 3 files changed, 383 insertions(+), 298 deletions(-) create mode 100644 sgm/inference/api.py diff --git a/sgm/inference/api.py b/sgm/inference/api.py new file mode 100644 index 000000000..1cf8b0909 --- /dev/null +++ b/sgm/inference/api.py @@ -0,0 +1,360 @@ +from enum import Enum +from omegaconf import OmegaConf +from pydantic import BaseModel +import pathlib +from sgm.inference.helpers import ( + do_sample, + do_img2img, + Img2ImgDiscretizationWrapper, +) +from sgm.modules.diffusionmodules.sampling import ( + EulerEDMSampler, + HeunEDMSampler, + EulerAncestralSampler, + DPMPP2SAncestralSampler, + DPMPP2MSampler, + LinearMultistepSampler, +) +from sgm.util import load_model_from_config + + +class ModelIdentifier(str, Enum): + SD_2_1 = "StableDiffusion2.1" + SD_2_1_768 = "StableDiffusion2.1-768" + SDXL_BASE = "SDXLBase" + SDXL_REFINER = "SDXLRefiner" + + +class Sampler(str, Enum): + EULER_EDM = "EulerEDMSampler" + HEUN_EDM = "HeunEDMSampler" + EULER_ANCESTRAL = "EulerAncestralSampler" + DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" + DPMPP2M = "DPMPP2MSampler" + LINEAR_MULTISTEP = "LinearMultistepSampler" + + +class Discretization(str, Enum): + LEGACY_DDPM = "LegacyDDPMDiscretization" + EDM = "EDMDiscretization" + + +class Guider(str, Enum): + VANILLA = "VanillaCFG" + IDENTITY = "IdentityGuider" + + +class Thresholder(str, Enum): + NONE = "None" + + +class SamplingParams(BaseModel): + width: int = 1024 + height: int = 1024 + steps: int = 50 + sampler: Sampler = Sampler.DPMPP2M + discretization: Discretization = Discretization.LEGACY_DDPM + guider: Guider = Guider.VANILLA + thresholder: Thresholder = Thresholder.NONE + scale: float = 6.0 + aesthetic_score: float = 5.0 + negative_aesthetic_score: float = 5.0 + img2img_strength: float = 1.0 + orig_width: int = 1024 + orig_height: int = 1024 + crop_coords_top: int = 0 + crop_coords_left: int = 0 + sigma_min: float = 0.0292 + sigma_max: float = 14.6146 + rho: float = 3.0 + s_churn: float = 0.0 + s_tmin: float = 0.0 + s_tmax: float = 999.0 + s_noise: float = 1.0 + eta: float = 1.0 + order: int = 4 + + +class SamplingSpec(BaseModel): + width: int + height: int + channels: int + factor: int + is_legacy: bool + config: str + ckpt: str + is_guided: bool + + +model_specs = { + ModelIdentifier.SD_2_1: SamplingSpec( + height=512, + width=512, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1.yaml", + ckpt="v2-1_512-ema-pruned.safetensors", + is_guided=True, + ), + ModelIdentifier.SD_2_1_768: SamplingSpec( + height=768, + width=768, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1_768.yaml", + ckpt="v2-1_768-ema-pruned.safetensors", + is_guided=True, + ), + ModelIdentifier.SDXL_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_0.9.safetensors", + is_guided=True, + ), + ModelIdentifier.SDXL_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_0.9.safetensors", + is_guided=True, + ), +} + + +class SamplingPipeline: + def __init__( + self, + model_id: ModelIdentifier, + model_path="checkpoints", + config_path="configs/inference", + device="cuda", + ) -> None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.model_id = model_id + self.specs = model_specs[self.model_id] + self.specs.config = str(pathlib.Path(config_path, self.specs.config)) + self.specs.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + self.device = device + self.model = self._load_model() + + def _load_model(self, device="cuda"): + config = OmegaConf.load(self.specs.config) + model = load_model_from_config(config, self.specs.ckpt) + model.to(device) + model.conditioner.half() + model.model.half() + return model + + def text_to_image( + self, + params: SamplingParams, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = dict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + return do_sample( + self.model, + sampler, + value_dict, + samples, + params.height, + params.width, + self.specs.channels, + self.specs.factor, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def image_to_image( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + + if params.img2img_strength < 1.0: + sampler.discretization = Img2ImgDiscretizationWrapper( + sampler.discretization, + strength=params.img2img_strength, + ) + + value_dict = dict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def refiner( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: str = None, + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = { + "orig_width": image.shape[3] * 8, + "orig_height": image.shape[2] * 8, + "target_width": image.shape[3] * 8, + "target_height": image.shape[2] * 8, + "prompt": prompt, + "negative_prompt": negative_prompt, + "crop_coords_top": 0, + "crop_coords_left": 0, + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + } + + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + skip_encode=True, + return_latents=return_latents, + filter=None, + ) + +def get_guider_config(params:SamplingParams): + if params.guider == Guider.IDENTITY: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif params.guider == Guider.VANILLA: + scale = params.scale + + thresholder = params.thresholder + + if thresholder == Thresholder.NONE: + dyn_thresh_config = { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, + } + else: + raise NotImplementedError + return guider_config + + +def get_discretization_config(params: SamplingParams): + if params.discretization == Discretization.LEGACY_DDPM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif params.discretization == Discretization.EDM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": params.sigma_min, + "sigma_max": params.sigma_max, + "rho": params.rho, + }, + } + else: + raise ValueError(f"unknown discertization {params.discretization}") + return discretization_config + + +def get_sampler_config(params: SamplingParams): + discretization_config = get_discretization_config(params) + guider_config = get_guider_config(params) + if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: + if params.sampler == Sampler.EULER_EDM: + sampler = EulerEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.HEUN_EDM: + sampler = HeunEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + elif ( + params.sampler == Sampler.EULER_ANCESTRAL or params.sampler == Sampler.DPMPP2S_ANCESTRAL + ): + if params.sampler == Sampler.EULER_ANCESTRAL: + sampler = EulerAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.DPMPP2S_ANCESTRAL: + sampler = DPMPP2SAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.DPMPP2M: + sampler = DPMPP2MSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=True, + ) + elif params.sampler == Sampler.LINEAR_MULTISTEP: + sampler = LinearMultistepSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=params.order, + verbose=True, + ) + else: + raise ValueError(f"unknown sampler {sampler}!") + + return sampler diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 4f92da6a7..ddba67832 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -10,14 +10,6 @@ from omegaconf import ListConfig from torch import autocast -from sgm.modules.diffusionmodules.sampling import ( - EulerEDMSampler, - HeunEDMSampler, - EulerAncestralSampler, - DPMPP2SAncestralSampler, - DPMPP2MSampler, - LinearMultistepSampler, -) from sgm.util import append_dims class WatermarkEmbedder: @@ -104,130 +96,6 @@ def __call__(self, *args, **kwargs): return sigmas -def get_guider(guider, **kwargs): - if guider == "IdentityGuider": - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } - elif guider == "VanillaCFG": - scale = max(0.0, min(100.0, kwargs.pop("scale", 5.0))) - - thresholder = kwargs.pop("thresholder", "None") - - if thresholder == "None": - dyn_thresh_config = { - "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" - } - else: - raise NotImplementedError - - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", - "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, - } - else: - raise NotImplementedError - return guider_config - - -def get_discretization(discretization, **kwargs): - if discretization == "LegacyDDPMDiscretization": - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - } - elif discretization == "EDMDiscretization": - sigma_min = kwargs.pop("sigma_min", 0.0292) - sigma_max = kwargs.pop("sigma_max", 14.6146) - rho = kwargs.pop("rho", 3.0) - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", - "params": { - "sigma_min": sigma_min, - "sigma_max": sigma_max, - "rho": rho, - }, - } - else: - raise ValueError(f"unknown discertization {discretization}") - return discretization_config - - -def get_sampler(sampler_name, steps, discretization_config, guider_config, **kwargs): - if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": - s_churn = kwargs.pop("s_churn", 0.0) - s_tmin = kwargs.pop("s_tmin", 0.0) - s_tmax = kwargs.pop("s_tmax", 999.0) - s_noise = kwargs.pop("s_noise", 1.0) - - if sampler_name == "EulerEDMSampler": - sampler = EulerEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "HeunEDMSampler": - sampler = HeunEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif ( - sampler_name == "EulerAncestralSampler" - or sampler_name == "DPMPP2SAncestralSampler" - ): - s_noise = kwargs.pop("s_noise", 1.0) - eta = kwargs.pop("eta", 1.0) - - if sampler_name == "EulerAncestralSampler": - sampler = EulerAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2SAncestralSampler": - sampler = DPMPP2SAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2MSampler": - sampler = DPMPP2MSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - verbose=True, - ) - elif sampler_name == "LinearMultistepSampler": - order = kwargs.pop("order", 4) - sampler = LinearMultistepSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - order=order, - verbose=True, - ) - else: - raise ValueError(f"unknown sampler {sampler_name}!") - - return sampler - - def do_sample( model, sampler, @@ -370,42 +238,6 @@ def get_input_image_tensor(image: Image, device="cuda"): image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image.to(device) - -def apply_refiner( - input, - model, - sampler, - num_samples, - prompt, - negative_prompt, - filter=None, -): - value_dict = { - "orig_width": input.shape[3] * 8, - "orig_height": input.shape[2] * 8, - "target_width": input.shape[3] * 8, - "target_height": input.shape[2] * 8, - "prompt": prompt, - "negative_prompt": negative_prompt, - "crop_coords_top": 0, - "crop_coords_left": 0, - "aesthetic_score": 6.0, - "negative_aesthetic_score": 2.5 - } - - samples = do_img2img( - input, - model, - sampler, - value_dict, - num_samples, - skip_encode=True, - filter=filter, - ) - - return samples - - @torch.no_grad() def do_img2img( img, @@ -475,4 +307,4 @@ def denoiser(x, sigma, c): if return_latents: return samples, samples_z - return samples + return samples \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 38e738fc3..baba31025 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -6,74 +6,17 @@ import torch from sgm.util import load_model_from_config +from sgm.inference.api import model_specs, SamplingParams, SamplingPipeline, Sampler import sgm.inference.helpers as helpers -VERSION2SPECS = { - "SD-XL base": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", - "is_guided": True, - }, - "sd-2.1": { - "H": 512, - "W": 512, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1.yaml", - "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", - "is_guided": True, - }, - "sd-2.1-768": { - "H": 768, - "W": 768, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1_768.yaml", - "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", - }, - "SDXL-Refiner": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", - "is_guided": True, - }, -} - -samplers = [ - "EulerEDMSampler", - "HeunEDMSampler", - "EulerAncestralSampler", - "DPMPP2SAncestralSampler", - "DPMPP2MSampler", - "LinearMultistepSampler", -] - @pytest.mark.inference class TestInference: - @fixture( - scope="class", params=["SD-XL base", "sd-2.1", "sd-2.1-768", "SDXL-Refiner"] - ) - def model(self, request): - specs = VERSION2SPECS[request.param] - config = OmegaConf.load(specs["config"]) - model = load_model_from_config(config, specs["ckpt"]) - model.cuda() - model.conditioner.half() - model.model.half() - yield model, specs - del model + @fixture(scope="class", params=model_specs.keys()) + def pipeline(self, request) -> SamplingPipeline: + pipeline = SamplingPipeline(request.param) + yield pipeline + del pipeline torch.cuda.empty_cache() def create_init_image(self, h, w): @@ -81,74 +24,24 @@ def create_init_image(self, h, w): image = Image.fromarray(image_array.astype("uint8")).convert("RGB") return helpers.get_input_image_tensor(image) - @pytest.mark.parametrize("sampler_name", samplers) - def test_txt2img(self, model, sampler_name): - specs = model[1] - model = model[0] - value_dict = { - "prompt": "A professional photograph of an astronaut riding a pig", - "negative_prompt": "", - "aesthetic_score": 6.0, - "negative_aesthetic_score": 2.5, - "orig_height": specs["H"], - "orig_width": specs["W"], - "target_height": specs["H"], - "target_width": specs["W"], - "crop_coords_top": 0, - "crop_coords_left": 0, - } - sampler = helpers.get_sampler( - sampler_name=sampler_name, - steps=10, - discretization_config=helpers.get_discretization( - "LegacyDDPMDiscretization" - ), - guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), - ) - output = helpers.do_sample( - model=model, - sampler=sampler, - value_dict=value_dict, - num_samples=1, - H=specs["H"], - W=specs["W"], - C=specs["C"], - F=specs["f"], + @pytest.mark.parametrize("sampler_enum", Sampler) + def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): + output = pipeline.text_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, ) assert output is not None - @pytest.mark.parametrize("sampler_name", samplers) - def test_img2img(self, model, sampler_name): - specs = model[1] - model = model[0] - init_image = self.create_init_image(specs["H"], specs["W"]).to("cuda") - value_dict = { - "prompt": "A professional photograph of an astronaut riding a pig", - "negative_prompt": "", - "aesthetic_score": 6.0, - "negative_aesthetic_score": 2.5, - "orig_height": specs["H"], - "orig_width": specs["W"], - "target_height": specs["H"], - "target_width": specs["W"], - "crop_coords_top": 0, - "crop_coords_left": 0, - } - - sampler = helpers.get_sampler( - sampler_name=sampler_name, - steps=10, - discretization_config=helpers.get_discretization( - "LegacyDDPMDiscretization" - ), - guider_config=helpers.get_guider(guider="VanillaCFG", scale=7.0), - ) - - output = helpers.do_img2img( - img=init_image, - model=model, - sampler=sampler, - value_dict=value_dict, - num_samples=1, + @pytest.mark.parametrize("sampler_enum", Sampler) + def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): + output = pipeline.image_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, ) + assert output is not None From 9283e34a89981c4788b516c0a02d90c8f5b29b17 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 01:20:46 +0000 Subject: [PATCH 23/46] cleanup --- sgm/inference/api.py | 12 +++++++----- sgm/inference/helpers.py | 7 +++++-- tests/inference/test_inference.py | 2 -- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 1cf8b0909..b5477a8e4 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -2,7 +2,7 @@ from omegaconf import OmegaConf from pydantic import BaseModel import pathlib -from sgm.inference.helpers import ( +from sgm.inference.helpers import ( do_sample, do_img2img, Img2ImgDiscretizationWrapper, @@ -247,7 +247,8 @@ def refiner( filter=None, ) -def get_guider_config(params:SamplingParams): + +def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" @@ -319,8 +320,9 @@ def get_sampler_config(params: SamplingParams): verbose=True, ) elif ( - params.sampler == Sampler.EULER_ANCESTRAL or params.sampler == Sampler.DPMPP2S_ANCESTRAL - ): + params.sampler == Sampler.EULER_ANCESTRAL + or params.sampler == Sampler.DPMPP2S_ANCESTRAL + ): if params.sampler == Sampler.EULER_ANCESTRAL: sampler = EulerAncestralSampler( num_steps=params.steps, @@ -346,7 +348,7 @@ def get_sampler_config(params: SamplingParams): guider_config=guider_config, verbose=True, ) - elif params.sampler == Sampler.LINEAR_MULTISTEP: + elif params.sampler == Sampler.LINEAR_MULTISTEP: sampler = LinearMultistepSampler( num_steps=params.steps, discretization_config=discretization_config, diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index ddba67832..8e1131844 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -12,6 +12,7 @@ from sgm.util import append_dims + class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark @@ -56,6 +57,7 @@ def __call__(self, image: torch.Tensor): WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) + def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) @@ -238,6 +240,7 @@ def get_input_image_tensor(image: Image, device="cuda"): image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image.to(device) + @torch.no_grad() def do_img2img( img, @@ -252,7 +255,7 @@ def do_img2img( skip_encode=False, filter=None, logger=None, - device="cuda" + device="cuda", ): precision_scope = autocast with torch.no_grad(): @@ -307,4 +310,4 @@ def denoiser(x, sigma, c): if return_latents: return samples, samples_z - return samples \ No newline at end of file + return samples diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index baba31025..efdc25ffa 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -2,10 +2,8 @@ from PIL import Image import pytest from pytest import fixture -from omegaconf import OmegaConf import torch -from sgm.util import load_model_from_config from sgm.inference.api import model_specs, SamplingParams, SamplingPipeline, Sampler import sgm.inference.helpers as helpers From da771082c70ebdb5df5a08b9113ce79b4cc05aa8 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:36:54 +0000 Subject: [PATCH 24/46] test refiner, prep for 1.0, align with metadata --- sgm/inference/api.py | 60 ++++++++++++++++++++-------- sgm/inference/helpers.py | 16 ++++---- tests/inference/test_inference.py | 66 ++++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 25 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index b5477a8e4..752d9c849 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -16,13 +16,15 @@ LinearMultistepSampler, ) from sgm.util import load_model_from_config +from typing import Optional - -class ModelIdentifier(str, Enum): - SD_2_1 = "StableDiffusion2.1" - SD_2_1_768 = "StableDiffusion2.1-768" - SDXL_BASE = "SDXLBase" - SDXL_REFINER = "SDXLRefiner" +class ModelArchitecture(str, Enum): + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SDXL_V1_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" class Sampler(str, Enum): @@ -87,7 +89,7 @@ class SamplingSpec(BaseModel): model_specs = { - ModelIdentifier.SD_2_1: SamplingSpec( + ModelArchitecture.SD_2_1: SamplingSpec( height=512, width=512, channels=4, @@ -97,7 +99,7 @@ class SamplingSpec(BaseModel): ckpt="v2-1_512-ema-pruned.safetensors", is_guided=True, ), - ModelIdentifier.SD_2_1_768: SamplingSpec( + ModelArchitecture.SD_2_1_768: SamplingSpec( height=768, width=768, channels=4, @@ -107,7 +109,7 @@ class SamplingSpec(BaseModel): ckpt="v2-1_768-ema-pruned.safetensors", is_guided=True, ), - ModelIdentifier.SDXL_BASE: SamplingSpec( + ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, width=1024, channels=4, @@ -117,7 +119,7 @@ class SamplingSpec(BaseModel): ckpt="sd_xl_base_0.9.safetensors", is_guided=True, ), - ModelIdentifier.SDXL_REFINER: SamplingSpec( + ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, width=1024, channels=4, @@ -127,13 +129,33 @@ class SamplingSpec(BaseModel): ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, ), + ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_1.0.safetensors", + is_guided=True + ), + ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_0.9.safetensors", + is_guided=True, + ) } class SamplingPipeline: def __init__( self, - model_id: ModelIdentifier, + model_id: ModelArchitecture, model_path="checkpoints", config_path="configs/inference", device="cuda", @@ -150,6 +172,8 @@ def __init__( def _load_model(self, device="cuda"): config = OmegaConf.load(self.specs.config) model = load_model_from_config(config, self.specs.ckpt) + if model is None: + raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) model.conditioner.half() model.model.half() @@ -167,6 +191,8 @@ def text_to_image( value_dict = dict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = params.width + value_dict["target_height"] = params.height return do_sample( self.model, sampler, @@ -197,11 +223,12 @@ def image_to_image( sampler.discretization, strength=params.img2img_strength, ) - + height, width = image.shape[2], image.shape[3] value_dict = dict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt - + value_dict["target_width"] = width + value_dict["target_height"] = height return do_img2img( image, self.model, @@ -218,7 +245,7 @@ def refiner( params: SamplingParams, image, prompt: str, - negative_prompt: str = None, + negative_prompt: Optional[str] = None, samples: int = 1, return_latents: bool = False, ): @@ -296,6 +323,7 @@ def get_discretization_config(params: SamplingParams): def get_sampler_config(params: SamplingParams): discretization_config = get_discretization_config(params) guider_config = get_guider_config(params) + sampler = None if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: if params.sampler == Sampler.EULER_EDM: sampler = EulerEDMSampler( @@ -356,7 +384,7 @@ def get_sampler_config(params: SamplingParams): order=params.order, verbose=True, ) - else: - raise ValueError(f"unknown sampler {sampler}!") + if sampler is None: + raise ValueError(f"unknown sampler {params.sampler}!") return sampler diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 8e1131844..b5f30531f 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,5 +1,5 @@ import os -from typing import Union, List +from typing import Union, List, Optional import math import numpy as np @@ -107,8 +107,8 @@ def do_sample( W, C, F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, + force_uc_zero_embeddings: Optional[List] = None, + batch2model_input: Optional[List] = None, return_latents=False, filter=None, device="cuda", @@ -228,17 +228,17 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): return batch, batch_uc -def get_input_image_tensor(image: Image, device="cuda"): +def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - return image.to(device) + image_array = np.array(image.convert("RGB")) + image_array = image_array[None].transpose(0, 3, 1, 2) + image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 + return image_tensor.to(device) @torch.no_grad() diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index efdc25ffa..3d75a041f 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -3,8 +3,15 @@ import pytest from pytest import fixture import torch +from typing import Tuple -from sgm.inference.api import model_specs, SamplingParams, SamplingPipeline, Sampler +from sgm.inference.api import ( + model_specs, + SamplingParams, + SamplingPipeline, + Sampler, + ModelArchitecture, +) import sgm.inference.helpers as helpers @@ -17,6 +24,22 @@ def pipeline(self, request) -> SamplingPipeline: del pipeline torch.cuda.empty_cache() + @fixture( + scope="class", + params=[ + [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], + [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], + ], + ids=["SDXL_V1", "SDXL_V0_9"], + ) + def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: + base_pipeline = SamplingPipeline(request.param[0]) + refiner_pipeline = SamplingPipeline(request.param[1]) + yield base_pipeline, refiner_pipeline + del base_pipeline + del refiner_pipeline + torch.cuda.empty_cache() + def create_init_image(self, h, w): image_array = numpy.random.rand(h, w, 3) * 255 image = Image.fromarray(image_array.astype("uint8")).convert("RGB") @@ -43,3 +66,44 @@ def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): samples=1, ) assert output is not None + + @pytest.mark.parametrize("sampler_enum", Sampler) + @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) + def test_sdxl_with_refiner( + self, + sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], + sampler_enum, + use_init_image, + ): + base_pipeline, refiner_pipeline = sdxl_pipelines + if use_init_image: + output = base_pipeline.image_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=self.create_init_image( + base_pipeline.specs.height, base_pipeline.specs.width + ), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + return_latents=True, + ) + else: + output = base_pipeline.text_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + return_latents=True, + ) + + assert isinstance(output, (tuple, list)) + samples, samples_z = output + assert samples is not None + assert samples_z is not None + refiner_pipeline.refiner( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=samples_z, + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + ) From 7c26bdea624eef2a260225746589d5eaa2127364 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:19:49 +0000 Subject: [PATCH 25/46] fix paths on second load --- sgm/inference/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 752d9c849..f9487151c 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -164,14 +164,14 @@ def __init__( raise ValueError(f"Model {model_id} not supported") self.model_id = model_id self.specs = model_specs[self.model_id] - self.specs.config = str(pathlib.Path(config_path, self.specs.config)) - self.specs.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + self.config = str(pathlib.Path(config_path, self.specs.config)) + self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) self.device = device self.model = self._load_model() def _load_model(self, device="cuda"): - config = OmegaConf.load(self.specs.config) - model = load_model_from_config(config, self.specs.ckpt) + config = OmegaConf.load(self.config) + model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) From b094614b1d4f37ddd906bd1da9edbc8847ff0636 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:49:12 +0000 Subject: [PATCH 26/46] deduplicate streamlit code --- scripts/demo/sampling.py | 15 +- scripts/demo/streamlit_helpers.py | 343 +----------------------------- 2 files changed, 19 insertions(+), 339 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 98d0af30f..d64d49618 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,6 +1,7 @@ from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally SAVE_PATH = "outputs/demo/txt2img/" @@ -130,7 +131,9 @@ def run_txt2img( if st.button("Sample"): st.write(f"**Model I:** {version}") - out = do_sample( + outputs = outputs = st.empty() + st.text("Sampling") + samples = do_sample( state["model"], sampler, value_dict, @@ -143,6 +146,10 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) + grid = torch.stack([samples]) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) + return out @@ -174,6 +181,7 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): + st.text("Sampling") out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], @@ -183,6 +191,7 @@ def run_img2img( force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, + logger=st ) return out @@ -247,9 +256,7 @@ def apply_refiner( save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - state = init_st(version_dict) - if state["msg"]: - st.info(state["msg"]) + state = init_st(version_dict) model = state["model"] is_legacy = version_dict["is_legacy"] diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 8f53b5db3..35a068352 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -22,53 +22,10 @@ DPMPP2MSampler, LinearMultistepSampler, ) +from sgm.inference.helpers import Img2ImgDiscretizationWrapper from sgm.util import append_dims -from sgm.util import instantiate_from_config - - -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor): - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, C, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - # watermarking libary expects input as cv2 BGR format - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image - +from sgm.util import instantiate_from_config, load_model_from_config -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) @st.cache_resource() @@ -79,54 +36,17 @@ def init_st(version_dict, load_ckpt=True): ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model, msg = load_model_from_config(config, ckpt if load_ckpt else None) - - state["msg"] = msg + model = load_model_from_config(config, ckpt if load_ckpt else None) + model = model.to("cuda") + model.conditioner.half() + model.model.half() + state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config return state -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - global_step = pl_sd["global_step"] - st.info(f"loaded ckpt from global step {global_step}") - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - else: - msg = None - - model.cuda() - model.eval() - return model, msg - - -def get_unique_embedder_keys_from_conditioner(conditioner): - return list(set([x.input_key for x in conditioner.embedders])) - - def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future @@ -187,18 +107,6 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): return value_dict -def perform_save_locally(save_path, samples): - os.makedirs(os.path.join(save_path), exist_ok=True) - base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) - for sample in samples: - sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") - Image.fromarray(sample.astype(np.uint8)).save( - os.path.join(save_path, f"{base_count:09}.png") - ) - base_count += 1 - - def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: @@ -209,30 +117,6 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", @@ -452,215 +336,4 @@ def load_img(display=True, key=None): def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) - return init_image - - -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, - return_latents=False, - filter=None, -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = torch.stack([samples]) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - - if return_latents: - return samples, samples_z - return samples - - -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -@torch.no_grad() -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: int = 0.0, - return_latents=False, - skip_encode=False, - filter=None, -): - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - z = model.encode_first_stage(img) - noise = torch.randn_like(z) - sigmas = sampler.discretization(sampler.num_steps) - sigma = sigmas[0] - - st.info(f"all sigmas: {sigmas}") - st.info(f"noising sigma: {sigma}") - - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = embed_watemark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - if return_latents: - return samples, samples_z - return samples + return init_image \ No newline at end of file From 8ae88886d4dcf6ebef85ec2ab16a121471e2cfd6 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:49:34 +0000 Subject: [PATCH 27/46] filenames --- sgm/inference/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index f9487151c..9a7f431ff 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -136,7 +136,7 @@ class SamplingSpec(BaseModel): factor=8, is_legacy=False, config="sd_xl_base.yaml", - ckpt="sd_xl_base_1.0.safetensors", + ckpt="sd_xl_base_1.0-metadata.safetensors", is_guided=True ), ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( @@ -146,7 +146,7 @@ class SamplingSpec(BaseModel): factor=8, is_legacy=True, config="sd_xl_refiner.yaml", - ckpt="sd_xl_refiner_0.9.safetensors", + ckpt="sd_xl_refiner_1.0-metadata.safetensors", is_guided=True, ) } From 83d6c66a2b3a32dbaf68c795efdd20b68f829adb Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 05:02:54 +0000 Subject: [PATCH 28/46] fixes --- scripts/demo/sampling.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index d64d49618..f39ae1c4f 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,7 +1,7 @@ from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally +from sgm.inference.helpers import do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally, embed_watermark SAVE_PATH = "outputs/demo/txt2img/" @@ -146,11 +146,11 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) - grid = torch.stack([samples]) + grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) - return out + return samples def run_img2img( @@ -181,8 +181,9 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): + outputs = outputs = st.empty() st.text("Sampling") - out = do_img2img( + samples = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], sampler, @@ -193,7 +194,10 @@ def run_img2img( filter=filter, logger=st ) - return out + grid = embed_watermark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) + return samples def apply_refiner( From 89c74b25170b6cbdf370dd9c6104b00b42426568 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 05:18:39 +0000 Subject: [PATCH 29/46] add pydantic to requirements --- requirements_pt13.txt | 1 + requirements_pt2.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements_pt13.txt b/requirements_pt13.txt index 3d5b117c3..c1f6ffdc2 100644 --- a/requirements_pt13.txt +++ b/requirements_pt13.txt @@ -13,6 +13,7 @@ torchvision==0.14.1+cu117 torchmetrics opencv-python==4.6.0.66 fairscale +pydantic pytorch-lightning==1.8.5 fsspec kornia==0.6.9 diff --git a/requirements_pt2.txt b/requirements_pt2.txt index 9988b9084..47fec0d9e 100644 --- a/requirements_pt2.txt +++ b/requirements_pt2.txt @@ -13,6 +13,7 @@ torchmetrics torchvision>=0.15.2 opencv-python==4.6.0.66 fairscale +pydantic pytorch-lightning==2.0.1 fire fsspec From def2cd4b847a0304bc93a42a3f0ed018c19faaa5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 05:48:44 +0000 Subject: [PATCH 30/46] fix usage of `msg` in demo script --- scripts/demo/sampling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index f39ae1c4f..2934579f7 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -284,8 +284,7 @@ def apply_refiner( st.write("**Refiner Options:**") version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2) - st.info(state2["msg"]) + state2 = init_st(version_dict2) stage2strength = st.number_input( "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 From dd9a1a724d3a34e21f36424f77adc71bb51dc710 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 06:02:23 +0000 Subject: [PATCH 31/46] remove double text --- scripts/demo/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2934579f7..d7aeb8b63 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -131,7 +131,7 @@ def run_txt2img( if st.button("Sample"): st.write(f"**Model I:** {version}") - outputs = outputs = st.empty() + outputs = st.empty() st.text("Sampling") samples = do_sample( state["model"], @@ -181,7 +181,7 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): - outputs = outputs = st.empty() + outputs = st.empty() st.text("Sampling") samples = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), From 00b8f102e6c0fea39400d5f54b0d7e0f21b2f684 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 06:05:05 +0000 Subject: [PATCH 32/46] run black --- scripts/demo/sampling.py | 14 ++++++++++---- scripts/demo/streamlit_helpers.py | 5 ++--- sgm/inference/api.py | 7 ++++--- tests/inference/test_inference.py | 4 +++- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index d7aeb8b63..19a7e127f 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,7 +1,13 @@ from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally, embed_watermark +from sgm.inference.helpers import ( + do_img2img, + do_sample, + get_unique_embedder_keys_from_conditioner, + perform_save_locally, + embed_watermark, +) SAVE_PATH = "outputs/demo/txt2img/" @@ -192,7 +198,7 @@ def run_img2img( force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, - logger=st + logger=st, ) grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") @@ -260,7 +266,7 @@ def apply_refiner( save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - state = init_st(version_dict) + state = init_st(version_dict) model = state["model"] is_legacy = version_dict["is_legacy"] @@ -284,7 +290,7 @@ def apply_refiner( st.write("**Refiner Options:**") version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2) + state2 = init_st(version_dict2) stage2strength = st.number_input( "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 35a068352..4a1ad1405 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -27,7 +27,6 @@ from sgm.util import instantiate_from_config, load_model_from_config - @st.cache_resource() def init_st(version_dict, load_ckpt=True): state = dict() @@ -40,7 +39,7 @@ def init_st(version_dict, load_ckpt=True): model = model.to("cuda") model.conditioner.half() model.model.half() - + state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config @@ -336,4 +335,4 @@ def load_img(display=True, key=None): def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) - return init_image \ No newline at end of file + return init_image diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 9a7f431ff..b65b523ca 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,10 +18,11 @@ from sgm.util import load_model_from_config from typing import Optional + class ModelArchitecture(str, Enum): SD_2_1 = "stable-diffusion-v2-1" SD_2_1_768 = "stable-diffusion-v2-1-768" - SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" @@ -137,7 +138,7 @@ class SamplingSpec(BaseModel): is_legacy=False, config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0-metadata.safetensors", - is_guided=True + is_guided=True, ), ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( height=1024, @@ -148,7 +149,7 @@ class SamplingSpec(BaseModel): config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0-metadata.safetensors", is_guided=True, - ) + ), } diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 3d75a041f..2b2af11e4 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -68,7 +68,9 @@ def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) - @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) + @pytest.mark.parametrize( + "use_init_image", [True, False], ids=["img2img", "txt2img"] + ) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], From f77d9e546097eb5811ba5562c77ece9c67512ed7 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 06:32:38 +0000 Subject: [PATCH 33/46] fix streamlit sampling when returning latents --- scripts/demo/sampling.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 19a7e127f..341fba2a7 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -139,7 +139,7 @@ def run_txt2img( st.write(f"**Model I:** {version}") outputs = st.empty() st.text("Sampling") - samples = do_sample( + out = do_sample( state["model"], sampler, value_dict, @@ -152,11 +152,13 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) + if return_latents: + samples, _ = out grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) - return samples + return out def run_img2img( @@ -189,7 +191,7 @@ def run_img2img( if st.button("Sample"): outputs = st.empty() st.text("Sampling") - samples = do_img2img( + out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], sampler, @@ -200,10 +202,12 @@ def run_img2img( filter=filter, logger=st, ) + if return_latents: + samples, _ = out grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) - return samples + return out def apply_refiner( From 733dfb388383ee88cbe757362b2ebb091e00a7c7 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 07:01:36 +0000 Subject: [PATCH 34/46] extract function for streamlit output --- scripts/demo/sampling.py | 21 +++++++-------------- scripts/demo/streamlit_helpers.py | 22 ++++++++++------------ 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 341fba2a7..46448f512 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -5,8 +5,7 @@ do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, - perform_save_locally, - embed_watermark, + perform_save_locally, ) SAVE_PATH = "outputs/demo/txt2img/" @@ -152,11 +151,7 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) - if return_latents: - samples, _ = out - grid = embed_watermark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) + show_samples(samples, outputs) return out @@ -201,12 +196,8 @@ def run_img2img( return_latents=return_latents, filter=filter, logger=st, - ) - if return_latents: - samples, _ = out - grid = embed_watermark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) + ) + show_samples(samples, outputs) return out @@ -334,6 +325,7 @@ def apply_refiner( samples_z = None if add_pipeline and samples_z is not None: + outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, @@ -343,7 +335,8 @@ def apply_refiner( prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", filter=filter, - ) + ) + show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 4a1ad1405..f81578a9b 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,18 +1,11 @@ import os -from typing import Union, List - -import math -import numpy as np import streamlit as st import torch from PIL import Image from einops import rearrange, repeat -from imwatermark import WatermarkEncoder -from omegaconf import OmegaConf, ListConfig -from torch import autocast +from omegaconf import OmegaConf from torchvision import transforms -from torchvision.utils import make_grid -from safetensors.torch import load_file as load_safetensors + from sgm.modules.diffusionmodules.sampling import ( EulerEDMSampler, @@ -22,9 +15,8 @@ DPMPP2MSampler, LinearMultistepSampler, ) -from sgm.inference.helpers import Img2ImgDiscretizationWrapper -from sgm.util import append_dims -from sgm.util import instantiate_from_config, load_model_from_config +from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark +from sgm.util import load_model_from_config @st.cache_resource() @@ -115,6 +107,12 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path +def show_samples(samples, outputs): + if isinstance(samples, tuple): + samples, _ = samples + grid = embed_watermark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) def get_guider(key): guider = st.sidebar.selectbox( From 959a7ee7dea15b53ef84ea34e2235bca4ece4649 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 07:13:01 +0000 Subject: [PATCH 35/46] another fix for streamlit outputs --- scripts/demo/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 46448f512..6ff0bb90d 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -151,7 +151,7 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) - show_samples(samples, outputs) + show_samples(out, outputs) return out @@ -197,7 +197,7 @@ def run_img2img( filter=filter, logger=st, ) - show_samples(samples, outputs) + show_samples(out, outputs) return out From cca46d33da968f50ba52e516406dfd9bf42500c5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 07:47:10 +0000 Subject: [PATCH 36/46] fix img2img in streamlit --- scripts/demo/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 6ff0bb90d..711320201 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,3 +1,4 @@ +import numpy as np from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering From 31056b978bef6ec243c6c1a88d53a6cf880e4fd4 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 01:39:04 -0700 Subject: [PATCH 37/46] Make fp16 optional and fix device param --- sgm/inference/api.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index b65b523ca..c2e00210f 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -160,6 +160,7 @@ def __init__( model_path="checkpoints", config_path="configs/inference", device="cuda", + use_fp16=True, ) -> None: if model_id not in model_specs: raise ValueError(f"Model {model_id} not supported") @@ -168,16 +169,17 @@ def __init__( self.config = str(pathlib.Path(config_path, self.specs.config)) self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) self.device = device - self.model = self._load_model() + self.model = self._load_model(device=device, use_fp16=use_fp16) - def _load_model(self, device="cuda"): + def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) - model.conditioner.half() - model.model.half() + if use_fp16: + model.conditioner.half() + model.model.half() return model def text_to_image( @@ -317,7 +319,7 @@ def get_discretization_config(params: SamplingParams): }, } else: - raise ValueError(f"unknown discertization {params.discretization}") + raise ValueError(f"unknown discretization {params.discretization}") return discretization_config From 42d11ffa81c2775841c3f9bbd8aa07399eb6c810 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 02:24:21 -0700 Subject: [PATCH 38/46] PR feedback --- requirements_pt13.txt | 1 - requirements_pt2.txt | 1 - scripts/demo/sampling.py | 3 +-- sgm/inference/api.py | 8 +++++--- sgm/inference/helpers.py | 9 ++------- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/requirements_pt13.txt b/requirements_pt13.txt index c1f6ffdc2..3d5b117c3 100644 --- a/requirements_pt13.txt +++ b/requirements_pt13.txt @@ -13,7 +13,6 @@ torchvision==0.14.1+cu117 torchmetrics opencv-python==4.6.0.66 fairscale -pydantic pytorch-lightning==1.8.5 fsspec kornia==0.6.9 diff --git a/requirements_pt2.txt b/requirements_pt2.txt index 47fec0d9e..9988b9084 100644 --- a/requirements_pt2.txt +++ b/requirements_pt2.txt @@ -13,7 +13,6 @@ torchmetrics torchvision>=0.15.2 opencv-python==4.6.0.66 fairscale -pydantic pytorch-lightning==2.0.1 fire fsspec diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 711320201..c1e97f026 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -195,8 +195,7 @@ def run_img2img( num_samples, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, - filter=filter, - logger=st, + filter=filter ) show_samples(out, outputs) return out diff --git a/sgm/inference/api.py b/sgm/inference/api.py index c2e00210f..30de42138 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -1,6 +1,6 @@ +from dataclasses import dataclass from enum import Enum from omegaconf import OmegaConf -from pydantic import BaseModel import pathlib from sgm.inference.helpers import ( do_sample, @@ -51,7 +51,8 @@ class Thresholder(str, Enum): NONE = "None" -class SamplingParams(BaseModel): +@dataclass +class SamplingParams(): width: int = 1024 height: int = 1024 steps: int = 50 @@ -78,7 +79,8 @@ class SamplingParams(BaseModel): order: int = 4 -class SamplingSpec(BaseModel): +@dataclass +class SamplingSpec(): width: int height: int channels: int diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index b5f30531f..a21da10f2 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -253,8 +253,7 @@ def do_img2img( offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, - filter=None, - logger=None, + filter=None, device="cuda", ): precision_scope = autocast @@ -284,11 +283,7 @@ def do_img2img( noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) - - if logger is not None: - logger.info(f"all sigmas: {sigmas}") - logger.info(f"noising sigma: {sigma}") - + if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim From 733d38b994602edb14bdfcce5373d27fc4fe1ba0 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:05:28 -0700 Subject: [PATCH 39/46] fix dict cast for dataclass --- sgm/inference/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 30de42138..b02349ebb 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict from enum import Enum from omegaconf import OmegaConf import pathlib @@ -193,7 +193,7 @@ def text_to_image( return_latents: bool = False, ): sampler = get_sampler_config(params) - value_dict = dict(params) + value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = params.width @@ -229,7 +229,7 @@ def image_to_image( strength=params.img2img_strength, ) height, width = image.shape[2], image.shape[3] - value_dict = dict(params) + value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = width From ba60896b8b118a4a3159a47e5c62d514e08954cf Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:21:43 -0700 Subject: [PATCH 40/46] run black, update ci script --- pyproject.toml | 2 +- scripts/demo/sampling.py | 10 +++++----- scripts/demo/streamlit_helpers.py | 2 ++ sgm/inference/api.py | 4 ++-- sgm/inference/helpers.py | 4 ++-- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c65bc103..f9e311f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,6 @@ dependencies = [ [tool.hatch.envs.ci.scripts] test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", - "pip install -r requirements_pt2.txt", + "pip install -r requirements/pt2.txt", "pytest -v tests/inference/test_inference.py {args}", ] \ No newline at end of file diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index cbb311ae2..3f3e7072e 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -7,7 +7,7 @@ do_img2img, do_sample, get_unique_embedder_keys_from_conditioner, - perform_save_locally, + perform_save_locally, ) SAVE_PATH = "outputs/demo/txt2img/" @@ -196,8 +196,8 @@ def run_img2img( num_samples, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, - filter=filter - ) + filter=filter, + ) show_samples(out, outputs) return out @@ -336,8 +336,8 @@ def apply_refiner( prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", filter=filter, - ) - show_samples(samples, outputs) + ) + show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 64422047d..4b752a7a0 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -107,6 +107,7 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path + def show_samples(samples, outputs): if isinstance(samples, tuple): samples, _ = samples @@ -114,6 +115,7 @@ def show_samples(samples, outputs): grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) + def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", diff --git a/sgm/inference/api.py b/sgm/inference/api.py index b02349ebb..ca5189cb8 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -52,7 +52,7 @@ class Thresholder(str, Enum): @dataclass -class SamplingParams(): +class SamplingParams: width: int = 1024 height: int = 1024 steps: int = 50 @@ -80,7 +80,7 @@ class SamplingParams(): @dataclass -class SamplingSpec(): +class SamplingSpec: width: int height: int channels: int diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index a21da10f2..2b8822ebb 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -253,7 +253,7 @@ def do_img2img( offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, - filter=None, + filter=None, device="cuda", ): precision_scope = autocast @@ -283,7 +283,7 @@ def do_img2img( noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) - + if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim From f811542fe9e2fbe540ce8e77cb0d57e646dd8101 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:25:42 -0700 Subject: [PATCH 41/46] cache pip dependencies on hosted runners, remove extra runs --- .github/workflows/black.yml | 2 +- .github/workflows/test-build.yaml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 80823b44b..ab6526013 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -1,5 +1,5 @@ name: Run black -on: [push, pull_request] +on: [pull_request] jobs: lint: diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml index ffbeff460..224251762 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/test-build.yaml @@ -2,6 +2,7 @@ name: Build package on: push: + branches: [ main ] pull_request: jobs: @@ -19,6 +20,7 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + cache: pip - name: Install dependencies run: | python -m pip install --upgrade pip From 7e86a690f21ffba2e99ebb1a80060d4e4cc7704b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:32:05 -0700 Subject: [PATCH 42/46] install package in ci env --- pyproject.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9e311f50..93d00c8a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,9 +34,7 @@ include = [ "./configs" = "sgm/configs" [tool.hatch.envs.ci] -# Skip for now, since requirements.txt is used by scripts and includes the project -# This should be changed when dependencies are handled by Hatch -skip-install = true +skip-install = false dependencies = [ "pytest" @@ -45,6 +43,6 @@ dependencies = [ [tool.hatch.envs.ci.scripts] test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", - "pip install -r requirements/pt2.txt", + "pip install -r requirements/pt2.txt", "pytest -v tests/inference/test_inference.py {args}", ] \ No newline at end of file From 64b5c9d499e344a3240c7fc087e943c6340d03fa Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 03:33:19 -0700 Subject: [PATCH 43/46] fix cache path --- .github/workflows/test-build.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml index 224251762..fc3a66ae6 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/test-build.yaml @@ -20,7 +20,8 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: pip + cache: pip + cache-dependency-path: requirements/${{ matrix.requirements-file }}.txt - name: Install dependencies run: | python -m pip install --upgrade pip From f471736fa097b42bba138370ef8b412f34664696 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:00:04 -0700 Subject: [PATCH 44/46] PR cleanup --- sgm/inference/api.py | 85 +++++++++++++++++++--------------------- sgm/inference/helpers.py | 7 +--- 2 files changed, 42 insertions(+), 50 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ca5189cb8..4ffdb9419 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -329,51 +329,46 @@ def get_sampler_config(params: SamplingParams): discretization_config = get_discretization_config(params) guider_config = get_guider_config(params) sampler = None - if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: - if params.sampler == Sampler.EULER_EDM: - sampler = EulerEDMSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=params.s_churn, - s_tmin=params.s_tmin, - s_tmax=params.s_tmax, - s_noise=params.s_noise, - verbose=True, - ) - elif params.sampler == Sampler.HEUN_EDM: - sampler = HeunEDMSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=params.s_churn, - s_tmin=params.s_tmin, - s_tmax=params.s_tmax, - s_noise=params.s_noise, - verbose=True, - ) - elif ( - params.sampler == Sampler.EULER_ANCESTRAL - or params.sampler == Sampler.DPMPP2S_ANCESTRAL - ): - if params.sampler == Sampler.EULER_ANCESTRAL: - sampler = EulerAncestralSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=params.eta, - s_noise=params.s_noise, - verbose=True, - ) - elif params.sampler == Sampler.DPMPP2S_ANCESTRAL: - sampler = DPMPP2SAncestralSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=params.eta, - s_noise=params.s_noise, - verbose=True, - ) + if params.sampler == Sampler.EULER_EDM: + sampler = EulerEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.HEUN_EDM: + sampler = HeunEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.EULER_ANCESTRAL: + sampler = EulerAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + elif params.sampler == Sampler.DPMPP2S_ANCESTRAL: + sampler = DPMPP2SAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) elif params.sampler == Sampler.DPMPP2M: sampler = DPMPP2MSampler( num_steps=params.steps, diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 2b8822ebb..1c653708b 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -118,9 +118,8 @@ def do_sample( if batch2model_input is None: batch2model_input = [] - precision_scope = autocast with torch.no_grad(): - with precision_scope(device): + with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( @@ -241,7 +240,6 @@ def get_input_image_tensor(image: Image.Image, device="cuda"): return image_tensor.to(device) -@torch.no_grad() def do_img2img( img, model, @@ -256,9 +254,8 @@ def do_img2img( filter=None, device="cuda", ): - precision_scope = autocast with torch.no_grad(): - with precision_scope(device): + with autocast(device) as precision_scope: with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), From 197a074cebe5e96a91e08ccbf70c9847ce247048 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:03:46 -0700 Subject: [PATCH 45/46] one more cleanup --- sgm/inference/api.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 4ffdb9419..0635d112f 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -330,7 +330,7 @@ def get_sampler_config(params: SamplingParams): guider_config = get_guider_config(params) sampler = None if params.sampler == Sampler.EULER_EDM: - sampler = EulerEDMSampler( + return EulerEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, @@ -340,8 +340,8 @@ def get_sampler_config(params: SamplingParams): s_noise=params.s_noise, verbose=True, ) - elif params.sampler == Sampler.HEUN_EDM: - sampler = HeunEDMSampler( + if params.sampler == Sampler.HEUN_EDM: + return HeunEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, @@ -351,8 +351,8 @@ def get_sampler_config(params: SamplingParams): s_noise=params.s_noise, verbose=True, ) - elif params.sampler == Sampler.EULER_ANCESTRAL: - sampler = EulerAncestralSampler( + if params.sampler == Sampler.EULER_ANCESTRAL: + return EulerAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, @@ -360,8 +360,8 @@ def get_sampler_config(params: SamplingParams): s_noise=params.s_noise, verbose=True, ) - elif params.sampler == Sampler.DPMPP2S_ANCESTRAL: - sampler = DPMPP2SAncestralSampler( + if params.sampler == Sampler.DPMPP2S_ANCESTRAL: + return DPMPP2SAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, @@ -369,22 +369,20 @@ def get_sampler_config(params: SamplingParams): s_noise=params.s_noise, verbose=True, ) - elif params.sampler == Sampler.DPMPP2M: - sampler = DPMPP2MSampler( + if params.sampler == Sampler.DPMPP2M: + return DPMPP2MSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) - elif params.sampler == Sampler.LINEAR_MULTISTEP: - sampler = LinearMultistepSampler( + if params.sampler == Sampler.LINEAR_MULTISTEP: + return LinearMultistepSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, order=params.order, verbose=True, ) - if sampler is None: - raise ValueError(f"unknown sampler {params.sampler}!") - return sampler + raise ValueError(f"unknown sampler {params.sampler}!") From 30a7f79ef5382d162bde4e89b80569474d7066c2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 26 Jul 2023 04:05:14 -0700 Subject: [PATCH 46/46] don't cache, it filled up --- .github/workflows/test-build.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml index fc3a66ae6..8aabe3760 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/test-build.yaml @@ -20,8 +20,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: requirements/${{ matrix.requirements-file }}.txt - name: Install dependencies run: | python -m pip install --upgrade pip