Skip to content

Support ControlNet #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi

- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.

- `--convert-controlnet`: Converts ControlNet models specified after this option. This can also convert multiple models if you specify like `--convert-controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth`.

- `--unet-support-controlnet`: enables a converted UNet model to receive additional inputs from ControlNet. This is required for generating image with using ControlNet and saved with a different name, `*_control-unet.mlpackage`, distinct from normal UNet. On the other hand, this UNet model can not work without ControlNet. Please use normal UNet for just txt2img.

</details>

## <a name="image-generation-with-python"></a> Image Generation with Python
Expand All @@ -157,6 +161,8 @@ Please refer to the help menu for all available arguments: `python -m python_cor
- `--model-version`: If you overrode the default model version while converting models to Core ML, you will need to specify the same model version here.
- `--compute-unit`: Note that the most performant compute unit for this particular implementation may differ across different hardware. `CPU_AND_GPU` or `CPU_AND_NE` may be faster than `ALL`. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
- `--controlnet`: ControlNet models specified with this option are used in image generation. Use this option in the format `--controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth` and make sure to use `--controlnet-inputs` in conjunction.
- `--controlnet-inputs`: Image inputs corresponding to each ControlNet model. Please provide image paths in same order as models in `--controlnet`, for example: `--controlnet-inputs image_mlsd image_depth`.

</details>

Expand Down Expand Up @@ -228,6 +234,14 @@ Optionally, it may also include the safety checker model that some versions of S

- `SafetyChecker.mlmodelc`

Optionally, for ControlNet:

- `ControlledUNet.mlmodelc` or `ControlledUnetChunk1.mlmodelc` & `ControlledUnetChunk2.mlmodelc` (enabled to receive ControlNet values)
- `controlnet/` (directory containing ControlNet models)
- `LllyasvielSdControlnetMlsd.mlmodelc` (for example, from lllyasviel/sd-controlnet-mlsd)
- `LllyasvielSdControlnetDepth.mlmodelc` (for example, from lllyasviel/sd-controlnet-depth)
- Other models you converted

Note that the chunked version of Unet is checked for first. Only if it is not present will the full `Unet.mlmodelc` be loaded. Chunking is required for iOS and iPadOS and not necessary for macOS.

</details>
Expand Down
244 changes: 244 additions & 0 deletions python_coreml_stable_diffusion/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin

import torch
import torch.nn as nn
import torch.nn.functional as F

from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map

class ControlNetConditioningEmbedding(nn.Module):

def __init__(
self,
conditioning_embedding_channels,
conditioning_channels=3,
block_out_channels=(16, 32, 96, 256),
):
super().__init__()

self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

self.blocks = nn.ModuleList([])

for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)

return embedding

class ControlNetModel(ModelMixin, ConfigMixin):

@register_to_config
def __init__(
self,
in_channels=4,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention=False,
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
cross_attention_dim=1280,
attention_head_dim=8,
use_linear_projection=False,
upcast_attention=False,
resnet_time_scale_shift="default",
conditioning_embedding_out_channels=(16, 32, 96, 256),
**kwargs,
):
super().__init__()

# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)

if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)

if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)

self._register_load_state_dict_pre_hook(linear_to_conv2d_map)

# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)

# time
time_embed_dim = block_out_channels[0] * 4

self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]

self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
)

# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
)

self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

# down
output_channel = block_out_channels[0]

controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)

for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

# mid
mid_block_channel = block_out_channels[-1]

controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
self.controlnet_mid_block = controlnet_block

self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)

def get_num_residuals(self):
num_res = 2 # initial sample + mid block
for down_block in self.down_blocks:
num_res += len(down_block.resnets)
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
num_res += len(down_block.downsamplers)
return num_res

def forward(
self,
sample,
timestep,
encoder_hidden_states,
controlnet_cond,
):
# 1. time
t_emb = self.time_proj(timestep)
emb = self.time_embedding(t_emb)

# 2. pre-process
sample = self.conv_in(sample)

controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)

sample += controlnet_cond

# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

down_block_res_samples += res_samples

# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
)

# 5. Control net blocks
controlnet_down_block_res_samples = ()

for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)

down_block_res_samples = controlnet_down_block_res_samples

mid_block_res_sample = self.controlnet_mid_block(sample)

return down_block_res_samples, mid_block_res_sample
17 changes: 17 additions & 0 deletions python_coreml_stable_diffusion/coreml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,22 @@ def _load_mlpackage(submodule_name, mlpackages_dir, model_version,

return CoreMLModel(mlpackage_path, compute_unit)

def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
"""
model_name = model_version.replace("/", "_")

logger.info(f"Loading controlnet_{model_name} mlpackage")

fname = f"ControlNet_{model_name}.mlpackage"

mlpackage_path = os.path.join(mlpackages_dir, fname)

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"controlnet_{model_name} CoreML model doesn't exist at {mlpackage_path}")

return CoreMLModel(mlpackage_path, compute_unit)

def get_available_compute_units():
return tuple(cu for cu in ct.ComputeUnit._member_names_)
Loading