Skip to content

Commit 0cce037

Browse files
authored
Merge pull request #74 from sekstini/finite-scalar-quantization
Add FSQ implementation
2 parents 2e0773b + e02ebab commit 0cce037

File tree

7 files changed

+198
-3
lines changed

7 files changed

+198
-3
lines changed

Diff for: README.md

+34
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,40 @@ indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
251251

252252
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
253253

254+
### Finite Scalar Quantization
255+
256+
<img src="./fsq.png" width="500px"></img>
257+
258+
| | VQ | FSQ |
259+
|------------------|----|-----|
260+
| Quantization | argmin_c \|\| z-c \|\| | round(f(z)) |
261+
| Gradients | Straight Through Estimation (STE) | STE |
262+
| Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
263+
| Tricks | EMA on codebook, codebook splitting, projections, ...| N/A |
264+
| Parameters | Codebook | N/A |
265+
266+
[This](https://arxiv.org/abs/2309.15505) work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.
267+
268+
269+
```python
270+
import torch
271+
from vector_quantize_pytorch import FSQ
272+
273+
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
274+
quantizer = FSQ(levels)
275+
276+
x = torch.randn(1, 1024, quantizer.dim)
277+
xhat, indices = quantizer(x)
278+
279+
print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
280+
print(indices.shape) # (1, 1024) - (batch, seq)
281+
282+
assert torch.all(xhat == quantizer.indices_to_codes(indices))
283+
assert torch.all(xhat == quantizer.implicit_codebook[indices])
284+
```
285+
286+
287+
254288
## Todo
255289

256290
- [x] allow for multi-headed codebooks

Diff for: examples/autoencoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, **vq_kwargs):
4141
def forward(self, x):
4242
for layer in self.layers:
4343
if isinstance(layer, VectorQuantize):
44-
x_flat, indices, commit_loss = layer(x)
44+
x, indices, commit_loss = layer(x)
4545
else:
4646
x = layer(x)
4747

Diff for: examples/autoencoder_fsq.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# FashionMnist VQ experiment with various settings, using FSQ.
2+
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
3+
4+
from tqdm.auto import trange
5+
6+
import math
7+
import torch
8+
import torch.nn as nn
9+
from torchvision import datasets, transforms
10+
from torch.utils.data import DataLoader
11+
12+
from vector_quantize_pytorch import FSQ
13+
14+
15+
lr = 3e-4
16+
train_iter = 1000
17+
levels = [8, 6, 5] # target size 2^8, actual size 240
18+
num_codes = math.prod(levels)
19+
seed = 1234
20+
device = "cuda" if torch.cuda.is_available() else "cpu"
21+
22+
23+
class SimpleFSQAutoEncoder(nn.Module):
24+
def __init__(self, levels: list[int]):
25+
super().__init__()
26+
self.layers = nn.ModuleList(
27+
[
28+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
29+
nn.MaxPool2d(kernel_size=2, stride=2),
30+
nn.GELU(),
31+
nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
32+
nn.Conv2d(8, 8, kernel_size=6, stride=3, padding=0),
33+
FSQ(levels),
34+
nn.ConvTranspose2d(8, 8, kernel_size=6, stride=3, padding=0),
35+
nn.Conv2d(8, 16, kernel_size=4, stride=1, padding=2),
36+
nn.GELU(),
37+
nn.Upsample(scale_factor=2, mode="nearest"),
38+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=2),
39+
]
40+
)
41+
return
42+
43+
def forward(self, x):
44+
for layer in self.layers:
45+
if isinstance(layer, FSQ):
46+
x, indices = layer(x)
47+
else:
48+
x = layer(x)
49+
50+
return x.clamp(-1, 1), indices
51+
52+
53+
def train(model, train_loader, train_iterations=1000):
54+
def iterate_dataset(data_loader):
55+
data_iter = iter(data_loader)
56+
while True:
57+
try:
58+
x, y = next(data_iter)
59+
except StopIteration:
60+
data_iter = iter(data_loader)
61+
x, y = next(data_iter)
62+
yield x.to(device), y.to(device)
63+
64+
for _ in (pbar := trange(train_iterations)):
65+
opt.zero_grad()
66+
x, _ = next(iterate_dataset(train_loader))
67+
out, indices = model(x)
68+
rec_loss = (out - x).abs().mean()
69+
rec_loss.backward()
70+
71+
opt.step()
72+
pbar.set_description(
73+
f"rec loss: {rec_loss.item():.3f} | "
74+
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
75+
)
76+
return
77+
78+
79+
transform = transforms.Compose(
80+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
81+
)
82+
train_dataset = DataLoader(
83+
datasets.FashionMNIST(
84+
root="~/data/fashion_mnist", train=True, download=True, transform=transform
85+
),
86+
batch_size=256,
87+
shuffle=True,
88+
)
89+
90+
print("baseline")
91+
torch.random.manual_seed(seed)
92+
model = SimpleFSQAutoEncoder(levels).to(device)
93+
opt = torch.optim.AdamW(model.parameters(), lr=lr)
94+
train(model, train_dataset, train_iterations=train_iter)

Diff for: fsq.png

288 KB
Loading

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.7.1',
6+
version = '1.8.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

Diff for: vector_quantize_pytorch/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
22
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
3-
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
3+
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
4+
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3+
Code adapted from Jax version in Appendix A.1
4+
"""
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
10+
def round_ste(z: torch.Tensor) -> torch.Tensor:
11+
"""Round with straight through gradients."""
12+
zhat = z.round()
13+
return z + (zhat - z).detach()
14+
15+
16+
class FSQ(nn.Module):
17+
def __init__(self, levels: list[int]):
18+
super().__init__()
19+
_levels = torch.tensor(levels, dtype=torch.int32)
20+
self.register_buffer("_levels", _levels)
21+
22+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
23+
self.register_buffer("_basis", _basis)
24+
25+
self.dim = len(levels)
26+
self.n_codes = self._levels.prod().item()
27+
implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
28+
self.register_buffer("implicit_codebook", implicit_codebook)
29+
30+
def forward(self, z: torch.Tensor) -> torch.Tensor:
31+
zhat = self.quantize(z)
32+
indices = self.codes_to_indices(zhat)
33+
return zhat, indices
34+
35+
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
36+
"""Bound `z`, an array of shape (..., d)."""
37+
half_l = (self._levels - 1) * (1 - eps) / 2
38+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
39+
shift = (offset / half_l).tan()
40+
return (z + shift).tanh() * half_l - offset
41+
42+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
43+
"""Quanitzes z, returns quantized zhat, same shape as z."""
44+
quantized = round_ste(self.bound(z))
45+
half_width = self._levels // 2 # Renormalize to [-1, 1].
46+
return quantized / half_width
47+
48+
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
49+
half_width = self._levels // 2
50+
return (zhat_normalized * half_width) + half_width
51+
52+
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
53+
half_width = self._levels // 2
54+
return (zhat - half_width) / half_width
55+
56+
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
57+
"""Converts a `code` to an index in the codebook."""
58+
assert zhat.shape[-1] == self.dim
59+
zhat = self._scale_and_shift(zhat)
60+
return (zhat * self._basis).sum(dim=-1).to(torch.int32)
61+
62+
def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
63+
"""Inverse of `codes_to_indices`."""
64+
indices = indices.unsqueeze(-1)
65+
codes_non_centered = (indices // self._basis) % self._levels
66+
return self._scale_and_shift_inverse(codes_non_centered)

0 commit comments

Comments
 (0)