Skip to content

Commit a2dd95f

Browse files
committed
Update LICENSE, update MusicGen code, update dependencies
1 parent 371f143 commit a2dd95f

File tree

7 files changed

+109
-14
lines changed

7 files changed

+109
-14
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Harmonai-org
3+
Copyright (c) 2023 Stability AI
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

harmonai_tools/models/conditioners.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import typing as tp
66
import gc
77
import os
8+
from ..training.utils import copy_state_dict
9+
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
810

911
from audio_diffusion_pytorch_fork import NumberEmbedder
1012

@@ -91,11 +93,13 @@ def __init__(self,
9193
feature_layer_ix: int = -1,
9294
audio_model_type="HTSAT-base",
9395
enable_fusion=True,
94-
project_out: bool = False):
96+
project_out: bool = False,
97+
finetune: bool = False):
9598
super().__init__(768 if use_text_features else 512, output_dim, 1, project_out=project_out)
9699

97100
self.use_text_features = use_text_features
98101
self.feature_layer_ix = feature_layer_ix
102+
self.finetune = finetune
99103

100104
# Suppress logging from transformers
101105
previous_level = logging.root.manager.disable
@@ -105,8 +109,23 @@ def __init__(self,
105109
try:
106110
import laion_clap
107111

108-
self.__dict__["model"] = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu').requires_grad_(False).eval()
109-
self.model.load_ckpt(clap_ckpt_path)
112+
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
113+
114+
if self.finetune:
115+
self.model = model
116+
else:
117+
self.__dict__["model"] = model
118+
119+
state_dict = clap_load_state_dict(clap_ckpt_path)
120+
self.model.model.load_state_dict(state_dict, strict=False)
121+
122+
if self.finetune:
123+
self.model.model.text_branch.requires_grad_(True)
124+
self.model.model.text_branch.train()
125+
else:
126+
self.model.model.text_branch.requires_grad_(False)
127+
self.model.model.text_branch.eval()
128+
110129
finally:
111130
logging.disable(previous_level)
112131

@@ -167,8 +186,23 @@ def __init__(self,
167186
try:
168187
import laion_clap
169188

170-
self.__dict__["model"] = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device=device).requires_grad_(False).eval()
171-
self.model.load_ckpt(clap_ckpt_path)
189+
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
190+
191+
if self.finetune:
192+
self.model = model
193+
else:
194+
self.__dict__["model"] = model
195+
196+
state_dict = clap_load_state_dict(clap_ckpt_path)
197+
self.model.model.load_state_dict(state_dict, strict=False)
198+
199+
if self.finetune:
200+
self.model.model.audio_branch.requires_grad_(True)
201+
self.model.model.audio_branch.train()
202+
else:
203+
self.model.model.audio_branch.requires_grad_(False)
204+
self.model.model.audio_branch.eval()
205+
172206
finally:
173207
logging.disable(previous_level)
174208

harmonai_tools/training/factory.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ def create_training_wrapper_from_config(model_config, model):
7575
elif model_type == 'diffusion_autoencoder':
7676
from .diffusion import DiffusionAutoencoderTrainingWrapper
7777

78-
7978
ema_copy = create_model_from_config(model_config)
80-
#ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
79+
8180
# Copy each weight to the ema copy
8281
for name, param in model.state_dict().items():
8382
if isinstance(param, Parameter):
@@ -92,8 +91,18 @@ def create_training_wrapper_from_config(model_config, model):
9291
)
9392
elif model_type == 'musicgen':
9493
from .musicgen import MusicGenTrainingWrapper
94+
95+
ema_copy = create_model_from_config(model_config).lm
96+
97+
for name, param in model.lm.state_dict().items():
98+
if isinstance(param, Parameter):
99+
# backwards compatibility for serialized parameters
100+
param = param.data
101+
ema_copy.state_dict()[name].copy_(param)
102+
95103
return MusicGenTrainingWrapper(
96104
model,
105+
ema_copy=ema_copy,
97106
lr=training_config["learning_rate"]
98107
)
99108
else:

harmonai_tools/training/musicgen.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __repr__(self):
3737

3838

3939
class MusicGenTrainingWrapper(pl.LightningModule):
40-
def __init__(self, musicgen_model, lr = 1e-4):
40+
def __init__(self, musicgen_model, lr = 1e-4, ema_copy=None):
4141
super().__init__()
4242

4343
self.musicgen_model: MusicGen = musicgen_model
@@ -48,6 +48,8 @@ def __init__(self, musicgen_model, lr = 1e-4):
4848

4949
self.lm.to(torch.float32).train().requires_grad_(True)
5050

51+
self.lm_ema = EMA(self.lm, ema_model=ema_copy, beta=0.99, update_every=10)
52+
5153
self.cfg_dropout = ClassifierFreeGuidanceDropout(0.1)
5254

5355
self.lr = lr
@@ -96,7 +98,9 @@ def _compute_cross_entropy(
9698

9799
def training_step(self, batch, batch_idx):
98100
reals, metadata = batch
99-
reals = reals[0]
101+
102+
if reals.ndim == 4 and reals.shape[0] == 1:
103+
reals = reals[0]
100104

101105
# Convert reals to mono if necessary
102106
if self.musicgen_model.audio_channels == 1:
@@ -113,7 +117,7 @@ def training_step(self, batch, batch_idx):
113117

114118
codes, _ = self.musicgen_model.compression_model.encode(reals) # [b, k, t]
115119

116-
attributes = [ConditioningAttributes(text={'description': md["prompt"][0]}) for md in metadata]
120+
attributes = [ConditioningAttributes(text={'description': md["prompt"][0][:512]}) for md in metadata]
117121
attributes = self.lm.cfg_dropout(attributes)
118122
attributes = self.lm.att_dropout(attributes)
119123
tokenized = self.lm.condition_provider.tokenize(attributes)
@@ -147,7 +151,11 @@ def training_step(self, batch, batch_idx):
147151
self.log_dict(log_dict, prog_bar=True, on_step=True)
148152
return loss
149153

154+
def on_before_zero_grad(self, *args, **kwargs):
155+
self.lm_ema.update()
156+
150157
def export_model(self, path):
158+
self.musicgen_model.lm = self.lm_ema.ema_model
151159
export_state_dict = {"state_dict": self.musicgen_model.state_dict()}
152160

153161
torch.save(export_state_dict, path)

harmonai_tools/training/utils.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,47 @@ def get_rank():
1414

1515
return torch.distributed.get_rank()
1616

17+
class InverseLR(torch.optim.lr_scheduler._LRScheduler):
18+
"""Implements an inverse decay learning rate schedule with an optional exponential
19+
warmup. When last_epoch=-1, sets initial lr as lr.
20+
inv_gamma is the number of steps/epochs required for the learning rate to decay to
21+
(1 / 2)**power of its original value.
22+
Args:
23+
optimizer (Optimizer): Wrapped optimizer.
24+
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
25+
power (float): Exponential factor of learning rate decay. Default: 1.
26+
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
27+
Default: 0.
28+
final_lr (float): The final learning rate. Default: 0.
29+
last_epoch (int): The index of last epoch. Default: -1.
30+
verbose (bool): If ``True``, prints a message to stdout for
31+
each update. Default: ``False``.
32+
"""
33+
34+
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
35+
last_epoch=-1, verbose=False):
36+
self.inv_gamma = inv_gamma
37+
self.power = power
38+
if not 0. <= warmup < 1:
39+
raise ValueError('Invalid value for warmup')
40+
self.warmup = warmup
41+
self.final_lr = final_lr
42+
super().__init__(optimizer, last_epoch, verbose)
43+
44+
def get_lr(self):
45+
if not self._get_lr_called_within_step:
46+
import warnings
47+
warnings.warn("To get the last learning rate computed by the scheduler, "
48+
"please use `get_last_lr()`.")
49+
50+
return self._get_closed_form_lr()
51+
52+
def _get_closed_form_lr(self):
53+
warmup = 1 - self.warmup ** (self.last_epoch + 1)
54+
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
55+
return [warmup * max(self.final_lr, base_lr * lr_mult)
56+
for base_lr in self.base_lrs]
57+
1758
def copy_state_dict(model, state_dict):
1859
"""Load state_dict to model, but only for keys that match exactly.
1960
@@ -55,6 +96,9 @@ def create_scheduler_from_config(scheduler_config, optimizer):
5596
Returns:
5697
torch.optim.lr_scheduler._LRScheduler: scheduler.
5798
"""
58-
scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
99+
if scheduler_config["type"] == "InverseLR":
100+
scheduler_fn = InverseLR
101+
else:
102+
scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
59103
scheduler = scheduler_fn(optimizer, **scheduler_config["config"])
60104
return scheduler

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'torchaudio>=2.0.2',
3535
'torchmetrics==0.11.4',
3636
'tqdm',
37-
'transformers==4.30.2',
37+
'transformers==4.33.3',
3838
'v-diffusion-pytorch==0.0.2',
3939
'vector-quantize-pytorch==1.6.21',
4040
'wandb==0.15.4',

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def main():
7474
else:
7575
strategy = args.strategy
7676
else:
77-
strategy = 'ddp' if args.num_gpus > 1 else None
77+
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else None
7878

7979
trainer = pl.Trainer(
8080
devices=args.num_gpus,

0 commit comments

Comments
 (0)