Skip to content

Commit 8ec60b3

Browse files
committed
diffusers: restore prompt weighting feature
1 parent d121406 commit 8ec60b3

File tree

2 files changed

+18
-22
lines changed

2 files changed

+18
-22
lines changed

ldm/invoke/generator/diffusers_pipeline.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1212
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1313

14+
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
15+
1416

1517
@dataclass
1618
class PipelineIntermediateState:
@@ -76,6 +78,11 @@ def __init__(
7678
safety_checker=safety_checker,
7779
feature_extractor=feature_extractor,
7880
)
81+
# InvokeAI's interface for text embeddings and whatnot
82+
self.clip_embedder = WeightedFrozenCLIPEmbedder(
83+
tokenizer=self.tokenizer,
84+
transformer=self.text_encoder
85+
)
7986

8087
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
8188
r"""
@@ -312,27 +319,12 @@ def get_text_embeddings(self,
312319
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
313320
return text_embeddings
314321

315-
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
316-
fragment_weights=None, **kwargs):
322+
@torch.inference_mode()
323+
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
317324
"""
318325
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
319326
"""
320-
assert return_tokens == True
321-
if fragment_weights:
322-
weights = fragment_weights[0]
323-
if any(weight != 1.0 for weight in weights):
324-
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
325-
326-
if kwargs:
327-
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
328-
329-
text_fragments = c[0]
330-
text_input = self._tokenize(text_fragments)
331-
332-
with torch.inference_mode():
333-
token_ids = text_input.input_ids.to(self.text_encoder.device)
334-
text_embeddings = self.text_encoder(token_ids)[0]
335-
return text_embeddings, text_input.input_ids
327+
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
336328

337329
@torch.inference_mode()
338330
def _tokenize(self, prompt: Union[str, List[str]]):

ldm/modules/encoders/modules.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,17 @@ class FrozenCLIPEmbedder(AbstractEncoder):
240240
def __init__(
241241
self,
242242
version='openai/clip-vit-large-patch14',
243-
device=choose_torch_device(),
244243
max_length=77,
244+
tokenizer=None,
245+
transformer=None,
245246
):
246247
super().__init__()
247-
self.tokenizer = CLIPTokenizer.from_pretrained(
248+
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
248249
version, local_files_only=True
249250
)
250-
self.transformer = CLIPTextModel.from_pretrained(
251+
self.transformer = transformer or CLIPTextModel.from_pretrained(
251252
version, local_files_only=True
252253
)
253-
self.device = device
254254
self.max_length = max_length
255255
self.freeze()
256256

@@ -456,6 +456,10 @@ def forward(self, text, **kwargs):
456456
def encode(self, text, **kwargs):
457457
return self(text, **kwargs)
458458

459+
@property
460+
def device(self):
461+
return self.transformer.device
462+
459463
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
460464

461465
fragment_weights_key = "fragment_weights"

0 commit comments

Comments
 (0)