1
+ import torch
2
+ import typing as tp
3
+ from audiocraft .models import MusicGen , CompressionModel , LMModel
4
+ import audiocraft .quantization as qt
5
+ from .autoencoders import AudioAutoencoder
6
+ from .bottleneck import DACRVQBottleneck , DACRVQVAEBottleneck
7
+
8
+ from audiocraft .modules .codebooks_patterns import (
9
+ DelayedPatternProvider ,
10
+ MusicLMPattern ,
11
+ ParallelPatternProvider ,
12
+ UnrolledPatternProvider ,
13
+ VALLEPattern ,
14
+ )
15
+
16
+ from audiocraft .modules .conditioners import (
17
+ ConditionFuser ,
18
+ ConditioningProvider ,
19
+ T5Conditioner ,
20
+ )
21
+
22
+ def create_musicgen_from_config (config ):
23
+ model_config = config .get ('model' , None )
24
+ assert model_config is not None , 'model config must be specified in config'
25
+
26
+ if model_config .get ("pretrained" , False ):
27
+ model = MusicGen .get_pretrained (model_config ["pretrained" ], device = "cpu" )
28
+
29
+ if model_config .get ("reinit_lm" , False ):
30
+ model .lm ._init_weights ("gaussian" , "current" , True )
31
+
32
+ return model
33
+
34
+ # Create MusicGen model from scratch
35
+ compression_config = model_config .get ('compression' , None )
36
+ assert compression_config is not None , 'compression config must be specified in model config'
37
+
38
+ compression_type = compression_config .get ('type' , None )
39
+ assert compression_type is not None , 'type must be specified in compression config'
40
+
41
+ if compression_type == 'pretrained' :
42
+ compression_model = CompressionModel .get_pretrained (compression_config ["config" ]["name" ])
43
+ elif compression_type == "dac_rvq_ae" :
44
+ from .autoencoders import create_autoencoder_from_config
45
+ autoencoder = create_autoencoder_from_config ({"model" : compression_config ["config" ], "sample_rate" : config ["sample_rate" ]})
46
+ autoencoder .load_state_dict (torch .load (compression_config ["ckpt_path" ], map_location = "cpu" )["state_dict" ])
47
+ compression_model = DACRVQCompressionModel (autoencoder )
48
+
49
+ lm_config = model_config .get ('lm' , None )
50
+ assert lm_config is not None , 'lm config must be specified in model config'
51
+
52
+ codebook_pattern = lm_config .pop ("codebook_pattern" , "delay" )
53
+
54
+ pattern_providers = {
55
+ 'parallel' : ParallelPatternProvider ,
56
+ 'delay' : DelayedPatternProvider ,
57
+ 'unroll' : UnrolledPatternProvider ,
58
+ 'valle' : VALLEPattern ,
59
+ 'musiclm' : MusicLMPattern ,
60
+ }
61
+
62
+ pattern_provider = pattern_providers [codebook_pattern ](n_q = compression_model .num_codebooks )
63
+
64
+ conditioning_config = model_config .get ("conditioning" , {})
65
+
66
+ condition_output_dim = conditioning_config .get ("output_dim" , 768 )
67
+
68
+ condition_provider = ConditioningProvider (
69
+ conditioners = {
70
+ "description" : T5Conditioner (
71
+ name = "t5-base" ,
72
+ output_dim = condition_output_dim ,
73
+ word_dropout = 0.3 ,
74
+ normalize_text = False ,
75
+ finetune = False ,
76
+ device = "cpu"
77
+ )
78
+ }
79
+ )
80
+
81
+ condition_fuser = ConditionFuser (fuse2cond = {
82
+ "cross" : ["description" ],
83
+ "prepend" : [],
84
+ "sum" : []
85
+ })
86
+
87
+ lm = LMModel (
88
+ pattern_provider = pattern_provider ,
89
+ condition_provider = condition_provider ,
90
+ fuser = condition_fuser ,
91
+ n_q = compression_model .num_codebooks ,
92
+ card = compression_model .cardinality ,
93
+ ** lm_config
94
+ )
95
+
96
+
97
+ model = MusicGen (
98
+ name = model_config .get ("name" , "musicgen-scratch" ),
99
+ compression_model = compression_model ,
100
+ lm = lm ,
101
+ max_duration = 30
102
+ )
103
+
104
+ return model
105
+
106
+ class DACRVQCompressionModel (CompressionModel ):
107
+ def __init__ (self , autoencoder : AudioAutoencoder ):
108
+ super ().__init__ ()
109
+ self .model = autoencoder .eval ()
110
+
111
+ assert isinstance (self .model .bottleneck , DACRVQBottleneck ) or isinstance (self .model .bottleneck , DACRVQVAEBottleneck ), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck"
112
+
113
+ self .n_quantizers = self .model .bottleneck .num_quantizers
114
+
115
+ def forward (self , x : torch .Tensor ) -> qt .QuantizedResult :
116
+ raise NotImplementedError ("Forward and training with DAC RVQ not supported" )
117
+
118
+ def encode (self , x : torch .Tensor ) -> tp .Tuple [torch .Tensor , tp .Optional [torch .Tensor ]]:
119
+ _ , info = self .model .encode (x , return_info = True , n_quantizers = self .n_quantizers )
120
+ codes = info ["codes" ]
121
+ return codes , None
122
+
123
+ def decode (self , codes : torch .Tensor , scale : tp .Optional [torch .Tensor ] = None ):
124
+ assert scale is None
125
+ z_q = self .decode_latent (codes )
126
+ return self .model .decode (z_q )
127
+
128
+ def decode_latent (self , codes : torch .Tensor ):
129
+ """Decode from the discrete codes to continuous latent space."""
130
+ return self .model .bottleneck .quantizer .from_codes (codes )[0 ]
131
+
132
+ @property
133
+ def channels (self ) -> int :
134
+ return self .model .io_channels
135
+
136
+ @property
137
+ def frame_rate (self ) -> float :
138
+ return self .model .sample_rate / self .model .downsampling_ratio
139
+
140
+ @property
141
+ def sample_rate (self ) -> int :
142
+ return self .model .sample_rate
143
+
144
+ @property
145
+ def cardinality (self ) -> int :
146
+ return self .model .bottleneck .quantizer .codebook_size
147
+
148
+ @property
149
+ def num_codebooks (self ) -> int :
150
+ return self .n_quantizers
151
+
152
+ @property
153
+ def total_codebooks (self ) -> int :
154
+ self .model .bottleneck .num_quantizers
155
+
156
+ def set_num_codebooks (self , n : int ):
157
+ """Set the active number of codebooks used by the quantizer.
158
+ """
159
+ assert n >= 1
160
+ assert n <= self .total_codebooks
161
+ self .n_quantizers = n
0 commit comments