diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b4b0f4bb3bd6..f41d0ffe72e3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -436,7 +436,10 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): return prompt def load_textual_inversion( - self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs + self, + pretrained_model_name_or_path: Union[str, List[str]], + token: Optional[Union[str, List[str]]] = None, + **kwargs, ): r""" Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and @@ -449,7 +452,7 @@ def load_textual_inversion( Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): + pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. @@ -457,6 +460,12 @@ def load_textual_inversion( `"sd-concepts-library/low-poly-hd-logos-icons"`. - A path to a *directory* containing textual inversion weights, e.g. `./my_text_inversion_directory/`. + - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`. + + Or a list of those elements. + token (`str` or `List[str]`, *optional*): + Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a + list, then `token` must also be a list of equal length. weight_name (`str`, *optional*): Name of a custom weight file. This should be used in two cases: @@ -576,16 +585,62 @@ def load_textual_inversion( "framework": "pytorch", } - # 1. Load textual inversion file - model_file = None - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_paths = [pretrained_model_name_or_path] + else: + pretrained_model_name_or_paths = pretrained_model_name_or_path + + if isinstance(token, str): + tokens = [token] + elif token is None: + tokens = [None] * len(pretrained_model_name_or_paths) + else: + tokens = token + + if len(pretrained_model_name_or_paths) != len(tokens): + raise ValueError( + f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" + f"Make sure both lists have the same length." + ) + + valid_tokens = [t for t in tokens if t is not None] + if len(set(valid_tokens)) < len(valid_tokens): + raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") + + token_ids_and_embeddings = [] + + for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + weights_name=weight_name or TEXT_INVERSION_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -596,88 +651,68 @@ def load_textual_inversion( subfolder=subfolder, user_agent=user_agent, ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except Exception as e: - if not allow_pickle: - raise e + state_dict = torch.load(model_file, map_location="cpu") - model_file = None + # 2. Load token and embedding correcly from file + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + + if token is not None and loaded_token != token: + logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) - # 2. Load token and embedding correcly from file - if isinstance(state_dict, torch.Tensor): - if token is None: + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token in vocab: raise ValueError( - "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." ) - embedding = state_dict - elif len(state_dict) == 1: - # diffusers - loaded_token, embedding = next(iter(state_dict.items())) - elif "string_to_param" in state_dict: - # A1111 - loaded_token = state_dict["name"] - embedding = state_dict["string_to_param"]["*"] - - if token is not None and loaded_token != token: - logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") - else: - token = loaded_token - - embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 - # 3. Make sure we don't mess up the tokenizer or text encoder - vocab = self.tokenizer.get_vocab() - if token in vocab: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." - ) - elif f"{token}_1" in vocab: - multi_vector_tokens = [token] - i = 1 - while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: - multi_vector_tokens.append(f"{token}_{i}") - i += 1 + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) - raise ValueError( - f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." - ) + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 - is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + if is_multi_vector: + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + embeddings = [e for e in embedding] # noqa: C416 + else: + tokens = [token] + embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] - if is_multi_vector: - tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] - embeddings = [e for e in embedding] # noqa: C416 - else: - tokens = [token] - embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] + # add tokens and get ids + self.tokenizer.add_tokens(tokens) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + token_ids_and_embeddings += zip(token_ids, embeddings) - # add tokens and get ids - self.tokenizer.add_tokens(tokens) - token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + logger.info(f"Loaded textual inversion embedding for {token}.") - # resize token embeddings and set new embeddings + # resize token embeddings and set all new embeddings self.text_encoder.resize_token_embeddings(len(self.tokenizer)) - for token_id, embedding in zip(token_ids, embeddings): + for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - logger.info(f"Loaded textual inversion embedding for {token}.") - class LoraLoaderMixin: r""" diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 168ff8106c52..70b1431d630a 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -575,6 +575,31 @@ def test_text_inversion_download(self): out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) + # multi embedding load + with tempfile.TemporaryDirectory() as tmpdirname1: + with tempfile.TemporaryDirectory() as tmpdirname2: + ten = {"<*****>": torch.ones((32,))} + torch.save(ten, os.path.join(tmpdirname1, "learned_embeds.bin")) + + ten = {"<******>": 2 * torch.ones((1, 32))} + torch.save(ten, os.path.join(tmpdirname2, "learned_embeds.bin")) + + pipe.load_textual_inversion([tmpdirname1, tmpdirname2]) + + token = pipe.tokenizer.convert_tokens_to_ids("<*****>") + assert token == num_tokens + 8, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>" + + token = pipe.tokenizer.convert_tokens_to_ids("<******>") + assert token == num_tokens + 9, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>" + + prompt = "hey <*****> <******>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + def test_download_ignore_files(self): # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 with tempfile.TemporaryDirectory() as tmpdirname: