Skip to content

Commit 69c83d6

Browse files
cjkangmesayakpaul
andauthored
[Community Pipeline] Add some feature for regional prompting pipeline (#9874)
* [Fix] fix bugs of regional_prompting pipeline * [Feat] add base prompt feature * [Fix] fix __init__ pipeline error * [Fix] delete unused args * [Fix] improve string handling * [Docs] docs to use_base in regional_prompting * make style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent e44fc75 commit 69c83d6

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

examples/community/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
33793379
best quality, 3persons in garden, an old man red suit
33803380
```
33813381

3382+
### Use base prompt
3383+
3384+
You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
3385+
3386+
```
3387+
2d animation style ADDBASE
3388+
masterpiece, high quality ADDCOMM
3389+
(blue sky)++ BREAK
3390+
green hair twintail BREAK
3391+
book shelf BREAK
3392+
messy desk BREAK
3393+
orange++ dress and sofa
3394+
```
3395+
33823396
### Negative prompt
33833397

33843398
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
@@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args)
34093423
### Optional Parameters
34103424

34113425
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
3426+
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
34123427

34133428
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
34143429

examples/community/regional_prompting_stable_diffusion.py

+63-16
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33

44
import torch
55
import torchvision.transforms.functional as FF
6-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
77

88
from diffusers import StableDiffusionPipeline
99
from diffusers.models import AutoencoderKL, UNet2DConditionModel
1010
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1111
from diffusers.schedulers import KarrasDiffusionSchedulers
12-
from diffusers.utils import USE_PEFT_BACKEND
1312

1413

1514
try:
1615
from compel import Compel
1716
except ImportError:
1817
Compel = None
1918

19+
KBASE = "ADDBASE"
2020
KCOMM = "ADDCOMM"
2121
KBRK = "BREAK"
2222

@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
3434
3535
Optional
3636
rp_args["save_mask"]: True/False (save masks in prompt mode)
37+
rp_args["power"]: int (power for attention maps in prompt mode)
38+
rp_args["base_ratio"]:
39+
float (Sets the ratio of the base prompt)
40+
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
41+
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
3742
3843
Pipeline for text-to-image generation using Stable Diffusion.
3944
@@ -70,6 +75,7 @@ def __init__(
7075
scheduler: KarrasDiffusionSchedulers,
7176
safety_checker: StableDiffusionSafetyChecker,
7277
feature_extractor: CLIPImageProcessor,
78+
image_encoder: CLIPVisionModelWithProjection = None,
7379
requires_safety_checker: bool = True,
7480
):
7581
super().__init__(
@@ -80,6 +86,7 @@ def __init__(
8086
scheduler,
8187
safety_checker,
8288
feature_extractor,
89+
image_encoder,
8390
requires_safety_checker,
8491
)
8592
self.register_modules(
@@ -90,6 +97,7 @@ def __init__(
9097
scheduler=scheduler,
9198
safety_checker=safety_checker,
9299
feature_extractor=feature_extractor,
100+
image_encoder=image_encoder,
93101
)
94102

95103
@torch.no_grad()
@@ -110,17 +118,40 @@ def __call__(
110118
rp_args: Dict[str, str] = None,
111119
):
112120
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
121+
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
113122
if negative_prompt is None:
114123
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
115124

116125
device = self._execution_device
117126
regions = 0
118127

128+
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
119129
self.power = int(rp_args["power"]) if "power" in rp_args else 1
120130

121131
prompts = prompt if isinstance(prompt, list) else [prompt]
122-
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
132+
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
123133
self.batch = batch = num_images_per_prompt * len(prompts)
134+
135+
if use_base:
136+
bases = prompts.copy()
137+
n_bases = n_prompts.copy()
138+
139+
for i, prompt in enumerate(prompts):
140+
parts = prompt.split(KBASE)
141+
if len(parts) == 2:
142+
bases[i], prompts[i] = parts
143+
elif len(parts) > 2:
144+
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
145+
for i, prompt in enumerate(n_prompts):
146+
n_parts = prompt.split(KBASE)
147+
if len(n_parts) == 2:
148+
n_bases[i], n_prompts[i] = n_parts
149+
elif len(n_parts) > 2:
150+
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
151+
152+
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
153+
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
154+
124155
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
125156
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
126157

@@ -137,8 +168,16 @@ def getcompelembs(prps):
137168

138169
conds = getcompelembs(all_prompts_cn)
139170
unconds = getcompelembs(all_n_prompts_cn)
140-
embs = getcompelembs(prompts)
141-
n_embs = getcompelembs(n_prompts)
171+
base_embs = getcompelembs(all_bases_cn) if use_base else None
172+
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
173+
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
174+
embs = getcompelembs(prompts) if not use_base else base_embs
175+
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
176+
177+
if use_base and self.base_ratio > 0:
178+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
179+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
180+
142181
prompt = negative_prompt = None
143182
else:
144183
conds = self.encode_prompt(prompts, device, 1, True)[0]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
147186
if equal
148187
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
149188
)
189+
190+
if use_base and self.base_ratio > 0:
191+
base_embs = self.encode_prompt(bases, device, 1, True)[0]
192+
base_n_embs = (
193+
self.encode_prompt(n_bases, device, 1, True)[0]
194+
if equal
195+
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
196+
)
197+
198+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
199+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
200+
150201
embs = n_embs = None
151202

152203
if not active:
@@ -225,8 +276,6 @@ def forward(
225276

226277
residual = hidden_states
227278

228-
args = () if USE_PEFT_BACKEND else (scale,)
229-
230279
if attn.spatial_norm is not None:
231280
hidden_states = attn.spatial_norm(hidden_states, temb)
232281

@@ -247,16 +296,15 @@ def forward(
247296
if attn.group_norm is not None:
248297
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
249298

250-
args = () if USE_PEFT_BACKEND else (scale,)
251-
query = attn.to_q(hidden_states, *args)
299+
query = attn.to_q(hidden_states)
252300

253301
if encoder_hidden_states is None:
254302
encoder_hidden_states = hidden_states
255303
elif attn.norm_cross:
256304
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
257305

258-
key = attn.to_k(encoder_hidden_states, *args)
259-
value = attn.to_v(encoder_hidden_states, *args)
306+
key = attn.to_k(encoder_hidden_states)
307+
value = attn.to_v(encoder_hidden_states)
260308

261309
inner_dim = key.shape[-1]
262310
head_dim = inner_dim // attn.heads
@@ -283,7 +331,7 @@ def forward(
283331
hidden_states = hidden_states.to(query.dtype)
284332

285333
# linear proj
286-
hidden_states = attn.to_out[0](hidden_states, *args)
334+
hidden_states = attn.to_out[0](hidden_states)
287335
# dropout
288336
hidden_states = attn.to_out[1](hidden_states)
289337

@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
410458
add = ""
411459
if KCOMM in prompt:
412460
add, prompt = prompt.split(KCOMM)
413-
add = add + " "
414-
prompts = prompt.split(KBRK)
415-
out_p.append([add + p for p in prompts])
461+
add = add.strip() + " "
462+
prompts = [p.strip() for p in prompt.split(KBRK)]
463+
out_p.append([add + p for i, p in enumerate(prompts)])
416464
out = [None] * batch * len(out_p[0]) * len(out_p)
417465
for p, prs in enumerate(out_p): # inputs prompts
418466
for r, pr in enumerate(prs): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
449497
add = []
450498
startend(add, inratios[1:])
451499
icells.append(add)
452-
453500
return ocells, icells, sum(len(cell) for cell in icells)
454501

455502

0 commit comments

Comments
 (0)