Skip to content

Commit 7904d0c

Browse files
committed
diffusers: restore prompt weighting feature
1 parent d121406 commit 7904d0c

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

ldm/invoke/generator/diffusers_pipeline.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import secrets
2-
import warnings
32
from dataclasses import dataclass
43
from typing import List, Optional, Union, Callable
54

@@ -11,6 +10,8 @@
1110
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1211
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1312

13+
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
14+
1415

1516
@dataclass
1617
class PipelineIntermediateState:
@@ -76,6 +77,11 @@ def __init__(
7677
safety_checker=safety_checker,
7778
feature_extractor=feature_extractor,
7879
)
80+
# InvokeAI's interface for text embeddings and whatnot
81+
self.clip_embedder = WeightedFrozenCLIPEmbedder(
82+
tokenizer=self.tokenizer,
83+
transformer=self.text_encoder
84+
)
7985

8086
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
8187
r"""
@@ -312,27 +318,12 @@ def get_text_embeddings(self,
312318
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
313319
return text_embeddings
314320

315-
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
316-
fragment_weights=None, **kwargs):
321+
@torch.inference_mode()
322+
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
317323
"""
318324
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
319325
"""
320-
assert return_tokens == True
321-
if fragment_weights:
322-
weights = fragment_weights[0]
323-
if any(weight != 1.0 for weight in weights):
324-
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
325-
326-
if kwargs:
327-
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
328-
329-
text_fragments = c[0]
330-
text_input = self._tokenize(text_fragments)
331-
332-
with torch.inference_mode():
333-
token_ids = text_input.input_ids.to(self.text_encoder.device)
334-
text_embeddings = self.text_encoder(token_ids)[0]
335-
return text_embeddings, text_input.input_ids
326+
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
336327

337328
@torch.inference_mode()
338329
def _tokenize(self, prompt: Union[str, List[str]]):

ldm/modules/encoders/modules.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,17 @@ class FrozenCLIPEmbedder(AbstractEncoder):
240240
def __init__(
241241
self,
242242
version='openai/clip-vit-large-patch14',
243-
device=choose_torch_device(),
244243
max_length=77,
244+
tokenizer=None,
245+
transformer=None,
245246
):
246247
super().__init__()
247-
self.tokenizer = CLIPTokenizer.from_pretrained(
248+
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
248249
version, local_files_only=True
249250
)
250-
self.transformer = CLIPTextModel.from_pretrained(
251+
self.transformer = transformer or CLIPTextModel.from_pretrained(
251252
version, local_files_only=True
252253
)
253-
self.device = device
254254
self.max_length = max_length
255255
self.freeze()
256256

@@ -456,6 +456,10 @@ def forward(self, text, **kwargs):
456456
def encode(self, text, **kwargs):
457457
return self(text, **kwargs)
458458

459+
@property
460+
def device(self):
461+
return self.transformer.device
462+
459463
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
460464

461465
fragment_weights_key = "fragment_weights"

0 commit comments

Comments
 (0)