Skip to content

Commit 171cf29

Browse files
author
ablattmann
committed
add configs for training unconditional/class-conditional ldms
1 parent f8b4a07 commit 171cf29

File tree

13 files changed

+562
-53
lines changed

13 files changed

+562
-53
lines changed

README.md

+85-15
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,7 @@ bash scripts/download_first_stages.sh
5555
```
5656

5757
The first stage models can then be found in `models/first_stage_models/<model_spec>`
58-
### Training autoencoder models
5958

60-
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
61-
Training can be started by running
62-
```
63-
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec> -t --gpus 0,
64-
```
65-
where `config_spec` is one of {`autoencoder_kl_8x8x64.yaml`(f=32, d=64), `autoencoder_kl_16x16x16.yaml`(f=16, d=16),
66-
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
67-
68-
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
69-
repository.
7059

7160

7261
## Pretrained LDMs
@@ -78,9 +67,10 @@ repository.
7867
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
7968
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
8069
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
81-
| OpenImages | Super-resolution | N/A | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
70+
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
8271
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
83-
| Landscapes (finetuned 512) | Semantic Image Synthesis | LDM-VQ-4 (100 DDIM steps, eta=1) | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | |
72+
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
73+
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
8474

8575

8676
### Get the models
@@ -116,10 +106,90 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp
116106
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
117107
the examples provided in `data/inpainting_examples`.
118108

109+
110+
# Train your own LDMs
111+
112+
## Data preparation
113+
114+
### Faces
115+
For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
116+
repository.
117+
118+
### LSUN
119+
120+
The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
121+
We performed a custom split into training and validation images, and provide the corresponding filenames
122+
at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
123+
After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
124+
also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
125+
126+
### ImageNet
127+
The code will try to download (through [Academic
128+
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
129+
is used. However, since ImageNet is quite large, this requires a lot of disk
130+
space and time. If you already have ImageNet on your disk, you can speed things
131+
up by putting the data into
132+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
133+
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
134+
of `train`/`validation`. It should have the following structure:
135+
136+
```
137+
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
138+
├── n01440764
139+
│ ├── n01440764_10026.JPEG
140+
│ ├── n01440764_10027.JPEG
141+
│ ├── ...
142+
├── n01443537
143+
│ ├── n01443537_10007.JPEG
144+
│ ├── n01443537_10014.JPEG
145+
│ ├── ...
146+
├── ...
147+
```
148+
149+
If you haven't extracted the data, you can also place
150+
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
151+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
152+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
153+
extracted into above structure without downloading it again. Note that this
154+
will only happen if neither a folder
155+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
156+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
157+
if you want to force running the dataset preparation again.
158+
159+
160+
## Model Training
161+
162+
Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
163+
164+
### Training autoencoder models
165+
166+
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
167+
Training can be started by running
168+
```
169+
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
170+
```
171+
where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
172+
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
173+
174+
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
175+
repository.
176+
177+
### Training LDMs
178+
179+
In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
180+
Training can be started by running
181+
182+
```shell script
183+
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
184+
```
185+
186+
where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
187+
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
188+
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
189+
119190
## Coming Soon...
120191

121-
* Code for training LDMs and the corresponding compression models.
122-
* Inference scripts for conditional LDMs for various conditioning modalities.
192+
* More inference scripts for conditional LDMs.
123193
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
124194
* We will also release some further pretrained models.
125195

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
model:
2+
base_learning_rate: 2.0e-06
3+
target: ldm.models.diffusion.ddpm.LatentDiffusion
4+
params:
5+
linear_start: 0.0015
6+
linear_end: 0.0195
7+
num_timesteps_cond: 1
8+
log_every_t: 200
9+
timesteps: 1000
10+
first_stage_key: image
11+
image_size: 64
12+
channels: 3
13+
monitor: val/loss_simple_ema
14+
15+
unet_config:
16+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17+
params:
18+
image_size: 64
19+
in_channels: 3
20+
out_channels: 3
21+
model_channels: 224
22+
attention_resolutions:
23+
# note: this isn\t actually the resolution but
24+
# the downsampling factor, i.e. this corresnponds to
25+
# attention on spatial resolution 8,16,32, as the
26+
# spatial reolution of the latents is 64 for f4
27+
- 8
28+
- 4
29+
- 2
30+
num_res_blocks: 2
31+
channel_mult:
32+
- 1
33+
- 2
34+
- 3
35+
- 4
36+
num_head_channels: 32
37+
first_stage_config:
38+
target: ldm.models.autoencoder.VQModelInterface
39+
params:
40+
embed_dim: 3
41+
n_embed: 8192
42+
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43+
ddconfig:
44+
double_z: false
45+
z_channels: 3
46+
resolution: 256
47+
in_channels: 3
48+
out_ch: 3
49+
ch: 128
50+
ch_mult:
51+
- 1
52+
- 2
53+
- 4
54+
num_res_blocks: 2
55+
attn_resolutions: []
56+
dropout: 0.0
57+
lossconfig:
58+
target: torch.nn.Identity
59+
cond_stage_config: __is_unconditional__
60+
data:
61+
target: main.DataModuleFromConfig
62+
params:
63+
batch_size: 48
64+
num_workers: 5
65+
wrap: false
66+
train:
67+
target: taming.data.faceshq.CelebAHQTrain
68+
params:
69+
size: 256
70+
validation:
71+
target: taming.data.faceshq.CelebAHQValidation
72+
params:
73+
size: 256
74+
75+
76+
lightning:
77+
callbacks:
78+
image_logger:
79+
target: main.ImageLogger
80+
params:
81+
batch_frequency: 5000
82+
max_images: 8
83+
increase_log_steps: False
84+
85+
trainer:
86+
benchmark: True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
model:
2+
base_learning_rate: 1.0e-06
3+
target: ldm.models.diffusion.ddpm.LatentDiffusion
4+
params:
5+
linear_start: 0.0015
6+
linear_end: 0.0195
7+
num_timesteps_cond: 1
8+
log_every_t: 200
9+
timesteps: 1000
10+
first_stage_key: image
11+
cond_stage_key: class_label
12+
image_size: 32
13+
channels: 4
14+
cond_stage_trainable: true
15+
conditioning_key: crossattn
16+
monitor: val/loss_simple_ema
17+
unet_config:
18+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19+
params:
20+
image_size: 32
21+
in_channels: 4
22+
out_channels: 4
23+
model_channels: 256
24+
attention_resolutions:
25+
#note: this isn\t actually the resolution but
26+
# the downsampling factor, i.e. this corresnponds to
27+
# attention on spatial resolution 8,16,32, as the
28+
# spatial reolution of the latents is 32 for f8
29+
- 4
30+
- 2
31+
- 1
32+
num_res_blocks: 2
33+
channel_mult:
34+
- 1
35+
- 2
36+
- 4
37+
num_head_channels: 32
38+
use_spatial_transformer: true
39+
transformer_depth: 1
40+
context_dim: 512
41+
first_stage_config:
42+
target: ldm.models.autoencoder.VQModelInterface
43+
params:
44+
embed_dim: 4
45+
n_embed: 16384
46+
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47+
ddconfig:
48+
double_z: false
49+
z_channels: 4
50+
resolution: 256
51+
in_channels: 3
52+
out_ch: 3
53+
ch: 128
54+
ch_mult:
55+
- 1
56+
- 2
57+
- 2
58+
- 4
59+
num_res_blocks: 2
60+
attn_resolutions:
61+
- 32
62+
dropout: 0.0
63+
lossconfig:
64+
target: torch.nn.Identity
65+
cond_stage_config:
66+
target: ldm.modules.encoders.modules.ClassEmbedder
67+
params:
68+
embed_dim: 512
69+
key: class_label
70+
data:
71+
target: main.DataModuleFromConfig
72+
params:
73+
batch_size: 64
74+
num_workers: 12
75+
wrap: false
76+
train:
77+
target: ldm.data.imagenet.ImageNetTrain
78+
params:
79+
config:
80+
size: 256
81+
validation:
82+
target: ldm.data.imagenet.ImageNetValidation
83+
params:
84+
config:
85+
size: 256
86+
87+
88+
lightning:
89+
callbacks:
90+
image_logger:
91+
target: main.ImageLogger
92+
params:
93+
batch_frequency: 5000
94+
max_images: 8
95+
increase_log_steps: False
96+
97+
trainer:
98+
benchmark: True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
model:
2+
base_learning_rate: 2.0e-06
3+
target: ldm.models.diffusion.ddpm.LatentDiffusion
4+
params:
5+
linear_start: 0.0015
6+
linear_end: 0.0195
7+
num_timesteps_cond: 1
8+
log_every_t: 200
9+
timesteps: 1000
10+
first_stage_key: image
11+
image_size: 64
12+
channels: 3
13+
monitor: val/loss_simple_ema
14+
unet_config:
15+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16+
params:
17+
image_size: 64
18+
in_channels: 3
19+
out_channels: 3
20+
model_channels: 224
21+
attention_resolutions:
22+
# note: this isn\t actually the resolution but
23+
# the downsampling factor, i.e. this corresnponds to
24+
# attention on spatial resolution 8,16,32, as the
25+
# spatial reolution of the latents is 64 for f4
26+
- 8
27+
- 4
28+
- 2
29+
num_res_blocks: 2
30+
channel_mult:
31+
- 1
32+
- 2
33+
- 3
34+
- 4
35+
num_head_channels: 32
36+
first_stage_config:
37+
target: ldm.models.autoencoder.VQModelInterface
38+
params:
39+
embed_dim: 3
40+
n_embed: 8192
41+
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42+
ddconfig:
43+
double_z: false
44+
z_channels: 3
45+
resolution: 256
46+
in_channels: 3
47+
out_ch: 3
48+
ch: 128
49+
ch_mult:
50+
- 1
51+
- 2
52+
- 4
53+
num_res_blocks: 2
54+
attn_resolutions: []
55+
dropout: 0.0
56+
lossconfig:
57+
target: torch.nn.Identity
58+
cond_stage_config: __is_unconditional__
59+
data:
60+
target: main.DataModuleFromConfig
61+
params:
62+
batch_size: 42
63+
num_workers: 5
64+
wrap: false
65+
train:
66+
target: taming.data.faceshq.FFHQTrain
67+
params:
68+
size: 256
69+
validation:
70+
target: taming.data.faceshq.FFHQValidation
71+
params:
72+
size: 256
73+
74+
75+
lightning:
76+
callbacks:
77+
image_logger:
78+
target: main.ImageLogger
79+
params:
80+
batch_frequency: 5000
81+
max_images: 8
82+
increase_log_steps: False
83+
84+
trainer:
85+
benchmark: True

0 commit comments

Comments
 (0)