|
11 | 11 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
12 | 12 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
13 | 13 |
|
| 14 | +from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder |
| 15 | + |
14 | 16 |
|
15 | 17 | @dataclass
|
16 | 18 | class PipelineIntermediateState:
|
@@ -76,6 +78,11 @@ def __init__(
|
76 | 78 | safety_checker=safety_checker,
|
77 | 79 | feature_extractor=feature_extractor,
|
78 | 80 | )
|
| 81 | + # InvokeAI's interface for text embeddings and whatnot |
| 82 | + self.clip_embedder = WeightedFrozenCLIPEmbedder( |
| 83 | + tokenizer=self.tokenizer, |
| 84 | + transformer=self.text_encoder |
| 85 | + ) |
79 | 86 |
|
80 | 87 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
81 | 88 | r"""
|
@@ -312,27 +319,12 @@ def get_text_embeddings(self,
|
312 | 319 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
313 | 320 | return text_embeddings
|
314 | 321 |
|
315 |
| - def get_learned_conditioning(self, c: List[List[str]], return_tokens=True, |
316 |
| - fragment_weights=None, **kwargs): |
| 322 | + @torch.inference_mode() |
| 323 | + def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): |
317 | 324 | """
|
318 | 325 | Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
319 | 326 | """
|
320 |
| - assert return_tokens == True |
321 |
| - if fragment_weights: |
322 |
| - weights = fragment_weights[0] |
323 |
| - if any(weight != 1.0 for weight in weights): |
324 |
| - warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2) |
325 |
| - |
326 |
| - if kwargs: |
327 |
| - warnings.warn(f"unsupported args {kwargs}", stacklevel=2) |
328 |
| - |
329 |
| - text_fragments = c[0] |
330 |
| - text_input = self._tokenize(text_fragments) |
331 |
| - |
332 |
| - with torch.inference_mode(): |
333 |
| - token_ids = text_input.input_ids.to(self.text_encoder.device) |
334 |
| - text_embeddings = self.text_encoder(token_ids)[0] |
335 |
| - return text_embeddings, text_input.input_ids |
| 327 | + return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights) |
336 | 328 |
|
337 | 329 | @torch.inference_mode()
|
338 | 330 | def _tokenize(self, prompt: Union[str, List[str]]):
|
|
0 commit comments