Skip to content

Commit 897f31e

Browse files
authored
Merge pull request #178 from r9y9/fix-pytorch13
Fixes for pytorch 1.3
2 parents f04a271 + f6f87aa commit 897f31e

File tree

5 files changed

+20
-14
lines changed

5 files changed

+20
-14
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ python synthesis.py --preset=20180505_deepvoice3_ljspeech.json \
8484

8585
## Requirements
8686

87-
- Python 3 (<= 3.6)
87+
- Python >= 3.5
8888
- CUDA >= 8.0
89-
- PyTorch >= v0.4.0
89+
- PyTorch >= v1.0.0
9090
- [nnmnkwii](https://github.com/r9y9/nnmnkwii) >= v0.0.11
9191
- [MeCab](http://taku910.github.io/mecab/) (Japanese only)
9292

deepvoice3_pytorch/modules.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def get_mask_from_lengths(memory, memory_lengths):
235235
memory: (batch, max_time, dim)
236236
memory_lengths: array like
237237
"""
238-
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
239-
for idx, l in enumerate(memory_lengths):
240-
mask[idx][:l] = 1
238+
max_len = max(memory_lengths)
239+
mask = torch.arange(max_len).expand(memory.size(0), max_len) < torch.tensor(memory_lengths).unsqueeze(-1)
240+
mask = mask.to(memory.device)
241241
return ~mask

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def create_readme_rst():
7979
install_requires=[
8080
"numpy",
8181
"scipy",
82-
"torch >= 0.4.0",
82+
"torch >= 1.0.0",
8383
"unidecode",
8484
"inflect",
8585
"librosa",

tests/test_deepvoice3.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
def _get_model(n_speakers=1, speaker_embed_dim=None,
2828
force_monotonic_attention=False,
29-
use_decoder_state_for_postnet_input=False):
29+
use_decoder_state_for_postnet_input=False, use_memory_mask=False):
3030
model = deepvoice3(n_vocab=n_vocab,
3131
embed_dim=32,
3232
mel_dim=num_mels,
@@ -42,6 +42,7 @@ def _get_model(n_speakers=1, speaker_embed_dim=None,
4242
converter_channels=32,
4343
force_monotonic_attention=force_monotonic_attention,
4444
use_decoder_state_for_postnet_input=use_decoder_state_for_postnet_input,
45+
use_memory_mask=use_memory_mask,
4546
)
4647
return model
4748

@@ -62,7 +63,7 @@ def _test_data():
6263
x = torch.LongTensor(seqs)
6364
y = torch.rand(x.size(0), 12, 80)
6465

65-
return x, y
66+
return x, y, input_lengths
6667

6768

6869
def _deepvoice3(n_vocab, embed_dim=256, mel_dim=80,
@@ -110,11 +111,14 @@ def _deepvoice3(n_vocab, embed_dim=256, mel_dim=80,
110111

111112

112113
def test_single_speaker_deepvoice3():
113-
x, y = _test_data()
114+
x, y, lengths = _test_data()
114115

115116
for v in [False, True]:
116117
model = _get_model(use_decoder_state_for_postnet_input=v)
117-
mel_outputs, linear_outputs, alignments, done = model(x, y)
118+
mel_outputs, linear_outputs, alignments, done = model(x, y, input_lengths=lengths)
119+
120+
model = _get_model(use_memory_mask=True)
121+
mel_outputs, linear_outputs, alignments, done = model(x, y, input_lengths=lengths)
118122

119123

120124
def _pad_2d(x, max_len, b_pad=0):
@@ -192,7 +196,7 @@ def test_incremental_correctness():
192196
assert max_target_len % r == 0
193197
mel = _pad_2d(mel, max_target_len)
194198
mel = torch.from_numpy(mel)
195-
mel_reshaped = mel.view(1, -1, mel_dim * r)
199+
mel_reshaped = mel.contiguous().view(1, -1, mel_dim * r)
196200
frame_positions = np.arange(1, mel_reshaped.size(1) + 1).reshape(1, mel_reshaped.size(1))
197201

198202
x = torch.LongTensor(seqs)
@@ -269,7 +273,7 @@ def test_incremental_forward():
269273
assert max_target_len % r == 0
270274
mel = _pad_2d(mel, max_target_len)
271275
mel = torch.from_numpy(mel)
272-
mel_reshaped = mel.view(1, -1, mel_dim * r)
276+
mel_reshaped = mel.contiguous().view(1, -1, mel_dim * r)
273277

274278
frame_positions = np.arange(1, mel_reshaped.size(1) + 1).reshape(1, mel_reshaped.size(1))
275279

train.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,10 @@ def collate_fn(batch):
340340
s, e = 1, max_decoder_target_len + 1
341341
# if b_pad > 0:
342342
# s, e = s - 1, e - 1
343+
# NOTE: needs clone to supress RuntimeError in dataloarder...
344+
# ref: https://github.com/pytorch/pytorch/issues/10756
343345
frame_positions = torch.arange(s, e).long().unsqueeze(0).expand(
344-
len(batch), max_decoder_target_len)
346+
len(batch), max_decoder_target_len).clone()
345347

346348
# done flags
347349
done = np.array([_pad(np.zeros(len(x[1]) // r // downsample_step - 1),
@@ -963,7 +965,7 @@ def restore_parts(path, model):
963965
data_loader = data_utils.DataLoader(
964966
dataset, batch_size=hparams.batch_size,
965967
num_workers=hparams.num_workers, sampler=sampler,
966-
collate_fn=collate_fn, pin_memory=hparams.pin_memory)
968+
collate_fn=collate_fn, pin_memory=hparams.pin_memory, drop_last=True)
967969

968970
device = torch.device("cuda" if use_cuda else "cpu")
969971

0 commit comments

Comments
 (0)