@@ -545,7 +545,16 @@ class SourceMixConditioner(Conditioner):
545
545
source_keys: a list of keys for the potential sources in the metadata
546
546
547
547
"""
548
- def __init__ (self , pretransform : Pretransform , output_dim : int , save_pretransform : bool = False , source_keys : tp .List [str ] = [], pre_encoded : bool = False ):
548
+ def __init__ (
549
+ self ,
550
+ pretransform : Pretransform ,
551
+ output_dim : int ,
552
+ save_pretransform : bool = False ,
553
+ source_keys : tp .List [str ] = [],
554
+ pre_encoded : bool = False ,
555
+ allow_null_source = False ,
556
+ source_length = None
557
+ ):
549
558
super ().__init__ (pretransform .encoded_channels , output_dim )
550
559
551
560
if not save_pretransform :
@@ -559,16 +568,28 @@ def __init__(self, pretransform: Pretransform, output_dim: int, save_pretransfor
559
568
560
569
self .pre_encoded = pre_encoded
561
570
571
+ self .allow_null_source = allow_null_source
572
+
573
+ if self .allow_null_source :
574
+ self .null_source = nn .Parameter (torch .randn (output_dim , 1 ))
575
+
576
+ assert source_length is not None , "Source length must be specified if allowing null sources"
577
+
578
+ self .source_length = source_length
579
+
562
580
def forward (self , sources : tp .List [tp .Dict [str , torch .Tensor ]], device : tp .Union [torch .device , str ]) -> tp .Tuple [torch .Tensor , torch .Tensor ]:
563
581
564
582
self .pretransform .to (device )
565
583
self .proj_out .to (device )
566
584
585
+ dtype = next (self .proj_out .parameters ()).dtype
586
+
567
587
# Output has to be the batch of summed projections
568
588
# Input is per-batch-item list of source audio
569
589
570
590
mixes = []
571
591
592
+ num_null_sources = 0
572
593
for source_dict in sources : # Iterate over batch items
573
594
574
595
mix = None
@@ -579,14 +600,16 @@ def forward(self, sources: tp.List[tp.Dict[str, torch.Tensor]], device: tp.Union
579
600
source = source_dict [key ]
580
601
581
602
if not self .pre_encoded :
582
- audio = set_audio_channels (source , self .pretransform .io_channels )
603
+ assert source .dim () == 2 , f"Source audio must be shape [channels, samples], got shape: { source .shape } "
604
+ audio = set_audio_channels (source .unsqueeze (0 ), self .pretransform .io_channels )
583
605
584
606
audio = audio .to (device )
585
-
586
- latents = self .pretransform .encode (audio )
607
+ latents = self .pretransform .encode (audio ).squeeze (0 )
587
608
else :
588
609
latents = source .to (device )
589
610
611
+ latents = latents .to (dtype )
612
+
590
613
if mix is None :
591
614
mix = self .source_heads [key_ix ](latents )
592
615
else :
@@ -595,7 +618,10 @@ def forward(self, sources: tp.List[tp.Dict[str, torch.Tensor]], device: tp.Union
595
618
if mix is not None :
596
619
mixes .append (mix )
597
620
else :
598
- raise ValueError ("No sources found for mix" )
621
+ if self .allow_null_source :
622
+ mixes .append (self .null_source .repeat (1 , self .source_length ))
623
+ else :
624
+ raise ValueError ("No sources found for mix" )
599
625
600
626
mixes = torch .stack (mixes , dim = 0 )
601
627
0 commit comments