From 10e3c19a139a349b4aac1f658e846a8842aeec3c Mon Sep 17 00:00:00 2001 From: cjkangme Date: Sat, 2 Nov 2024 16:14:48 +0900 Subject: [PATCH 1/7] [Fix] fix bugs of regional_prompting pipeline --- .../regional_prompting_stable_diffusion.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 8a022987ba9d..16fa70bbced2 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -9,7 +9,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import USE_PEFT_BACKEND try: @@ -119,7 +118,7 @@ def __call__( self.power = int(rp_args["power"]) if "power" in rp_args else 1 prompts = prompt if isinstance(prompt, list) else [prompt] - n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] + n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) @@ -225,8 +224,6 @@ def forward( residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -247,16 +244,15 @@ def forward( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -283,7 +279,7 @@ def forward( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) From fff00137ebcca59e03c364cbc9abdf412140e14f Mon Sep 17 00:00:00 2001 From: cjkangme Date: Sat, 2 Nov 2024 17:27:32 +0900 Subject: [PATCH 2/7] [Feat] add base prompt feature --- .../regional_prompting_stable_diffusion.py | 55 ++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 16fa70bbced2..aae3a617fe49 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -16,6 +16,7 @@ except ImportError: Compel = None +KBASE = "ADDBASE" KCOMM = "ADDCOMM" KBRK = "BREAK" @@ -33,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): Optional rp_args["save_mask"]: True/False (save masks in prompt mode) + rp_args["power"]: int (power for attention maps in prompt mode) + rp_args["base_ratio"]: + float (Sets the ratio of the base prompt) + ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT) + [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt) Pipeline for text-to-image generation using Stable Diffusion. @@ -109,20 +115,45 @@ def __call__( rp_args: Dict[str, str] = None, ): active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt + use_base = ( + KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt + ) if negative_prompt is None: negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) device = self._execution_device regions = 0 + self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0 self.power = int(rp_args["power"]) if "power" in rp_args else 1 prompts = prompt if isinstance(prompt, list) else [prompt] n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) + all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) + if use_base: + bases = prompts.copy() + n_bases = n_prompts.copy() + + for i, prompt in enumerate(prompts): + parts = prompt.split(KBASE) + if len(parts) == 2: + bases[i], prompts[i] = parts + elif len(parts) > 2: + raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}") + for i, prompt in enumerate(n_prompts): + n_parts = prompt.split(KBASE) + if len(n_parts) == 2: + n_bases[i], n_prompts[i] = n_parts + elif len(n_parts) > 2: + raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}") + + all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt) + all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt) + equal = len(all_prompts_cn) == len(all_n_prompts_cn) if Compel: @@ -136,8 +167,16 @@ def getcompelembs(prps): conds = getcompelembs(all_prompts_cn) unconds = getcompelembs(all_n_prompts_cn) - embs = getcompelembs(prompts) - n_embs = getcompelembs(n_prompts) + base_embs = getcompelembs(all_bases_cn) if use_base else None + base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None + # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts + embs = getcompelembs(prompts) if not use_base else base_embs + n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs + + if use_base and self.base_ratio > 0: + conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds + unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds + prompt = negative_prompt = None else: conds = self.encode_prompt(prompts, device, 1, True)[0] @@ -146,6 +185,18 @@ def getcompelembs(prps): if equal else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] ) + + if use_base and self.base_ratio > 0: + base_embs = self.encode_prompt(bases, device, 1, True)[0] + base_n_embs = ( + self.encode_prompt(n_bases, device, 1, True)[0] + if equal + else self.encode_prompt(all_n_bases_cn, device, 1, True)[0] + ) + + conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds + unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds + embs = n_embs = None if not active: From 0e34eac88feaf22368b1435fdae8831676231b02 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Mon, 4 Nov 2024 20:09:46 +0900 Subject: [PATCH 3/7] [Fix] fix __init__ pipeline error --- .../community/regional_prompting_stable_diffusion.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index aae3a617fe49..9de736c08904 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -3,7 +3,7 @@ import torch import torchvision.transforms.functional as FF -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers import StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -75,6 +75,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__( @@ -85,6 +86,7 @@ def __init__( scheduler, safety_checker, feature_extractor, + image_encoder, requires_safety_checker, ) self.register_modules( @@ -95,6 +97,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) @torch.no_grad() @@ -450,7 +453,7 @@ def hook_forwards(root_module: torch.nn.Module): ### Make prompt list for each regions -def promptsmaker(prompts, batch): +def promptsmaker(prompts, batch, use_base=False): out_p = [] plen = len(prompts) for prompt in prompts: @@ -459,7 +462,7 @@ def promptsmaker(prompts, batch): add, prompt = prompt.split(KCOMM) add = add + " " prompts = prompt.split(KBRK) - out_p.append([add + p for p in prompts]) + out_p.append([add + p if i != 0 else p for i, p in enumerate(prompts)]) out = [None] * batch * len(out_p[0]) * len(out_p) for p, prs in enumerate(out_p): # inputs prompts for r, pr in enumerate(prs): # prompts for regions @@ -496,7 +499,6 @@ def startend(cells, array): add = [] startend(add, inratios[1:]) icells.append(add) - return ocells, icells, sum(len(cell) for cell in icells) From b83e70fbbc78628b5d5b72652e386b171fa24e73 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Wed, 6 Nov 2024 17:21:21 +0900 Subject: [PATCH 4/7] [Fix] delete unused args --- examples/community/regional_prompting_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 9de736c08904..ef4197b67e69 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -35,7 +35,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): Optional rp_args["save_mask"]: True/False (save masks in prompt mode) rp_args["power"]: int (power for attention maps in prompt mode) - rp_args["base_ratio"]: + rp_args["base_ratio"]: float (Sets the ratio of the base prompt) ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT) [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt) @@ -453,7 +453,7 @@ def hook_forwards(root_module: torch.nn.Module): ### Make prompt list for each regions -def promptsmaker(prompts, batch, use_base=False): +def promptsmaker(prompts, batch): out_p = [] plen = len(prompts) for prompt in prompts: From ee804c459511f88e593451288d9b9265979c84b9 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Wed, 6 Nov 2024 19:08:34 +0900 Subject: [PATCH 5/7] [Fix] improve string handling --- .../community/regional_prompting_stable_diffusion.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index ef4197b67e69..bfcde82e02fd 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -134,9 +134,6 @@ def __call__( n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) - all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) - all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) - if use_base: bases = prompts.copy() n_bases = n_prompts.copy() @@ -157,6 +154,9 @@ def __call__( all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt) all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt) + all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) + all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) + equal = len(all_prompts_cn) == len(all_n_prompts_cn) if Compel: @@ -460,9 +460,9 @@ def promptsmaker(prompts, batch): add = "" if KCOMM in prompt: add, prompt = prompt.split(KCOMM) - add = add + " " - prompts = prompt.split(KBRK) - out_p.append([add + p if i != 0 else p for i, p in enumerate(prompts)]) + add = add.strip() + " " + prompts = [p.strip() for p in prompt.split(KBRK)] + out_p.append([add + p for i, p in enumerate(prompts)]) out = [None] * batch * len(out_p[0]) * len(out_p) for p, prs in enumerate(out_p): # inputs prompts for r, pr in enumerate(prs): # prompts for regions From 692247597ddee603f1f20992cb2465e98e6789db Mon Sep 17 00:00:00 2001 From: cjkangme Date: Fri, 8 Nov 2024 09:25:58 +0900 Subject: [PATCH 6/7] [Docs] docs to use_base in regional_prompting --- examples/community/README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/examples/community/README.md b/examples/community/README.md index 743993eb44c3..84ec6627aa39 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3220,6 +3220,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK best quality, 3persons in garden, an old man red suit ``` +### Use base prompt + +You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first. + +``` +2d animation style ADDBASE +masterpiece, high quality ADDCOMM +(blue sky)++ BREAK +green hair twintail BREAK +book shelf BREAK +messy desk BREAK +orange++ dress and sofa +``` + ### Negative prompt Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions. @@ -3250,6 +3264,7 @@ pipe(prompt=prompt, rp_args=rp_args) ### Optional Parameters - `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`. +- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT` The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed. From b054413e3d5c76f17faf93d25afeaae2b7a0ac46 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Sat, 9 Nov 2024 10:09:43 +0900 Subject: [PATCH 7/7] make style --- examples/community/regional_prompting_stable_diffusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index bfcde82e02fd..95f6cebb0190 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -118,9 +118,7 @@ def __call__( rp_args: Dict[str, str] = None, ): active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt - use_base = ( - KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt - ) + use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt if negative_prompt is None: negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)