Skip to content
This repository was archived by the owner on May 13, 2025. It is now read-only.

Commit 9833412

Browse files
committed
Whisper support via FasterWhisper
1 parent 54d1d71 commit 9833412

File tree

9 files changed

+279
-63
lines changed

9 files changed

+279
-63
lines changed

modules/faster_whisper/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from fastapi import FastAPI
2+
3+
from modules.faster_whisper.load import faster_whisper_load
4+
from modules.faster_whisper.unload import faster_whisper_unload
5+
from modules.faster_whisper.action import faster_whisper_action
6+
7+
8+
def setup_faster_whisper(app: FastAPI) -> None:
9+
"""
10+
Setup FasterWhisper routes.
11+
"""
12+
13+
app.post("/faster_whisper/load/")(faster_whisper_load)
14+
15+
app.post("/faster_whisper/unload")(faster_whisper_unload)
16+
17+
app.post("/faster_whisper/action")(faster_whisper_action)

modules/faster_whisper/action.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from modules.state import get_inference
2+
from modules.faster_whisper.inference import FasterWhisperInference, inference_name
3+
4+
from pydantic import BaseModel
5+
6+
7+
class FasterWhisperInferenceData(BaseModel):
8+
"""
9+
Task schema for FasterWhisper actions.
10+
11+
Attributes:
12+
- audio: str: The audio to transcribe (base64 encoded)
13+
"""
14+
15+
audio: str
16+
17+
18+
async def faster_whisper_action(data: FasterWhisperInferenceData):
19+
"""
20+
Use FasterWhisper to transcribe audio.
21+
"""
22+
23+
inference: FasterWhisperInference = get_inference(inference_name)
24+
25+
result = inference.inference(data.audio)
26+
27+
return result

modules/faster_whisper/inference.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import base64
2+
from io import BytesIO
3+
4+
import torch
5+
from faster_whisper import WhisperModel
6+
7+
8+
class FasterWhisperInference:
9+
def __init__(
10+
self,
11+
model: str,
12+
device: str,
13+
):
14+
if device == "cuda":
15+
if not torch.cuda.is_available():
16+
raise ValueError("CUDA is not available on this device.")
17+
else:
18+
self.device = "cuda"
19+
else:
20+
self.device = "cpu"
21+
22+
self.compute_type = "float16" if self.device == "cuda" else "float32"
23+
24+
self.model = WhisperModel(
25+
model,
26+
device=self.device,
27+
compute_type=self.compute_type,
28+
)
29+
30+
def __del__(self):
31+
del self.model
32+
try:
33+
torch.cuda.empty_cache()
34+
except:
35+
pass
36+
37+
def inference(
38+
self,
39+
audioRaw: str,
40+
) -> any:
41+
fileBytes = base64.b64decode(audioRaw)
42+
43+
segments, info = self.model.transcribe(BytesIO(fileBytes), beam_size=5)
44+
segments_list = list(segments) # The transcription will actually run here.
45+
46+
return " ".join([segment.text for segment in segments_list]).strip()
47+
48+
49+
inference_name = FasterWhisperInference.__name__

modules/faster_whisper/load.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from modules.state import load_inference
2+
from modules.faster_whisper.inference import FasterWhisperInference, inference_name
3+
4+
from pydantic import BaseModel
5+
6+
7+
class FasterWhisperData(BaseModel):
8+
"""
9+
Task schema for loading FasterWhisper inference.
10+
11+
Attributes:
12+
- model: str: The model size to download from HuggingFace hub.
13+
- device: str: The device to load the model to.
14+
"""
15+
16+
model: str
17+
device: str
18+
force_reload: bool = True
19+
20+
21+
async def faster_whisper_load(data: FasterWhisperData):
22+
"""
23+
Load a FasterWhisper model to RAM/VRAM.
24+
"""
25+
26+
load_inference(
27+
inference_name,
28+
FasterWhisperInference(model=data.model, device=data.device),
29+
force_reload=data.force_reload,
30+
)

modules/faster_whisper/unload.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from modules.state import unload_inference
2+
from modules.faster_whisper.inference import inference_name
3+
4+
5+
async def faster_whisper_unload():
6+
"""
7+
Unload the FasterWhisper model from RAM/VRAM.
8+
"""
9+
10+
unload_inference(inference_name)

modules/state/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@ def unload_inference(inference_name: str) -> None:
1414
gc.collect()
1515

1616

17-
def load_inference(inference_name: str, inference: any) -> None:
17+
def load_inference(inference_name: str, inference: any, force_reload=True) -> None:
1818
"""
1919
Load inference module into state.
2020
"""
2121

22-
if inference_name in state:
23-
unload_inference(inference_name)
24-
state[inference_name] = inference
22+
if force_reload:
23+
if inference_name in state:
24+
unload_inference(inference_name)
25+
state[inference_name] = inference
26+
else:
27+
if inference_name in state:
28+
return
29+
else:
30+
state[inference_name] = inference
2531

2632

2733
def get_inference(inference_name: str) -> any:

server.py

+12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,19 @@
33
sys.dont_write_bytecode = True
44

55

6+
def setup_cuda_env():
7+
import os
8+
9+
os.environ["LD_LIBRARY_PATH"] = os.path.join(
10+
os.getcwd(), "miniconda", "envs", "oc_external", "lib"
11+
)
12+
13+
614
if __name__ == "__main__":
15+
setup_cuda_env()
16+
717
from modules.florence2 import setup_florence2
18+
from modules.faster_whisper import setup_faster_whisper
819

920
from custom_modules import get_custom_modules
1021

@@ -37,6 +48,7 @@ async def ping():
3748

3849
# Setup internal module routes.
3950
setup_florence2(app)
51+
setup_faster_whisper(app)
4052

4153
# Setup custom module routes.
4254
for custom_module_setup in get_custom_modules():

0 commit comments

Comments
 (0)