Skip to content

Commit e70e9d8

Browse files
nemonamelessJunnYu
andauthored
Support UniDiffuser model and pipeline (#5487)
* unidiffuser initial version * add 7 example codes * refine codes * simplify codes * fix caption_decoder uvit from_pretrained * unified all pipelines in one UniDiffuserPipeline * fix set_timesteps * fix preprocess and some hpyparams * fix scheduler params * delete sub-pipeline * add unittest * fix sub-pipeline * add einops in ppdiffusers/requirements.txt * fix encode_prefix * unidiffuser dpm solver++ * fix dpmsolver++ for unidiffuser * refator * layer_norm * missing use_beam_search * fix center_crop and remove CFG * support xformers * support purge float16 * skip_special_tokens * update import UVitModel * remove einops dep * update dummpy --------- Co-authored-by: yujun <[email protected]>
1 parent fc5b9b1 commit e70e9d8

29 files changed

+2407
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ppdiffusers import UniDiffuserPipeline
16+
from ppdiffusers.utils import load_image
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
image = load_image("https://bj.bcebos.com/v1/paddlenlp/models/community/thu-ml/data/space.jpg")
20+
result = pipe(mode="i2t", image=image, prompt=None)
21+
text = result.texts[0]
22+
with open("image_to_text_generation-unidiffuser-result.txt", "w") as f:
23+
print("{}\n".format(text), file=f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
from ppdiffusers.utils import load_image
18+
19+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
20+
image = load_image("https://bj.bcebos.com/v1/paddlenlp/models/community/thu-ml/data/space.jpg")
21+
result = pipe(mode="i2t2i", image=image, prompt=None)
22+
image = result.images[0]
23+
image.save("image_variation-unidiffuser-result.png")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
prompt = "an elephant under the sea"
20+
result = pipe(mode="t2i", image=None, prompt=prompt)
21+
image = result.images[0]
22+
image.save("text_to_image_generation-unidiffuser-result.png")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
prompt = "an elephant under the sea"
20+
result = pipe(mode="t2i2t", image=None, prompt=prompt)
21+
text = result.texts[0]
22+
with open("text_variation-unidiffuser-result.txt", "w") as f:
23+
print("{}\n".format(text), file=f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
result = pipe(mode="i", image=None, prompt=None)
20+
image = result.images[0]
21+
image.save("unconditional_image_generation-unidiffuser-result.png")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
result = pipe(mode="joint", image=None, prompt=None)
20+
image = result.images[0]
21+
image.save("unconditional_image_text_generation-unidiffuser-result.png")
22+
text = result.texts[0]
23+
with open("unconditional_image_text_generation-unidiffuser-result.txt", "w") as f:
24+
print("{}\n".format(text), file=f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from ppdiffusers import UniDiffuserPipeline
17+
18+
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser")
19+
result = pipe(mode="t", image=None, prompt=None)
20+
text = result.texts[0]
21+
with open("unconditional_text_generation-unidiffuser-result.txt", "w") as f:
22+
print("{}\n".format(text), file=f)

ppdiffusers/ppdiffusers/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .configuration_utils import ConfigMixin
2020
from .utils import (
2121
OptionalDependencyNotAvailable,
22+
is_einops_available,
2223
is_fastdeploy_available,
2324
is_inflect_available,
2425
is_k_diffusion_available,
@@ -88,6 +89,7 @@
8889
PNDMPipeline,
8990
RePaintPipeline,
9091
ScoreSdeVePipeline,
92+
TextPipelineOutput,
9193
)
9294
from .schedulers import (
9395
DDIMInverseScheduler,
@@ -96,6 +98,7 @@
9698
DEISMultistepScheduler,
9799
DPMSolverMultistepScheduler,
98100
DPMSolverSinglestepScheduler,
101+
DPMSolverUniDiffuserScheduler,
99102
EulerAncestralDiscreteScheduler,
100103
EulerDiscreteScheduler,
101104
HeunDiscreteScheduler,
@@ -161,13 +164,15 @@
161164
TextToVideoSDPipeline,
162165
UnCLIPImageVariationPipeline,
163166
UnCLIPPipeline,
167+
UniDiffuserPipeline,
164168
VersatileDiffusionDualGuidedPipeline,
165169
VersatileDiffusionImageVariationPipeline,
166170
VersatileDiffusionPipeline,
167171
VersatileDiffusionTextToImagePipeline,
168172
VQDiffusionPipeline,
169173
)
170174
from .pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
175+
from .pipelines.unidiffuser.caption_decoder import CaptionDecoder
171176

172177
try:
173178
if not (is_paddle_available() and is_paddlenlp_available() and is_k_diffusion_available()):
@@ -200,3 +205,19 @@
200205
from .utils.dummy_paddle_and_librosa_objects import * # noqa F403
201206
else:
202207
from .pipelines import AudioDiffusionPipeline, Mel
208+
209+
try:
210+
if not (is_paddle_available() and is_paddlenlp_available() and is_einops_available()):
211+
raise OptionalDependencyNotAvailable()
212+
except OptionalDependencyNotAvailable:
213+
from .utils.dummy_paddle_and_paddlenlp_and_einops_objects import * # noqa F403
214+
else:
215+
from .pipelines import UniDiffuserPipeline
216+
217+
try:
218+
if not (is_paddle_available() and is_einops_available()):
219+
raise OptionalDependencyNotAvailable()
220+
except OptionalDependencyNotAvailable:
221+
from .utils.dummy_paddle_and_einops_objects import * # noqa F403
222+
else:
223+
from .models import UViTModel

ppdiffusers/ppdiffusers/models/__init__.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
# flake8: noqa
1616

1717

18-
from ..utils.import_utils import is_paddle_available
18+
from ..utils.import_utils import (
19+
OptionalDependencyNotAvailable,
20+
is_einops_available,
21+
is_paddle_available,
22+
)
1923

2024
if is_paddle_available():
2125
from .autoencoder_kl import AutoencoderKL
@@ -30,3 +34,11 @@
3034
from .unet_2d_condition import UNet2DConditionModel
3135
from .unet_3d_condition import UNet3DConditionModel
3236
from .vq_model import VQModel
37+
38+
try:
39+
if not (is_paddle_available() and is_einops_available()):
40+
raise OptionalDependencyNotAvailable()
41+
except OptionalDependencyNotAvailable:
42+
from ..utils.dummy_paddle_and_einops_objects import * # noqa F403
43+
else:
44+
from .uvit import UViTModel

ppdiffusers/ppdiffusers/models/attention.py

+53
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,59 @@
2424
from .embeddings import CombinedTimestepLabelEmbeddings
2525

2626

27+
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
28+
"""
29+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
30+
31+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
32+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
33+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
34+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
35+
argument.
36+
"""
37+
if drop_prob == 0.0 or not training:
38+
return input
39+
keep_prob = 1 - drop_prob
40+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
41+
random_tensor = keep_prob + paddle.rand(shape, dtype=input.dtype)
42+
random_tensor = paddle.floor(random_tensor) # binarize
43+
output = (input / keep_prob) * random_tensor
44+
return output
45+
46+
47+
class DropPath(nn.Layer):
48+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
49+
50+
def __init__(self, drop_prob: Optional[float] = None) -> None:
51+
super().__init__()
52+
self.drop_prob = drop_prob
53+
54+
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
55+
return drop_path(hidden_states, self.drop_prob, self.training)
56+
57+
def extra_repr(self) -> str:
58+
return "p={}".format(self.drop_prob)
59+
60+
61+
class Mlp(nn.Layer):
62+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
63+
super().__init__()
64+
out_features = out_features or in_features
65+
hidden_features = hidden_features or in_features
66+
self.fc1 = nn.Linear(in_features, hidden_features)
67+
self.act = act_layer()
68+
self.fc2 = nn.Linear(hidden_features, out_features)
69+
self.drop = nn.Dropout(drop)
70+
71+
def forward(self, x):
72+
x = self.fc1(x)
73+
x = self.act(x)
74+
x = self.drop(x)
75+
x = self.fc2(x)
76+
x = self.drop(x)
77+
return x
78+
79+
2780
class AttentionBlock(nn.Layer):
2881
"""
2982
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted

ppdiffusers/ppdiffusers/models/autoencoder_kl.py

+2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def disable_slicing(self):
160160

161161
@apply_forward_hook
162162
def encode(self, x: paddle.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
163+
# TODO junnyu, support float16
164+
x = x.cast(self.dtype)
163165
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
164166
return self.tiled_encode(x, return_dict=return_dict)
165167

ppdiffusers/ppdiffusers/models/embeddings.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
layer_norm=False,
126126
flatten=True,
127127
bias=True,
128+
add_pos_embed=True,
128129
):
129130
super().__init__()
130131

@@ -141,16 +142,23 @@ def __init__(
141142
else:
142143
self.norm = None
143144

144-
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
145-
self.register_buffer("pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=False)
145+
self.add_pos_embed = add_pos_embed
146+
if add_pos_embed:
147+
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
148+
self.register_buffer(
149+
"pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=False
150+
)
146151

147152
def forward(self, latent):
148153
latent = self.proj(latent)
149154
if self.flatten:
150155
latent = latent.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC
151156
if self.layer_norm:
152157
latent = self.norm(latent)
153-
return latent + self.pos_embed
158+
if self.add_pos_embed:
159+
return latent + self.pos_embed
160+
else:
161+
return latent
154162

155163

156164
class TimestepEmbedding(nn.Layer):

0 commit comments

Comments
 (0)