Skip to content

Commit 4fdc25f

Browse files
authored
Merge pull request #171 from julian-parker/patch-4
Use 64bit ints to guard against issues with very large FSQ codebook size
2 parents e43988d + f0def31 commit 4fdc25f

File tree

1 file changed

+7
-5
lines changed
  • stable_audio_tools/models

1 file changed

+7
-5
lines changed

stable_audio_tools/models/fsq.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def __init__(
3535
super().__init__()
3636
self.levels = levels
3737

38-
_levels = torch.tensor(levels, dtype=int32)
38+
_levels = torch.tensor(levels, dtype=torch.int64)
3939
self.register_buffer("_levels", _levels, persistent = False)
4040

41-
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
41+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64)
4242
self.register_buffer("_basis", _basis, persistent = False)
4343

4444
codebook_dim = len(levels)
@@ -90,7 +90,9 @@ def _indices_to_codes(self, indices):
9090

9191
def _codes_to_indices(self, zhat):
9292
zhat = self._scale_and_shift(zhat)
93-
return (zhat * self._basis).sum(dim=-1).to(int32)
93+
zhat = zhat.round().to(torch.int64)
94+
out = (zhat * self._basis).sum(dim=-1)
95+
return out
9496

9597
def _indices_to_level_indices(self, indices):
9698
indices = rearrange(indices, '... -> ... 1')
@@ -100,7 +102,7 @@ def _indices_to_level_indices(self, indices):
100102
def indices_to_codes(self, indices):
101103
# Expects input of batch x sequence x num_codebooks
102104
assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}'
103-
codes = self._indices_to_codes(indices)
105+
codes = self._indices_to_codes(indices.to(torch.int64))
104106
codes = rearrange(codes, '... c d -> ... (c d)')
105107
return codes
106108

@@ -116,7 +118,7 @@ def forward(self, z, skip_tanh: bool = False):
116118
# make sure allowed dtype before quantizing
117119

118120
if z.dtype not in self.allowed_dtypes:
119-
z = z.float()
121+
z = z.to(torch.float64)
120122

121123
codes = self.quantize(z, skip_tanh=skip_tanh)
122124
indices = self._codes_to_indices(codes)

0 commit comments

Comments
 (0)