@@ -47,6 +47,8 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
47
47
bool ggml_backend_is_sycl (ggml_backend_t backend);
48
48
int ggml_backend_sycl_get_device (ggml_backend_t backend);
49
49
static bool ggml_backend_buffer_is_sycl_split (ggml_backend_buffer_t buffer);
50
+ static inline int get_sycl_env (const char *env_name, int default_val);
51
+ static inline int get_work_group_size (const sycl::device& device);
50
52
51
53
void dev2dev_memcpy (sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
52
54
const void *ptr_src, size_t size) {
@@ -1768,8 +1770,7 @@ static void norm_f32_sycl(const float *x, float *dst, const int ncols,
1768
1770
});
1769
1771
});
1770
1772
} else {
1771
- // FIXME: 1024 from cuda
1772
- const int work_group_size = GROUP_SIZE;
1773
+ const int work_group_size = get_work_group_size (stream->get_device ());
1773
1774
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
1774
1775
/*
1775
1776
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -1815,7 +1816,7 @@ static void group_norm_f32_sycl(const float *x, float *dst,
1815
1816
});
1816
1817
});
1817
1818
} else {
1818
- const int work_group_size = GROUP_SIZE ;
1819
+ const int work_group_size = get_work_group_size (stream-> get_device ()) ;
1819
1820
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
1820
1821
/*
1821
1822
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -1904,7 +1905,7 @@ static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols,
1904
1905
});
1905
1906
});
1906
1907
} else {
1907
- const int work_group_size = GROUP_SIZE ;
1908
+ const int work_group_size = get_work_group_size (stream-> get_device ()) ;
1908
1909
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
1909
1910
/*
1910
1911
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
@@ -2444,7 +2445,7 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
2444
2445
const int nrows_y, const float scale, const float max_bias,
2445
2446
queue_ptr stream) {
2446
2447
int nth = WARP_SIZE;
2447
- int max_block_size = GROUP_SIZE ;
2448
+ int max_block_size = get_work_group_size (stream-> get_device ()) ;
2448
2449
while (nth < ncols_x && nth < max_block_size) nth *= 2 ;
2449
2450
if (nth>max_block_size) nth = max_block_size;
2450
2451
@@ -2596,7 +2597,7 @@ void ggml_backend_sycl_print_sycl_devices() {
2596
2597
}
2597
2598
}
2598
2599
2599
- int get_sycl_env (const char *env_name, int default_val) {
2600
+ static inline int get_sycl_env (const char *env_name, int default_val) {
2600
2601
char *user_device_string = getenv (env_name);
2601
2602
int user_number = default_val;
2602
2603
@@ -2610,10 +2611,9 @@ int get_sycl_env(const char *env_name, int default_val) {
2610
2611
return user_number;
2611
2612
}
2612
2613
2613
- int get_work_group_size (int user_device_id ) {
2614
+ static inline int get_work_group_size (const sycl::device& device ) {
2614
2615
dpct::device_info prop;
2615
- dpct::get_device_info (prop,
2616
- dpct::dev_mgr::instance ().get_device (user_device_id));
2616
+ dpct::get_device_info (prop, device);
2617
2617
return prop.get_max_work_group_size ();
2618
2618
}
2619
2619
0 commit comments