Skip to content

Commit bebb04d

Browse files
Add paint by example (huggingface#1533)
* add paint by example * mkae loading possibel * up * Update src/diffusers/models/attention.py * up * finalize weight structure * make example work * make it work * up * up * fix * del * add * update * Apply suggestions from code review * correct transformer 2d * finish * up * up * up * up * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review * up * finish Co-authored-by: Pedro Cuenca <[email protected]>
1 parent ab04d1c commit bebb04d

File tree

7 files changed

+711
-24
lines changed

7 files changed

+711
-24
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
AltDiffusionPipeline,
7373
CycleDiffusionPipeline,
7474
LDMTextToImagePipeline,
75+
PaintByExamplePipeline,
7576
StableDiffusionImageVariationPipeline,
7677
StableDiffusionImg2ImgPipeline,
7778
StableDiffusionInpaintPipeline,

models/attention.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ def __init__(
406406
):
407407
super().__init__()
408408
self.only_cross_attention = only_cross_attention
409+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
410+
411+
# 1. Self-Attn
409412
self.attn1 = CrossAttention(
410413
query_dim=dim,
411414
heads=num_attention_heads,
@@ -415,23 +418,28 @@ def __init__(
415418
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
416419
) # is a self-attention
417420
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
418-
self.attn2 = CrossAttention(
419-
query_dim=dim,
420-
cross_attention_dim=cross_attention_dim,
421-
heads=num_attention_heads,
422-
dim_head=attention_head_dim,
423-
dropout=dropout,
424-
bias=attention_bias,
425-
) # is self-attn if context is none
426421

427-
# layer norms
428-
self.use_ada_layer_norm = num_embeds_ada_norm is not None
429-
if self.use_ada_layer_norm:
430-
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431-
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
422+
# 2. Cross-Attn
423+
if cross_attention_dim is not None:
424+
self.attn2 = CrossAttention(
425+
query_dim=dim,
426+
cross_attention_dim=cross_attention_dim,
427+
heads=num_attention_heads,
428+
dim_head=attention_head_dim,
429+
dropout=dropout,
430+
bias=attention_bias,
431+
) # is self-attn if context is none
432432
else:
433-
self.norm1 = nn.LayerNorm(dim)
434-
self.norm2 = nn.LayerNorm(dim)
433+
self.attn2 = None
434+
435+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
436+
437+
if cross_attention_dim is not None:
438+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
439+
else:
440+
self.norm2 = None
441+
442+
# 3. Feed-forward
435443
self.norm3 = nn.LayerNorm(dim)
436444

437445
# if xformers is installed try to use memory_efficient_attention by default
@@ -481,11 +489,12 @@ def forward(self, hidden_states, context=None, timestep=None):
481489
else:
482490
hidden_states = self.attn1(norm_hidden_states) + hidden_states
483491

484-
# 2. Cross-Attention
485-
norm_hidden_states = (
486-
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
487-
)
488-
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
492+
if self.attn2 is not None:
493+
# 2. Cross-Attention
494+
norm_hidden_states = (
495+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
496+
)
497+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
489498

490499
# 3. Feed-forward
491500
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -666,14 +675,16 @@ def __init__(
666675
inner_dim = int(dim * mult)
667676
dim_out = dim_out if dim_out is not None else dim
668677

669-
if activation_fn == "geglu":
670-
geglu = GEGLU(dim, inner_dim)
678+
if activation_fn == "gelu":
679+
act_fn = GELU(dim, inner_dim)
680+
elif activation_fn == "geglu":
681+
act_fn = GEGLU(dim, inner_dim)
671682
elif activation_fn == "geglu-approximate":
672-
geglu = ApproximateGELU(dim, inner_dim)
683+
act_fn = ApproximateGELU(dim, inner_dim)
673684

674685
self.net = nn.ModuleList([])
675686
# project in
676-
self.net.append(geglu)
687+
self.net.append(act_fn)
677688
# project dropout
678689
self.net.append(nn.Dropout(dropout))
679690
# project out
@@ -685,6 +696,27 @@ def forward(self, hidden_states):
685696
return hidden_states
686697

687698

699+
class GELU(nn.Module):
700+
r"""
701+
GELU activation function
702+
"""
703+
704+
def __init__(self, dim_in: int, dim_out: int):
705+
super().__init__()
706+
self.proj = nn.Linear(dim_in, dim_out)
707+
708+
def gelu(self, gate):
709+
if gate.device.type != "mps":
710+
return F.gelu(gate)
711+
# mps: gelu is not implemented for float16
712+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
713+
714+
def forward(self, hidden_states):
715+
hidden_states = self.proj(hidden_states)
716+
hidden_states = self.gelu(hidden_states)
717+
return hidden_states
718+
719+
688720
# feedforward
689721
class GEGLU(nn.Module):
690722
r"""

pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
if is_torch_available() and is_transformers_available():
2929
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
3030
from .latent_diffusion import LDMTextToImagePipeline
31+
from .paint_by_example import PaintByExamplePipeline
3132
from .stable_diffusion import (
3233
CycleDiffusionPipeline,
3334
StableDiffusionImageVariationPipeline,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
6+
import PIL
7+
from PIL import Image
8+
9+
from ...utils import is_torch_available, is_transformers_available
10+
11+
12+
if is_transformers_available() and is_torch_available():
13+
from .image_encoder import PaintByExampleImageEncoder
14+
from .pipeline_paint_by_example import PaintByExamplePipeline
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
from torch import nn
16+
17+
from transformers import CLIPPreTrainedModel, CLIPVisionModel
18+
19+
from ...models.attention import BasicTransformerBlock
20+
from ...utils import logging
21+
22+
23+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class PaintByExampleImageEncoder(CLIPPreTrainedModel):
27+
def __init__(self, config, proj_size=768):
28+
super().__init__(config)
29+
self.proj_size = proj_size
30+
31+
self.model = CLIPVisionModel(config)
32+
self.mapper = PaintByExampleMapper(config)
33+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
34+
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
35+
36+
# uncondition for scaling
37+
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size)))
38+
39+
def forward(self, pixel_values):
40+
clip_output = self.model(pixel_values=pixel_values)
41+
latent_states = clip_output.pooler_output
42+
latent_states = self.mapper(latent_states[:, None])
43+
latent_states = self.final_layer_norm(latent_states)
44+
latent_states = self.proj_out(latent_states)
45+
return latent_states
46+
47+
48+
class PaintByExampleMapper(nn.Module):
49+
def __init__(self, config):
50+
super().__init__()
51+
num_layers = (config.num_hidden_layers + 1) // 5
52+
hid_size = config.hidden_size
53+
num_heads = 1
54+
self.blocks = nn.ModuleList(
55+
[
56+
BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)
57+
for _ in range(num_layers)
58+
]
59+
)
60+
61+
def forward(self, hidden_states):
62+
for block in self.blocks:
63+
hidden_states = block(hidden_states)
64+
65+
return hidden_states

0 commit comments

Comments
 (0)