Skip to content

Commit 330a7a8

Browse files
committed
Update to 0.0.18
1 parent c5f5f7e commit 330a7a8

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='stable-audio-tools',
5-
version='0.0.17',
5+
version='0.0.18',
66
url='https://github.com/Stability-AI/stable-audio-tools.git',
77
author='Stability AI',
88
description='Training and inference tools for generative audio models from Stability AI',
@@ -25,7 +25,6 @@
2525
'prefigure==0.0.9',
2626
'pytorch_lightning==2.1.0',
2727
'PyWavelets==1.4.1',
28-
'pypesq==1.2.4',
2928
'safetensors',
3029
'sentencepiece==0.1.99',
3130
'torch>=2.0.1',

stable_audio_tools/models/conditioners.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,16 @@ class SourceMixConditioner(Conditioner):
545545
source_keys: a list of keys for the potential sources in the metadata
546546
547547
"""
548-
def __init__(self, pretransform: Pretransform, output_dim: int, save_pretransform: bool = False, source_keys: tp.List[str] = [], pre_encoded: bool = False):
548+
def __init__(
549+
self,
550+
pretransform: Pretransform,
551+
output_dim: int,
552+
save_pretransform: bool = False,
553+
source_keys: tp.List[str] = [],
554+
pre_encoded: bool = False,
555+
allow_null_source=False,
556+
source_length=None
557+
):
549558
super().__init__(pretransform.encoded_channels, output_dim)
550559

551560
if not save_pretransform:
@@ -559,16 +568,28 @@ def __init__(self, pretransform: Pretransform, output_dim: int, save_pretransfor
559568

560569
self.pre_encoded = pre_encoded
561570

571+
self.allow_null_source = allow_null_source
572+
573+
if self.allow_null_source:
574+
self.null_source = nn.Parameter(torch.randn(output_dim, 1))
575+
576+
assert source_length is not None, "Source length must be specified if allowing null sources"
577+
578+
self.source_length = source_length
579+
562580
def forward(self, sources: tp.List[tp.Dict[str, torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
563581

564582
self.pretransform.to(device)
565583
self.proj_out.to(device)
566584

585+
dtype = next(self.proj_out.parameters()).dtype
586+
567587
# Output has to be the batch of summed projections
568588
# Input is per-batch-item list of source audio
569589

570590
mixes = []
571591

592+
num_null_sources = 0
572593
for source_dict in sources: # Iterate over batch items
573594

574595
mix = None
@@ -579,14 +600,16 @@ def forward(self, sources: tp.List[tp.Dict[str, torch.Tensor]], device: tp.Union
579600
source = source_dict[key]
580601

581602
if not self.pre_encoded:
582-
audio = set_audio_channels(source, self.pretransform.io_channels)
603+
assert source.dim() == 2, f"Source audio must be shape [channels, samples], got shape: {source.shape}"
604+
audio = set_audio_channels(source.unsqueeze(0), self.pretransform.io_channels)
583605

584606
audio = audio.to(device)
585-
586-
latents = self.pretransform.encode(audio)
607+
latents = self.pretransform.encode(audio).squeeze(0)
587608
else:
588609
latents = source.to(device)
589610

611+
latents = latents.to(dtype)
612+
590613
if mix is None:
591614
mix = self.source_heads[key_ix](latents)
592615
else:
@@ -595,7 +618,10 @@ def forward(self, sources: tp.List[tp.Dict[str, torch.Tensor]], device: tp.Union
595618
if mix is not None:
596619
mixes.append(mix)
597620
else:
598-
raise ValueError("No sources found for mix")
621+
if self.allow_null_source:
622+
mixes.append(self.null_source.repeat(1, self.source_length))
623+
else:
624+
raise ValueError("No sources found for mix")
599625

600626
mixes = torch.stack(mixes, dim=0)
601627

0 commit comments

Comments
 (0)