Skip to content

Commit 2d47195

Browse files
committed
initialized
0 parents  commit 2d47195

11 files changed

+499
-0
lines changed

.gitignore

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# OS specific
2+
*.DS_Store
3+
4+
# Python
5+
/build
6+
/dist
7+
__pycache__
8+
*.ipynb_checkpoints
9+
*.egg-info
10+
11+
# Vim
12+
*.vim
13+
*.swk
14+
*.swl
15+
*.swm
16+
*.swn
17+
*.swo
18+
*.swp

LICENSE

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
Modified MIT License
2+
3+
Software Copyright (c) 2021 OpenAI
4+
5+
We don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to
6+
do with as you please. We only ask that you use the model responsibly and clearly indicate that it
7+
was used.
8+
9+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
10+
associated documentation files (the "Software"), to deal in the Software without restriction,
11+
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
12+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
13+
subject to the following conditions:
14+
15+
The above copyright notice and this permission notice shall be included
16+
in all copies or substantial portions of the Software.
17+
The above copyright notice and this permission notice need not be included
18+
with content created by the Software.
19+
20+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
21+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
23+
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
24+
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
25+
OR OTHER DEALINGS IN THE SOFTWARE.

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Overview
2+
3+
[[Blog]](https://openai.com/blog/dall-e/) [[Paper]](https://arxiv.org/abs/2102.12092) [[Model Card]](model_card.md) [[Usage]](notebooks/usage.ipynb)
4+
5+
This is the official PyTorch package for the discrete VAE used for DALL·E.
6+
7+
# Installation
8+
9+
Before running [the example notebook](notebooks/usage.ipynb), you will need to install the package using
10+
11+
pip install git+https://github.com/openai/DALL-E.git

dall_e/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import io, requests
2+
import torch
3+
import torch.nn as nn
4+
5+
from dall_e.encoder import Encoder
6+
from dall_e.decoder import Decoder
7+
from dall_e.utils import map_pixels, unmap_pixels
8+
9+
def load_model(path: str, device: torch.device = None) -> nn.Module:
10+
if path.startswith('http://') or path.startswith('https://'):
11+
resp = requests.get(path)
12+
resp.raise_for_status()
13+
14+
with io.BytesIO(resp.content) as buf:
15+
return torch.load(buf, map_location=device)
16+
else:
17+
with open(path, 'rb') as f:
18+
return torch.load(f, map_location=device)

dall_e/decoder.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import attr
2+
import numpy as np
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
from collections import OrderedDict
9+
from functools import partial
10+
from dall_e.utils import Conv2d
11+
12+
@attr.s(eq=False, repr=False)
13+
class DecoderBlock(nn.Module):
14+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16+
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17+
18+
device: torch.device = attr.ib(default=None)
19+
requires_grad: bool = attr.ib(default=False)
20+
21+
def __attrs_post_init__(self) -> None:
22+
super().__init__()
23+
self.n_hid = self.n_out // 4
24+
self.post_gain = 1 / (self.n_layers ** 2)
25+
26+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27+
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28+
self.res_path = nn.Sequential(OrderedDict([
29+
('relu_1', nn.ReLU()),
30+
('conv_1', make_conv(self.n_in, self.n_hid, 1)),
31+
('relu_2', nn.ReLU()),
32+
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33+
('relu_3', nn.ReLU()),
34+
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35+
('relu_4', nn.ReLU()),
36+
('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
return self.id_path(x) + self.post_gain * self.res_path(x)
40+
41+
@attr.s(eq=False, repr=False)
42+
class Decoder(nn.Module):
43+
group_count: int = 4
44+
n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
45+
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
46+
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
47+
output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
48+
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
49+
50+
device: torch.device = attr.ib(default=torch.device('cpu'))
51+
requires_grad: bool = attr.ib(default=False)
52+
use_mixed_precision: bool = attr.ib(default=True)
53+
54+
def __attrs_post_init__(self) -> None:
55+
super().__init__()
56+
57+
blk_range = range(self.n_blk_per_group)
58+
n_layers = self.group_count * self.n_blk_per_group
59+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
60+
make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
61+
requires_grad=self.requires_grad)
62+
63+
self.blocks = nn.Sequential(OrderedDict([
64+
('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
65+
('group_1', nn.Sequential(OrderedDict([
66+
*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
67+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
68+
]))),
69+
('group_2', nn.Sequential(OrderedDict([
70+
*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
71+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
72+
]))),
73+
('group_3', nn.Sequential(OrderedDict([
74+
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
75+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
76+
]))),
77+
('group_4', nn.Sequential(OrderedDict([
78+
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
79+
]))),
80+
('output', nn.Sequential(OrderedDict([
81+
('relu', nn.ReLU()),
82+
('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
83+
]))),
84+
]))
85+
86+
def forward(self, x: torch.Tensor) -> torch.Tensor:
87+
if len(x.shape) != 4:
88+
raise ValueError(f'input shape {x.shape} is not 4d')
89+
if x.shape[1] != self.vocab_size:
90+
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
91+
if x.dtype != torch.float32:
92+
raise ValueError('input must have dtype torch.float32')
93+
94+
return self.blocks(x)

dall_e/encoder.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import attr
2+
import numpy as np
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
from collections import OrderedDict
9+
from functools import partial
10+
from dall_e.utils import Conv2d
11+
12+
@attr.s(eq=False, repr=False)
13+
class EncoderBlock(nn.Module):
14+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16+
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17+
18+
device: torch.device = attr.ib(default=None)
19+
requires_grad: bool = attr.ib(default=False)
20+
21+
def __attrs_post_init__(self) -> None:
22+
super().__init__()
23+
self.n_hid = self.n_out // 4
24+
self.post_gain = 1 / (self.n_layers ** 2)
25+
26+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27+
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28+
self.res_path = nn.Sequential(OrderedDict([
29+
('relu_1', nn.ReLU()),
30+
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
31+
('relu_2', nn.ReLU()),
32+
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33+
('relu_3', nn.ReLU()),
34+
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35+
('relu_4', nn.ReLU()),
36+
('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
return self.id_path(x) + self.post_gain * self.res_path(x)
40+
41+
@attr.s(eq=False, repr=False)
42+
class Encoder(nn.Module):
43+
group_count: int = 4
44+
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
45+
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
46+
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
47+
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
48+
49+
device: torch.device = attr.ib(default=torch.device('cpu'))
50+
requires_grad: bool = attr.ib(default=False)
51+
use_mixed_precision: bool = attr.ib(default=True)
52+
53+
def __attrs_post_init__(self) -> None:
54+
super().__init__()
55+
56+
blk_range = range(self.n_blk_per_group)
57+
n_layers = self.group_count * self.n_blk_per_group
58+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
59+
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
60+
requires_grad=self.requires_grad)
61+
62+
self.blocks = nn.Sequential(OrderedDict([
63+
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
64+
('group_1', nn.Sequential(OrderedDict([
65+
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
66+
('pool', nn.MaxPool2d(kernel_size=2)),
67+
]))),
68+
('group_2', nn.Sequential(OrderedDict([
69+
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
70+
('pool', nn.MaxPool2d(kernel_size=2)),
71+
]))),
72+
('group_3', nn.Sequential(OrderedDict([
73+
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
74+
('pool', nn.MaxPool2d(kernel_size=2)),
75+
]))),
76+
('group_4', nn.Sequential(OrderedDict([
77+
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
78+
]))),
79+
('output', nn.Sequential(OrderedDict([
80+
('relu', nn.ReLU()),
81+
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
82+
]))),
83+
]))
84+
85+
def forward(self, x: torch.Tensor) -> torch.Tensor:
86+
if len(x.shape) != 4:
87+
raise ValueError(f'input shape {x.shape} is not 4d')
88+
if x.shape[1] != self.input_channels:
89+
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
90+
if x.dtype != torch.float32:
91+
raise ValueError('input must have dtype torch.float32')
92+
93+
return self.blocks(x)

dall_e/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import attr
2+
import math
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
logit_laplace_eps: float = 0.1
9+
10+
@attr.s(eq=False)
11+
class Conv2d(nn.Module):
12+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
13+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
14+
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
15+
16+
use_float16: bool = attr.ib(default=True)
17+
device: torch.device = attr.ib(default=torch.device('cpu'))
18+
requires_grad: bool = attr.ib(default=False)
19+
20+
def __attrs_post_init__(self) -> None:
21+
super().__init__()
22+
23+
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
24+
device=self.device, requires_grad=self.requires_grad)
25+
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
26+
27+
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
28+
requires_grad=self.requires_grad)
29+
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
30+
31+
def forward(self, x: torch.Tensor) -> torch.Tensor:
32+
if self.use_float16 and 'cuda' in self.w.device.type:
33+
if x.dtype != torch.float16:
34+
x = x.half()
35+
36+
w, b = self.w.half(), self.b.half()
37+
else:
38+
if x.dtype != torch.float32:
39+
x = x.float()
40+
41+
w, b = self.w, self.b
42+
43+
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
44+
45+
def map_pixels(x: torch.Tensor) -> torch.Tensor:
46+
if len(x.shape) != 4:
47+
raise ValueError('expected input to be 4d')
48+
if x.dtype != torch.float:
49+
raise ValueError('expected input to have type float')
50+
51+
return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
52+
53+
def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
54+
if len(x.shape) != 4:
55+
raise ValueError('expected input to be 4d')
56+
if x.dtype != torch.float:
57+
raise ValueError('expected input to have type float')
58+
59+
return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)

model_card.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Model Card: DALL·E dVAE
2+
3+
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from
4+
Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we're providing some information about about the discrete
5+
VAE (dVAE) that was used to train DALL·E.
6+
7+
## Model Details
8+
9+
The dVAE was developed by researchers at OpenAI to reduce the memory footprint of the transformer trained on the
10+
text-to-image generation task. The details involved in training the dVAE are described in [the paper][dalle_paper]. This
11+
model card describes the first version of the model, released in February 2021. The model consists of a convolutional
12+
encoder and decoder whose architectures are described [here](dall_e/encoder.py) and [here](dall_e/decoder.py), respectively.
13+
For questions or comments about the models or the code release, please file a Github issue.
14+
15+
## Model Use
16+
17+
### Intended Use
18+
19+
The model is intended for others to use for training their own generative models.
20+
21+
### Out-of-Scope Use Cases
22+
23+
This model is inappropriate for high-fidelity image processing applications. We also do not recommend its use as a
24+
general-purpose image compressor.
25+
26+
## Training Data
27+
28+
The model was trained on publicly available text-image pairs collected from the internet. This data consists partly of
29+
[Conceptual Captions][cc] and a filtered subset of [YFCC100M][yfcc100m]. We used a subset of the filters described in
30+
[Sharma et al.][cc_paper] to construct this dataset; further details are described in [our paper][dalle_paper]. We will
31+
not be releasing the dataset.
32+
33+
## Performance and Limitations
34+
35+
The heavy compression from the encoding process results in a noticeable loss of detail in the reconstructed images. This
36+
renders it inappropriate for applications that require fine-grained details of the image to be preserved.
37+
38+
[dalle_paper]: https://arxiv.org/abs/2102.12092
39+
[cc]: https://ai.google.com/research/ConceptualCaptions
40+
[cc_paper]: https://www.aclweb.org/anthology/P18-1238/
41+
[yfcc100m]: http://projects.dfki.uni-kl.de/yfcc100m/

0 commit comments

Comments
 (0)