Skip to content

Commit c6078a6

Browse files
authored
whisper : do not launch log_mel threads when n_thread is 1 (ggml-org#763)
1 parent 004f34b commit c6078a6

File tree

1 file changed

+66
-71
lines changed

1 file changed

+66
-71
lines changed

whisper.cpp

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,6 +2284,60 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
22842284
}
22852285
}
22862286

2287+
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
2288+
int n_samples, int fft_size, int fft_step, int n_threads,
2289+
const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2290+
std::vector<float> fft_in(fft_size, 0.0);
2291+
std::vector<float> fft_out(2 * fft_size);
2292+
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
2293+
2294+
for (int i = ith; i < mel.n_len; i += n_threads) {
2295+
const int offset = i * fft_step;
2296+
2297+
// apply Hanning window
2298+
for (int j = 0; j < fft_size; j++) {
2299+
if (offset + j < n_samples) {
2300+
fft_in[j] = hann[j] * samples[offset + j];
2301+
} else {
2302+
fft_in[j] = 0.0;
2303+
}
2304+
}
2305+
2306+
// FFT -> mag^2
2307+
fft(fft_in, fft_out);
2308+
2309+
for (int j = 0; j < fft_size; j++) {
2310+
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2311+
}
2312+
for (int j = 1; j < fft_size / 2; j++) {
2313+
fft_out[j] += fft_out[fft_size - j];
2314+
}
2315+
2316+
if (speed_up) {
2317+
// scale down in the frequency domain results in a speed up in the time domain
2318+
for (int j = 0; j < n_fft; j++) {
2319+
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
2320+
}
2321+
}
2322+
2323+
// mel spectrogram
2324+
for (int j = 0; j < mel.n_mel; j++) {
2325+
double sum = 0.0;
2326+
2327+
for (int k = 0; k < n_fft; k++) {
2328+
sum += fft_out[k] * filters.data[j * n_fft + k];
2329+
}
2330+
if (sum < 1e-10) {
2331+
sum = 1e-10;
2332+
}
2333+
2334+
sum = log10(sum);
2335+
2336+
mel.data[j * mel.n_len + i] = sum;
2337+
}
2338+
}
2339+
}
2340+
22872341
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
22882342
static bool log_mel_spectrogram(
22892343
whisper_state & wstate,
@@ -2310,81 +2364,22 @@ static bool log_mel_spectrogram(
23102364
mel.n_len = (n_samples)/fft_step;
23112365
mel.data.resize(mel.n_mel*mel.n_len);
23122366

2313-
const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
2314-
23152367
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
23162368
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
23172369

2318-
std::vector<std::thread> workers(n_threads);
2319-
for (int iw = 0; iw < n_threads; ++iw) {
2320-
workers[iw] = std::thread([&](int ith) {
2321-
std::vector<float> fft_in;
2322-
fft_in.resize(fft_size);
2323-
for (int i = 0; i < fft_size; i++) {
2324-
fft_in[i] = 0.0;
2325-
}
2326-
2327-
std::vector<float> fft_out;
2328-
fft_out.resize(2*fft_size);
2329-
2330-
for (int i = ith; i < mel.n_len; i += n_threads) {
2331-
const int offset = i*fft_step;
2332-
2333-
// apply Hanning window
2334-
for (int j = 0; j < fft_size; j++) {
2335-
if (offset + j < n_samples) {
2336-
fft_in[j] = hann[j]*samples[offset + j];
2337-
} else {
2338-
fft_in[j] = 0.0;
2339-
}
2340-
}
2341-
2342-
// FFT -> mag^2
2343-
fft(fft_in, fft_out);
2344-
2345-
for (int j = 0; j < fft_size; j++) {
2346-
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2347-
}
2348-
for (int j = 1; j < fft_size/2; j++) {
2349-
//if (i == 0) {
2350-
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2351-
//}
2352-
fft_out[j] += fft_out[fft_size - j];
2353-
}
2354-
if (i == 0) {
2355-
//for (int j = 0; j < fft_size; j++) {
2356-
// printf("%d: %e\n", j, fft_out[j]);
2357-
//}
2358-
}
2359-
2360-
if (speed_up) {
2361-
// scale down in the frequency domain results in a speed up in the time domain
2362-
for (int j = 0; j < n_fft; j++) {
2363-
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
2364-
}
2365-
}
2366-
2367-
// mel spectrogram
2368-
for (int j = 0; j < mel.n_mel; j++) {
2369-
double sum = 0.0;
2370-
2371-
for (int k = 0; k < n_fft; k++) {
2372-
sum += fft_out[k]*filters.data[j*n_fft + k];
2373-
}
2374-
if (sum < 1e-10) {
2375-
sum = 1e-10;
2376-
}
2377-
2378-
sum = log10(sum);
2379-
2380-
mel.data[j*mel.n_len + i] = sum;
2381-
}
2382-
}
2383-
}, iw);
2384-
}
2370+
if (n_threads == 1) {
2371+
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2372+
} else {
2373+
std::vector<std::thread> workers(n_threads);
2374+
for (int iw = 0; iw < n_threads; ++iw) {
2375+
workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
2376+
n_samples, fft_size, fft_step, n_threads,
2377+
std::cref(filters), speed_up, std::ref(mel));
2378+
}
23852379

2386-
for (int iw = 0; iw < n_threads; ++iw) {
2387-
workers[iw].join();
2380+
for (int iw = 0; iw < n_threads; ++iw) {
2381+
workers[iw].join();
2382+
}
23882383
}
23892384

23902385
// clamping and normalization

0 commit comments

Comments
 (0)