1
1
#include < BinSearch.h>
2
- #ifdef _WIN32
3
- #include < thread>
4
- #else
5
- #include < pthread.h>
6
- #endif
7
2
#include < common.h>
3
+ #include < thread>
8
4
9
5
using namespace BinSearch ;
10
6
@@ -30,61 +26,38 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
30
26
BinAlgo<Scalar, float , Direct2> bin_searcher (code, elements_code);
31
27
32
28
int thread_wave_size = 256 ;
33
- // we chunk the thresds into waves of 256 since the max limit is
29
+ // we chunk the threads into waves of 256 since the max limit is
34
30
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
35
31
for (long long offset = 0 ; offset < num_blocks; offset+=thread_wave_size)
36
32
{
37
33
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
38
- #ifdef _WIN32
39
- std::thread *threads = (std::thread *) malloc (sizeof (std::thread) * valid_chunks);
40
- #else
41
- pthread_t *threads = (pthread_t *) malloc (sizeof (pthread_t ) * valid_chunks);
42
- #endif
43
-
44
- struct quantize_block_args **args = (quantize_block_args **) malloc (valid_chunks * sizeof (quantize_block_args *));
45
-
46
- for (long long i = 0 ; i < valid_chunks; i++)
47
- args[i] = (quantize_block_args *) malloc (sizeof (quantize_block_args));
34
+ std::vector<std::thread> threads (valid_chunks);
35
+ std::vector<quantize_block_args> args (valid_chunks);
48
36
49
37
int chunks_processed = 0 ;
50
38
for (long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
51
39
{
52
40
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
53
41
long long block_end = block_idx + valid_items;
54
42
55
- struct quantize_block_args *arg = args[chunks_processed];
56
- arg->bin_searcher = &bin_searcher;
57
- arg->code = code;
58
- arg->A = A;
59
- arg->absmax = absmax;
60
- arg->out = out;
61
- arg->block_end = block_end;
62
- arg->block_idx = block_idx;
63
- arg->threadidx = block_idx / blocksize;
64
- arg->blocksize = blocksize;
65
-
66
- #ifdef _WIN32
67
- new (&threads[chunks_processed]) std::thread (quantize_block, arg);
68
- #else
69
- pthread_create (&threads[chunks_processed], NULL , &quantize_block, (void *) arg);
70
- #endif
43
+ struct quantize_block_args & arg = args[chunks_processed];
44
+ arg.bin_searcher = &bin_searcher;
45
+ arg.code = code;
46
+ arg.A = A;
47
+ arg.absmax = absmax;
48
+ arg.out = out;
49
+ arg.block_end = block_end;
50
+ arg.block_idx = block_idx;
51
+ arg.threadidx = block_idx / blocksize;
52
+ arg.blocksize = blocksize;
53
+
54
+ threads[chunks_processed] = std::thread ([arg] { quantize_block (arg); });
71
55
chunks_processed += 1 ;
72
56
if (chunks_processed == valid_chunks){ break ; }
73
57
}
74
58
75
59
for (int i = 0 ; i < valid_chunks; i++)
76
- {
77
- #ifdef _WIN32
78
60
threads[i].join ();
79
- #else
80
- int err = pthread_join (threads[i], NULL );
81
- #endif
82
- }
83
- free (threads);
84
- for (int i = 0 ; i < valid_chunks; i++)
85
- free (args[i]);
86
- free (args);
87
-
88
61
}
89
62
90
63
}
0 commit comments