26
26
27
27
def _get_model (n_speakers = 1 , speaker_embed_dim = None ,
28
28
force_monotonic_attention = False ,
29
- use_decoder_state_for_postnet_input = False ):
29
+ use_decoder_state_for_postnet_input = False , use_memory_mask = False ):
30
30
model = deepvoice3 (n_vocab = n_vocab ,
31
31
embed_dim = 32 ,
32
32
mel_dim = num_mels ,
@@ -42,6 +42,7 @@ def _get_model(n_speakers=1, speaker_embed_dim=None,
42
42
converter_channels = 32 ,
43
43
force_monotonic_attention = force_monotonic_attention ,
44
44
use_decoder_state_for_postnet_input = use_decoder_state_for_postnet_input ,
45
+ use_memory_mask = use_memory_mask ,
45
46
)
46
47
return model
47
48
@@ -62,7 +63,7 @@ def _test_data():
62
63
x = torch .LongTensor (seqs )
63
64
y = torch .rand (x .size (0 ), 12 , 80 )
64
65
65
- return x , y
66
+ return x , y , input_lengths
66
67
67
68
68
69
def _deepvoice3 (n_vocab , embed_dim = 256 , mel_dim = 80 ,
@@ -110,11 +111,14 @@ def _deepvoice3(n_vocab, embed_dim=256, mel_dim=80,
110
111
111
112
112
113
def test_single_speaker_deepvoice3 ():
113
- x , y = _test_data ()
114
+ x , y , lengths = _test_data ()
114
115
115
116
for v in [False , True ]:
116
117
model = _get_model (use_decoder_state_for_postnet_input = v )
117
- mel_outputs , linear_outputs , alignments , done = model (x , y )
118
+ mel_outputs , linear_outputs , alignments , done = model (x , y , input_lengths = lengths )
119
+
120
+ model = _get_model (use_memory_mask = True )
121
+ mel_outputs , linear_outputs , alignments , done = model (x , y , input_lengths = lengths )
118
122
119
123
120
124
def _pad_2d (x , max_len , b_pad = 0 ):
@@ -192,7 +196,7 @@ def test_incremental_correctness():
192
196
assert max_target_len % r == 0
193
197
mel = _pad_2d (mel , max_target_len )
194
198
mel = torch .from_numpy (mel )
195
- mel_reshaped = mel .view (1 , - 1 , mel_dim * r )
199
+ mel_reshaped = mel .contiguous (). view (1 , - 1 , mel_dim * r )
196
200
frame_positions = np .arange (1 , mel_reshaped .size (1 ) + 1 ).reshape (1 , mel_reshaped .size (1 ))
197
201
198
202
x = torch .LongTensor (seqs )
@@ -269,7 +273,7 @@ def test_incremental_forward():
269
273
assert max_target_len % r == 0
270
274
mel = _pad_2d (mel , max_target_len )
271
275
mel = torch .from_numpy (mel )
272
- mel_reshaped = mel .view (1 , - 1 , mel_dim * r )
276
+ mel_reshaped = mel .contiguous (). view (1 , - 1 , mel_dim * r )
273
277
274
278
frame_positions = np .arange (1 , mel_reshaped .size (1 ) + 1 ).reshape (1 , mel_reshaped .size (1 ))
275
279
0 commit comments