Skip to content

Commit 371f143

Browse files
committed
Add custom MusicGen training, fix DAC RVQ demo
1 parent 40f1158 commit 371f143

11 files changed

+321
-54
lines changed

defaults.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ num_nodes = 1
1717
strategy = ""
1818

1919
# Precision to use for training
20-
precision = 16
20+
precision = "16"
2121

2222
# number of CPU workers for the DataLoader
2323
num_workers = 8

harmonai_tools/configs/model_configs/txt2audio/44k_vae_1024_64_stereo_adp_t5_prompts_12s.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"io_channels": 64
7272
},
7373
"training": {
74-
"learning_rate": 2e-5,
74+
"learning_rate": 4e-5,
7575
"demo": {
7676
"demo_every": 2000,
7777
"demo_steps": 250,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{
2+
"model_type": "diffusion_cond",
3+
"sample_size": 262144,
4+
"sample_rate": 44100,
5+
"audio_channels": 2,
6+
"model": {
7+
"pretransform": {
8+
"type": "autoencoder",
9+
"config": {
10+
"encoder": {
11+
"type": "dac",
12+
"config": {
13+
"in_channels": 2,
14+
"latent_dim": 128,
15+
"d_model": 128,
16+
"strides": [4, 4, 8, 8]
17+
}
18+
},
19+
"decoder": {
20+
"type": "dac",
21+
"config": {
22+
"out_channels": 2,
23+
"latent_dim": 64,
24+
"channels": 1536,
25+
"rates": [8, 8, 4, 4]
26+
}
27+
},
28+
"bottleneck": {
29+
"type": "vae"
30+
},
31+
"latent_dim": 64,
32+
"downsampling_ratio": 1024,
33+
"io_channels": 2
34+
}
35+
},
36+
"conditioning": {
37+
"configs": [
38+
{
39+
"id": "prompt",
40+
"type": "t5",
41+
"config": {
42+
"t5_model_name": "t5-base",
43+
"max_length": 77
44+
}
45+
}
46+
],
47+
"cond_dim": 768
48+
},
49+
"diffusion": {
50+
"type": "adp_cfg_1d",
51+
"cross_attention_cond_ids": ["prompt"],
52+
"config": {
53+
"in_channels": 64,
54+
"context_embedding_features": 768,
55+
"context_embedding_max_length":77,
56+
"channels": 256,
57+
"resnet_groups": 8,
58+
"kernel_multiplier_downsample": 2,
59+
"multipliers": [2, 3, 4, 5],
60+
"factors": [1, 2, 4],
61+
"num_blocks": [3, 3, 3],
62+
"attentions": [1, 1, 1, 1],
63+
"attention_heads": 16,
64+
"attention_features": 64,
65+
"attention_multiplier": 4,
66+
"use_nearest_upsample": false,
67+
"use_skip_scale": true,
68+
"use_context_time": true
69+
}
70+
},
71+
"io_channels": 64
72+
},
73+
"training": {
74+
"learning_rate": 4e-5,
75+
"demo": {
76+
"demo_every": 2000,
77+
"demo_steps": 250,
78+
"num_demos": 8,
79+
"demo_cond": [
80+
{"prompt": "Amen break 174 BPM"},
81+
{"prompt": "A car honking on a busy street"},
82+
{"prompt": "People talking in a crowded cafe"},
83+
{"prompt": "A short, beautiful piano riff in C minor"},
84+
{"prompt": "Tight Snare Drum"},
85+
{"prompt": "Calm, meditative ambient drone"},
86+
{"prompt": "Rattling snare"},
87+
{"prompt": "Clean bright guitar loop"}
88+
],
89+
"demo_cfg_scales": [3, 6, 9]
90+
}
91+
}
92+
}

harmonai_tools/interface/gradio.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def load_model(model_config, model_ckpt_path, pretransform_ckpt_path=None, devic
3333
print(f"Loading model checkpoint from {model_ckpt_path}")
3434

3535
# Load checkpoint
36-
#copy_state_dict(model, torch.load(model_ckpt_path)["state_dict"])
37-
model.load_state_dict(torch.load(model_ckpt_path)["state_dict"])
36+
copy_state_dict(model, torch.load(model_ckpt_path)["state_dict"])
37+
#model.load_state_dict(torch.load(model_ckpt_path)["state_dict"])
3838

3939
if pretransform_ckpt_path is not None:
4040
print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
@@ -392,7 +392,7 @@ def create_autoencoder_ui(model_config):
392392
is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
393393

394394
if is_dac_rvq:
395-
n_quantizers = model["bottleneck"]["config"]["num_quantizers"]
395+
n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
396396
else:
397397
n_quantizers = 0
398398

harmonai_tools/models/autoencoders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def encode(self, audio, return_info=False, skip_pretransform=False, **kwargs):
328328
latents = rearrange(latents, 'b c t -> b t c')
329329
latents = self.latent_pca.transform(latents)
330330
latents = rearrange(latents, 'b t c -> b c t')
331-
331+
332332
if return_info:
333333
return latents, info
334334

harmonai_tools/models/bottleneck.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,15 @@ def encode(self, x, return_info=False, **kwargs):
192192
if self.quantize_on_decode:
193193
return x, info if return_info else x
194194

195-
output = self.quantizer(x, **kwargs)
195+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
196+
197+
output = {
198+
"z": z,
199+
"codes": codes,
200+
"latents": latents,
201+
"vq/commitment_loss": commitment_loss,
202+
"vq/codebook_loss": codebook_loss,
203+
}
196204

197205
output["vq/commitment_loss"] /= self.num_quantizers
198206
output["vq/codebook_loss"] /= self.num_quantizers
@@ -207,7 +215,7 @@ def encode(self, x, return_info=False, **kwargs):
207215
def decode(self, x):
208216

209217
if self.quantize_on_decode:
210-
x = self.quantizer(x)["z"]
218+
x = self.quantizer(x)[0]
211219

212220
return x
213221

harmonai_tools/models/factory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def create_model_from_config(model_config):
1616
from .autoencoders import create_diffAE_from_config
1717
return create_diffAE_from_config(model_config)
1818
elif model_type == 'musicgen':
19-
from audiocraft.models import MusicGen
20-
return MusicGen.get_pretrained(model_config["model"]["pretrained"], device="cpu")
19+
from .musicgen import create_musicgen_from_config
20+
return create_musicgen_from_config(model_config)
2121
else:
2222
raise NotImplementedError(f'Unknown model type: {model_type}')
2323

harmonai_tools/models/musicgen.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
import typing as tp
3+
from audiocraft.models import MusicGen, CompressionModel, LMModel
4+
import audiocraft.quantization as qt
5+
from .autoencoders import AudioAutoencoder
6+
from .bottleneck import DACRVQBottleneck, DACRVQVAEBottleneck
7+
8+
from audiocraft.modules.codebooks_patterns import (
9+
DelayedPatternProvider,
10+
MusicLMPattern,
11+
ParallelPatternProvider,
12+
UnrolledPatternProvider,
13+
VALLEPattern,
14+
)
15+
16+
from audiocraft.modules.conditioners import (
17+
ConditionFuser,
18+
ConditioningProvider,
19+
T5Conditioner,
20+
)
21+
22+
def create_musicgen_from_config(config):
23+
model_config = config.get('model', None)
24+
assert model_config is not None, 'model config must be specified in config'
25+
26+
if model_config.get("pretrained", False):
27+
model = MusicGen.get_pretrained(model_config["pretrained"], device="cpu")
28+
29+
if model_config.get("reinit_lm", False):
30+
model.lm._init_weights("gaussian", "current", True)
31+
32+
return model
33+
34+
# Create MusicGen model from scratch
35+
compression_config = model_config.get('compression', None)
36+
assert compression_config is not None, 'compression config must be specified in model config'
37+
38+
compression_type = compression_config.get('type', None)
39+
assert compression_type is not None, 'type must be specified in compression config'
40+
41+
if compression_type == 'pretrained':
42+
compression_model = CompressionModel.get_pretrained(compression_config["config"]["name"])
43+
elif compression_type == "dac_rvq_ae":
44+
from .autoencoders import create_autoencoder_from_config
45+
autoencoder = create_autoencoder_from_config({"model": compression_config["config"], "sample_rate": config["sample_rate"]})
46+
autoencoder.load_state_dict(torch.load(compression_config["ckpt_path"], map_location="cpu")["state_dict"])
47+
compression_model = DACRVQCompressionModel(autoencoder)
48+
49+
lm_config = model_config.get('lm', None)
50+
assert lm_config is not None, 'lm config must be specified in model config'
51+
52+
codebook_pattern = lm_config.pop("codebook_pattern", "delay")
53+
54+
pattern_providers = {
55+
'parallel': ParallelPatternProvider,
56+
'delay': DelayedPatternProvider,
57+
'unroll': UnrolledPatternProvider,
58+
'valle': VALLEPattern,
59+
'musiclm': MusicLMPattern,
60+
}
61+
62+
pattern_provider = pattern_providers[codebook_pattern](n_q=compression_model.num_codebooks)
63+
64+
conditioning_config = model_config.get("conditioning", {})
65+
66+
condition_output_dim = conditioning_config.get("output_dim", 768)
67+
68+
condition_provider = ConditioningProvider(
69+
conditioners = {
70+
"description": T5Conditioner(
71+
name="t5-base",
72+
output_dim=condition_output_dim,
73+
word_dropout=0.3,
74+
normalize_text=False,
75+
finetune=False,
76+
device="cpu"
77+
)
78+
}
79+
)
80+
81+
condition_fuser = ConditionFuser(fuse2cond={
82+
"cross": ["description"],
83+
"prepend": [],
84+
"sum": []
85+
})
86+
87+
lm = LMModel(
88+
pattern_provider = pattern_provider,
89+
condition_provider = condition_provider,
90+
fuser = condition_fuser,
91+
n_q = compression_model.num_codebooks,
92+
card = compression_model.cardinality,
93+
**lm_config
94+
)
95+
96+
97+
model = MusicGen(
98+
name = model_config.get("name", "musicgen-scratch"),
99+
compression_model = compression_model,
100+
lm = lm,
101+
max_duration=30
102+
)
103+
104+
return model
105+
106+
class DACRVQCompressionModel(CompressionModel):
107+
def __init__(self, autoencoder: AudioAutoencoder):
108+
super().__init__()
109+
self.model = autoencoder.eval()
110+
111+
assert isinstance(self.model.bottleneck, DACRVQBottleneck) or isinstance(self.model.bottleneck, DACRVQVAEBottleneck), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck"
112+
113+
self.n_quantizers = self.model.bottleneck.num_quantizers
114+
115+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
116+
raise NotImplementedError("Forward and training with DAC RVQ not supported")
117+
118+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
119+
_, info = self.model.encode(x, return_info=True, n_quantizers=self.n_quantizers)
120+
codes = info["codes"]
121+
return codes, None
122+
123+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
124+
assert scale is None
125+
z_q = self.decode_latent(codes)
126+
return self.model.decode(z_q)
127+
128+
def decode_latent(self, codes: torch.Tensor):
129+
"""Decode from the discrete codes to continuous latent space."""
130+
return self.model.bottleneck.quantizer.from_codes(codes)[0]
131+
132+
@property
133+
def channels(self) -> int:
134+
return self.model.io_channels
135+
136+
@property
137+
def frame_rate(self) -> float:
138+
return self.model.sample_rate / self.model.downsampling_ratio
139+
140+
@property
141+
def sample_rate(self) -> int:
142+
return self.model.sample_rate
143+
144+
@property
145+
def cardinality(self) -> int:
146+
return self.model.bottleneck.quantizer.codebook_size
147+
148+
@property
149+
def num_codebooks(self) -> int:
150+
return self.n_quantizers
151+
152+
@property
153+
def total_codebooks(self) -> int:
154+
self.model.bottleneck.num_quantizers
155+
156+
def set_num_codebooks(self, n: int):
157+
"""Set the active number of codebooks used by the quantizer.
158+
"""
159+
assert n >= 1
160+
assert n <= self.total_codebooks
161+
self.n_quantizers = n

0 commit comments

Comments
 (0)