Skip to content

Commit 6a4fd2b

Browse files
committed
fix workgroup size hardcode
1 parent a7614fa commit 6a4fd2b

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

ggml-sycl.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
4747
bool ggml_backend_is_sycl(ggml_backend_t backend);
4848
int ggml_backend_sycl_get_device(ggml_backend_t backend);
4949
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
50+
static inline int get_sycl_env(const char *env_name, int default_val);
51+
static inline int get_work_group_size(const sycl::device& device);
5052

5153
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
5254
const void *ptr_src, size_t size) {
@@ -1768,8 +1770,7 @@ static void norm_f32_sycl(const float *x, float *dst, const int ncols,
17681770
});
17691771
});
17701772
} else {
1771-
// FIXME: 1024 from cuda
1772-
const int work_group_size = GROUP_SIZE;
1773+
const int work_group_size = get_work_group_size(stream->get_device());
17731774
const sycl::range<3> block_dims(1, 1, work_group_size);
17741775
/*
17751776
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -1815,7 +1816,7 @@ static void group_norm_f32_sycl(const float *x, float *dst,
18151816
});
18161817
});
18171818
} else {
1818-
const int work_group_size = GROUP_SIZE;
1819+
const int work_group_size = get_work_group_size(stream->get_device());
18191820
const sycl::range<3> block_dims(1, 1, work_group_size);
18201821
/*
18211822
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -1904,7 +1905,7 @@ static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols,
19041905
});
19051906
});
19061907
} else {
1907-
const int work_group_size = GROUP_SIZE;
1908+
const int work_group_size = get_work_group_size(stream->get_device());
19081909
const sycl::range<3> block_dims(1, 1, work_group_size);
19091910
/*
19101911
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
@@ -2444,7 +2445,7 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
24442445
const int nrows_y, const float scale, const float max_bias,
24452446
queue_ptr stream) {
24462447
int nth = WARP_SIZE;
2447-
int max_block_size = GROUP_SIZE;
2448+
int max_block_size = get_work_group_size(stream->get_device());
24482449
while (nth < ncols_x && nth < max_block_size) nth *= 2;
24492450
if (nth>max_block_size) nth = max_block_size;
24502451

@@ -2596,7 +2597,7 @@ void ggml_backend_sycl_print_sycl_devices() {
25962597
}
25972598
}
25982599

2599-
int get_sycl_env(const char *env_name, int default_val) {
2600+
static inline int get_sycl_env(const char *env_name, int default_val) {
26002601
char *user_device_string = getenv(env_name);
26012602
int user_number = default_val;
26022603

@@ -2610,10 +2611,9 @@ int get_sycl_env(const char *env_name, int default_val) {
26102611
return user_number;
26112612
}
26122613

2613-
int get_work_group_size(int user_device_id) {
2614+
static inline int get_work_group_size(const sycl::device& device) {
26142615
dpct::device_info prop;
2615-
dpct::get_device_info(prop,
2616-
dpct::dev_mgr::instance().get_device(user_device_id));
2616+
dpct::get_device_info(prop, device);
26172617
return prop.get_max_work_group_size();
26182618
}
26192619

ggml-sycl/presets.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#define GGML_SYCL_MAX_DEVICES 48
1919
#define GGML_SYCL_NAME "SYCL"
2020

21-
// FIXME: 1024 from cuda
22-
#define GROUP_SIZE 1024
2321
#define WARP_SIZE 32
2422
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
2523

0 commit comments

Comments
 (0)