Skip to content

Commit 6955b5e

Browse files
Jamezo97wkpark
authored andcommitted
minimal fix to support Windows
based on @Jamezo97 and @acpopescu work manually cherry-picked from PR bitsandbytes-foundation#788 and PR bitsandbytes-foundation#229 and cleanup by wkpark Signed-off-by: Won-Kyu Park <[email protected]>
1 parent b90db7e commit 6955b5e

File tree

4 files changed

+34
-8
lines changed

4 files changed

+34
-8
lines changed

csrc/cpu_ops.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <BinSearch.h>
2+
#ifdef _WIN32
3+
#include <thread>
4+
#else
25
#include <pthread.h>
6+
#endif
37
#include <common.h>
48

59
using namespace BinSearch;
@@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
3135
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
3236
{
3337
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
38+
#ifdef _WIN32
39+
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
40+
#else
3441
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
42+
#endif
3543

3644
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
3745

@@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
5563
arg->threadidx = block_idx / blocksize;
5664
arg->blocksize = blocksize;
5765

66+
#ifdef _WIN32
67+
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
68+
#else
5869
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
70+
#endif
5971
chunks_processed += 1;
6072
if(chunks_processed == valid_chunks){ break; }
6173
}
6274

6375
for (int i = 0; i < valid_chunks; i++)
76+
{
77+
#ifdef _WIN32
78+
threads[i].join();
79+
#else
6480
int err = pthread_join(threads[i], NULL);
65-
81+
#endif
82+
}
6683
free(threads);
6784
for (int i = 0; i < valid_chunks; i++)
6885
free(args[i]);

csrc/kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3821,12 +3821,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
38213821
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
38223822
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
38233823

3824-
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3825-
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3826-
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3827-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3828-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3829-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3824+
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3825+
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3826+
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3827+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3828+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3829+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
38303830

38313831
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38323832
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

csrc/ops.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include <stdio.h>
1111
#include <iostream>
12-
#include <unistd.h>
1312
#include <assert.h>
1413

1514
#include <cuda_runtime_api.h>

include/SIMD.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ template <> struct InstrFloatTraits<SSE, double>
6464
typedef __m128d vec_t;
6565
};
6666

67+
template <> struct InstrFloatTraits<Scalar, float>
68+
{
69+
typedef float vec_t;
70+
};
71+
72+
template <> struct InstrFloatTraits<Scalar, double>
73+
{
74+
typedef double vec_t;
75+
};
76+
6777
template <InstrSet I, typename T>
6878
struct FTOITraits
6979
{

0 commit comments

Comments
 (0)