This repository was archived by the owner on Aug 28, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 194
/
Copy pathcreate_wav2vec2.py
65 lines (56 loc) · 2.39 KB
/
create_wav2vec2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import Tensor
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from transformers import Wav2Vec2ForCTC
# Wav2vec2 model emits sequences of probability (logits) distributions over the characters
# The following class adds steps to decode the transcript (best path)
class SpeechRecognizer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.labels = [
"<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S",
"R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X",
"J", "Q", "Z"]
def forward(self, waveforms: Tensor) -> str:
"""Given a single channel speech data, return transcription.
Args:
waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
Returns:
str: The resulting transcript
"""
logits, _ = self.model(waveforms) # [batch, num_seq, num_label]
best_path = torch.argmax(logits[0], dim=-1) # [num_seq,]
prev = ''
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char == prev:
continue
if char == '<s>':
prev = ''
continue
hypothesis += char
prev = char
return hypothesis.replace('|', ' ')
# Load Wav2Vec2 pretrained model from Hugging Face Hub
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Convert the model to torchaudio format, which supports TorchScript.
model = import_huggingface_model(model)
# Remove weight normalization which is not supported by quantization.
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
model = model.eval()
# Attach decoder
model = SpeechRecognizer(model)
# Apply quantization / script / optimize for motbile
quantized_model = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)
# Sanity check
waveform , _ = torchaudio.load('scent_of_a_woman_future.wav')
print(waveform.size())
print('Result:', optimized_model(waveform))
optimized_model.save("SpeechRecognition/wav2vec2.pt")