-
-
Notifications
You must be signed in to change notification settings - Fork 33.4k
/
Copy pathwake_word.py
205 lines (169 loc) · 7.62 KB
/
wake_word.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Support for Wyoming wake-word-detection services."""
import asyncio
from collections.abc import AsyncIterable
import logging
from wyoming.audio import AudioChunk, AudioStart
from wyoming.client import AsyncTcpClient
from wyoming.wake import Detect, Detection
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService, load_wyoming_info
from .error import WyomingError
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming wake-word-detection."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingWakeWordProvider(hass, config_entry, item.service),
]
)
class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
"""Wyoming wake-word-detection provider."""
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.hass = hass
self.service = service
wake_service = service.info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
)
for ww in wake_service.models
]
self._attr_name = wake_service.name
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
info = await load_wyoming_info(
self.service.host, self.service.port, retries=0, timeout=1
)
if info is not None:
wake_service = info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(
id=ww.name,
name=ww.description or ww.name,
phrase=ww.phrase,
)
for ww in wake_service.models
]
return self._supported_wake_words
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None:
"""Try to detect one or more wake words in an audio stream.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
async def next_chunk():
"""Get the next chunk from audio stream."""
async for chunk_bytes in stream:
return chunk_bytes
return None
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Inform client which wake word we want to detect (None = default)
await client.write_event(
Detect(names=[wake_word_id] if wake_word_id else None).event()
)
await client.write_event(
AudioStart(
rate=16000,
width=2,
channels=1,
).event(),
)
# Read audio and wake events in "parallel"
audio_task = asyncio.create_task(next_chunk())
wake_task = asyncio.create_task(client.read_event())
pending = {audio_task, wake_task}
try:
while True:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
if wake_task in done:
event = wake_task.result()
if event is None:
_LOGGER.debug("Connection lost")
break
if Detection.is_type(event.type):
# Possible detection
detection = Detection.from_event(event)
_LOGGER.info(detection)
if wake_word_id and (detection.name != wake_word_id):
_LOGGER.warning(
"Expected wake word %s but got %s, skipping",
wake_word_id,
detection.name,
)
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
continue
# Retrieve queued audio
queued_audio: list[tuple[bytes, int]] | None = None
if audio_task in pending:
# Save queued audio
await audio_task
pending.remove(audio_task)
queued_audio = [audio_task.result()]
return wake_word.DetectionResult(
wake_word_id=detection.name,
wake_word_phrase=self._get_phrase(detection.name),
timestamp=detection.timestamp,
queued_audio=queued_audio,
)
# Next event
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
if audio_task in done:
# Forward audio to wake service
chunk_info = audio_task.result()
if chunk_info is None:
break
chunk_bytes, chunk_timestamp = chunk_info
chunk = AudioChunk(
rate=16000,
width=2,
channels=1,
audio=chunk_bytes,
timestamp=chunk_timestamp,
)
await client.write_event(chunk.event())
# Next chunk
audio_task = asyncio.create_task(next_chunk())
pending.add(audio_task)
finally:
# Clean up
if audio_task in pending:
# It's critical that we don't cancel the audio task or
# leave it hanging. This would mess up the pipeline STT
# by stopping the audio stream.
await audio_task
pending.remove(audio_task)
for task in pending:
task.cancel()
except (OSError, WyomingError):
_LOGGER.exception("Error processing audio stream")
return None
def _get_phrase(self, model_id: str) -> str:
"""Get wake word phrase for model id."""
for ww_model in self._supported_wake_words:
if not ww_model.phrase:
continue
if ww_model.id == model_id:
return ww_model.phrase
return model_id