@@ -35,10 +35,10 @@ def __init__(
35
35
super ().__init__ ()
36
36
self .levels = levels
37
37
38
- _levels = torch .tensor (levels , dtype = int32 )
38
+ _levels = torch .tensor (levels , dtype = torch . int64 )
39
39
self .register_buffer ("_levels" , _levels , persistent = False )
40
40
41
- _basis = torch .cumprod (torch .tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = int32 )
41
+ _basis = torch .cumprod (torch .tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = torch . int64 )
42
42
self .register_buffer ("_basis" , _basis , persistent = False )
43
43
44
44
codebook_dim = len (levels )
@@ -90,7 +90,9 @@ def _indices_to_codes(self, indices):
90
90
91
91
def _codes_to_indices (self , zhat ):
92
92
zhat = self ._scale_and_shift (zhat )
93
- return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
93
+ zhat = zhat .round ().to (torch .int64 )
94
+ out = (zhat * self ._basis ).sum (dim = - 1 )
95
+ return out
94
96
95
97
def _indices_to_level_indices (self , indices ):
96
98
indices = rearrange (indices , '... -> ... 1' )
@@ -100,7 +102,7 @@ def _indices_to_level_indices(self, indices):
100
102
def indices_to_codes (self , indices ):
101
103
# Expects input of batch x sequence x num_codebooks
102
104
assert indices .shape [- 1 ] == self .num_codebooks , f'expected last dimension of { self .num_codebooks } but found last dimension of { indices .shape [- 1 ]} '
103
- codes = self ._indices_to_codes (indices )
105
+ codes = self ._indices_to_codes (indices . to ( torch . int64 ) )
104
106
codes = rearrange (codes , '... c d -> ... (c d)' )
105
107
return codes
106
108
@@ -116,7 +118,7 @@ def forward(self, z, skip_tanh: bool = False):
116
118
# make sure allowed dtype before quantizing
117
119
118
120
if z .dtype not in self .allowed_dtypes :
119
- z = z .float ( )
121
+ z = z .to ( torch . float64 )
120
122
121
123
codes = self .quantize (z , skip_tanh = skip_tanh )
122
124
indices = self ._codes_to_indices (codes )
0 commit comments