@@ -1568,7 +1568,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1568
1568
ggml_cl_pool_free (d_D, d_size);
1569
1569
}
1570
1570
1571
- static void ggml_cl_mul_mat_f16 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */ ) {
1571
+ static void ggml_cl_mul_mat_f16 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
1572
1572
GGML_ASSERT (fp16_support);
1573
1573
1574
1574
const int64_t ne00 = src0->ne [0 ];
@@ -1598,6 +1598,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1598
1598
const int y_ne = ne11 * ne10;
1599
1599
const int d_ne = ne11 * ne01;
1600
1600
1601
+ GGML_ASSERT (wsize >= sizeof (ggml_fp16_t ) * y_ne);
1602
+ GGML_ASSERT (wsize >= sizeof (ggml_fp16_t ) * d_ne);
1603
+ ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
1604
+
1601
1605
size_t x_size;
1602
1606
size_t y_size;
1603
1607
size_t d_size;
@@ -1634,7 +1638,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1634
1638
1635
1639
// convert src1 to fp16
1636
1640
// TODO: use multiple threads
1637
- ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
1638
1641
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
1639
1642
if (src1_cont_rows) {
1640
1643
if (src1_cont_cols) {
@@ -1897,8 +1900,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
1897
1900
}
1898
1901
1899
1902
size_t ggml_cl_mul_mat_get_wsize (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1900
- if (ggml_cl_mul_mat_use_f16 (src0, src1, dst)) {
1901
- return ggml_nelements (src1 ) * sizeof ( ggml_fp16_t );
1903
+ if (src0-> type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16 (src0, src1, dst)) {
1904
+ return sizeof ( ggml_fp16_t ) * std::max (src1-> ne [ 0 ] * src1-> ne [ 1 ], dst-> ne [ 0 ] * dst-> ne [ 1 ] );
1902
1905
}
1903
1906
return 0 ;
1904
1907
}
0 commit comments