Skip to content

Commit 332530b

Browse files
authored
quantize_block C->C++, use std::thread everywhere (bitsandbytes-foundation#1024)
1 parent 8c507d9 commit 332530b

File tree

3 files changed

+27
-58
lines changed

3 files changed

+27
-58
lines changed

csrc/common.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,35 @@
11
#include <common.h>
22
#include <float.h>
33

4-
void *quantize_block(void *arguments) {
4+
void quantize_block(const quantize_block_args& args) {
55
// 1. find absmax in block
66
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
77
// 3. do binary search to find the closest value
88
// 4. check minimal distance
99
// 5. store index
1010

11-
struct quantize_block_args *args = (quantize_block_args *) arguments;
12-
1311
// 1. find absmax in block
1412
float absmax_block = -FLT_MAX;
15-
for (long long i = args->block_idx; i < args->block_end; i++)
16-
absmax_block = fmax(absmax_block, fabs(args->A[i]));
13+
for (long long i = args.block_idx; i < args.block_end; i++)
14+
absmax_block = fmax(absmax_block, fabs(args.A[i]));
1715

18-
args->absmax[args->block_idx / args->blocksize] = absmax_block;
16+
args.absmax[args.block_idx / args.blocksize] = absmax_block;
1917

20-
for (long long i = args->block_idx; i < args->block_end; i++) {
18+
for (long long i = args.block_idx; i < args.block_end; i++) {
2119
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
2220
// 3. do binary search to find the closest value
23-
float normed_value = args->A[i] / absmax_block;
24-
long long idx = args->bin_searcher->scalar(normed_value);
21+
float normed_value = args.A[i] / absmax_block;
22+
long long idx = args.bin_searcher->scalar(normed_value);
2523

2624
// 4. check minimal distance
2725
// The binary search returns always the value to the left, which might not be the closest value
2826
if (idx < 255) {
29-
float dist_left = fabs(normed_value - (args->code[idx]));
30-
float dist_right = fabs(normed_value - (args->code[idx + 1]));
27+
float dist_left = fabs(normed_value - (args.code[idx]));
28+
float dist_right = fabs(normed_value - (args.code[idx + 1]));
3129
if (dist_right < dist_left) { idx += 1; }
3230
}
3331

3432
// 5. store index
35-
args->out[i] = (unsigned char) idx;
33+
args.out[i] = (unsigned char) idx;
3634
}
37-
38-
return NULL;
3935
}

csrc/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ struct quantize_block_args {
2020
};
2121

2222

23-
void *quantize_block(void *arguments);
23+
void quantize_block(const quantize_block_args& args);
2424

2525
#endif

csrc/cpu_ops.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
#include <BinSearch.h>
2-
#ifdef _WIN32
3-
#include <thread>
4-
#else
5-
#include <pthread.h>
6-
#endif
72
#include <common.h>
3+
#include <thread>
84

95
using namespace BinSearch;
106

@@ -30,61 +26,38 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
3026
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
3127

3228
int thread_wave_size = 256;
33-
// we chunk the thresds into waves of 256 since the max limit is
29+
// we chunk the threads into waves of 256 since the max limit is
3430
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
3531
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
3632
{
3733
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
38-
#ifdef _WIN32
39-
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
40-
#else
41-
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
42-
#endif
43-
44-
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
45-
46-
for(long long i = 0; i < valid_chunks; i++)
47-
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
34+
std::vector<std::thread> threads(valid_chunks);
35+
std::vector<quantize_block_args> args(valid_chunks);
4836

4937
int chunks_processed = 0;
5038
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
5139
{
5240
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
5341
long long block_end = block_idx + valid_items;
5442

55-
struct quantize_block_args *arg = args[chunks_processed];
56-
arg->bin_searcher = &bin_searcher;
57-
arg->code = code;
58-
arg->A = A;
59-
arg->absmax = absmax;
60-
arg->out = out;
61-
arg->block_end = block_end;
62-
arg->block_idx = block_idx;
63-
arg->threadidx = block_idx / blocksize;
64-
arg->blocksize = blocksize;
65-
66-
#ifdef _WIN32
67-
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
68-
#else
69-
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
70-
#endif
43+
struct quantize_block_args& arg = args[chunks_processed];
44+
arg.bin_searcher = &bin_searcher;
45+
arg.code = code;
46+
arg.A = A;
47+
arg.absmax = absmax;
48+
arg.out = out;
49+
arg.block_end = block_end;
50+
arg.block_idx = block_idx;
51+
arg.threadidx = block_idx / blocksize;
52+
arg.blocksize = blocksize;
53+
54+
threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
7155
chunks_processed += 1;
7256
if(chunks_processed == valid_chunks){ break; }
7357
}
7458

7559
for (int i = 0; i < valid_chunks; i++)
76-
{
77-
#ifdef _WIN32
7860
threads[i].join();
79-
#else
80-
int err = pthread_join(threads[i], NULL);
81-
#endif
82-
}
83-
free(threads);
84-
for (int i = 0; i < valid_chunks; i++)
85-
free(args[i]);
86-
free(args);
87-
8861
}
8962

9063
}

0 commit comments

Comments
 (0)