Skip to content

Add paint by example #1533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
57f632e
add paint by example
patrickvonplaten Dec 3, 2022
7fbb2ec
mkae loading possibel
patrickvonplaten Dec 3, 2022
5695196
up
patrickvonplaten Dec 3, 2022
d13dd70
Update src/diffusers/models/attention.py
patrickvonplaten Dec 3, 2022
c1df752
up
patrickvonplaten Dec 4, 2022
c7aa57b
finalize weight structure
patrickvonplaten Dec 4, 2022
3041771
make example work
patrickvonplaten Dec 4, 2022
1193cfa
make it work
patrickvonplaten Dec 4, 2022
03cbee1
up
patrickvonplaten Dec 4, 2022
98c14fb
up
patrickvonplaten Dec 6, 2022
938b4b9
Merge branch 'main' into add_paint_by_example
patrickvonplaten Dec 6, 2022
c7d5a9e
fix
patrickvonplaten Dec 6, 2022
396bc5c
del
patrickvonplaten Dec 6, 2022
314f79b
add
patrickvonplaten Dec 6, 2022
0513534
Merge branch 'add_paint_by_example' of https://github.com/huggingface…
patrickvonplaten Dec 6, 2022
735311d
update
patrickvonplaten Dec 6, 2022
b98a476
Apply suggestions from code review
patrickvonplaten Dec 6, 2022
07eb1d2
correct transformer 2d
patrickvonplaten Dec 6, 2022
7250d58
Merge branch 'add_paint_by_example' of https://github.com/huggingface…
patrickvonplaten Dec 6, 2022
5d2a7ca
finish
patrickvonplaten Dec 6, 2022
8ea23f8
up
patrickvonplaten Dec 6, 2022
7396992
up
patrickvonplaten Dec 6, 2022
54f4cc1
up
patrickvonplaten Dec 6, 2022
1e14802
up
patrickvonplaten Dec 6, 2022
d8eb5ad
fix
patrickvonplaten Dec 7, 2022
3eb4817
fix
patrickvonplaten Dec 7, 2022
634211e
Apply suggestions from code review
patrickvonplaten Dec 7, 2022
6ad90d6
Apply suggestions from code review
patrickvonplaten Dec 7, 2022
dc5cfa6
up
patrickvonplaten Dec 7, 2022
4668b12
finish
patrickvonplaten Dec 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.inpaint_by_example import InpaintByExampleImageEncoder, InpaintByExamplePipeline
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig


def shave_segments(path, n_shave_prefix_segments=1):
Expand Down Expand Up @@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model


def convert_inpaint_by_example_checkpoint(checkpoint):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
model = InpaintByExampleImageEncoder(config)

keys = list(checkpoint.keys())

text_model_dict = {}

for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]

# load clip vision
model.model.load_state_dict(text_model_dict)

# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}

MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}

mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]

num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]

model.mapper.load_state_dict(mapped_weights)

# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)

# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)

# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model


def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")

Expand Down Expand Up @@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint):
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--pipeline_type",
default=None,
type=str,
help="The pipeline type. If `None` pipeline will be automatically inferred.",
)
parser.add_argument(
"--image_size",
default=None,
Expand Down Expand Up @@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint):

original_config = OmegaConf.load(args.original_config_file)

if args.num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels

if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
Expand Down Expand Up @@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint):
vae.load_state_dict(converted_vae_checkpoint)

# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
model_type = args.pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]

if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
Expand All @@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint):
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
elif model_type == "InpaintByExample":
vision_model = convert_inpaint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = InpaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
CycleDiffusionPipeline,
InpaintByExamplePipeline,
LDMTextToImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
Expand Down
79 changes: 53 additions & 26 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def __init__(
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None

# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
Expand All @@ -414,25 +417,25 @@ def __init__(
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none

# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

# 2. Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = None

# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)

# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
Expand Down Expand Up @@ -481,11 +484,12 @@ def forward(self, hidden_states, context=None, timestep=None):
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states

# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
if self.attn2 is not None:
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states

# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Expand Down Expand Up @@ -666,14 +670,16 @@ def __init__(
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
geglu = ApproximateGELU(dim, inner_dim)
act_fn = ApproximateGELU(dim, inner_dim)

self.net = nn.ModuleList([])
# project in
self.net.append(geglu)
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
Expand All @@ -685,6 +691,27 @@ def forward(self, hidden_states):
return hidden_states


class GELU(nn.Module):
r"""
GELU activation function
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states


# feedforward
class GEGLU(nn.Module):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if is_torch_available() and is_transformers_available():
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .inpaint_by_example import InpaintByExamplePipeline
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/inpaint_by_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union

import numpy as np

import PIL
from PIL import Image

from ...utils import BaseOutput, is_torch_available, is_transformers_available


if is_transformers_available() and is_torch_available():
from .image_encoder import InpaintByExampleImageEncoder
from .pipeline_inpaint_by_example import InpaintByExamplePipeline
65 changes: 65 additions & 0 deletions src/diffusers/pipelines/inpaint_by_example/image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn

from transformers import CLIPPreTrainedModel, CLIPVisionModel

from ...models.attention import BasicTransformerBlock
from ...utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class InpaintByExampleImageEncoder(CLIPPreTrainedModel):
def __init__(self, config, proj_size=768):
super().__init__(config)
self.proj_size = proj_size

self.model = CLIPVisionModel(config)
self.mapper = InpaintByExampleMapper(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)

# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size)))

def forward(self, pixel_values):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output
latent_states = self.mapper(latent_states[:, None])
latent_states = self.final_layer_norm(latent_states)
latent_states = self.proj_out(latent_states)
return latent_states


class InpaintByExampleMapper(nn.Module):
def __init__(self, config):
super().__init__()
num_layers = (config.num_hidden_layers + 1) // 5
hid_size = config.hidden_size
num_heads = 1
self.blocks = nn.ModuleList(
[
BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)
for _ in range(num_layers)
]
)

def forward(self, hidden_states):
for block in self.blocks:
hidden_states = block(hidden_states)

return hidden_states
Loading