@@ -2336,6 +2336,30 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2336
2336
* s = sumf ;
2337
2337
}
2338
2338
2339
+ // TODO: move this to a more sensible place
2340
+ typedef void (* dequantize_row_q_t )(const void * restrict x , float * restrict y , int k );
2341
+ typedef void (* quantize_row_q_t )(const float * restrict x , void * restrict y , int k );
2342
+ typedef void (* vec_dot_q_t )(const int n , float * restrict s , const void * restrict x , const void * restrict y );
2343
+
2344
+ typedef struct {
2345
+ dequantize_row_q_t dequantize_row_q ;
2346
+ quantize_row_q_t quantize_row_q ;
2347
+ vec_dot_q_t vec_dot_q ;
2348
+ } quantize_fns_t ;
2349
+
2350
+ static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
2351
+ [GGML_TYPE_Q4_0 ] = {
2352
+ .dequantize_row_q = dequantize_row_q4_0 ,
2353
+ .quantize_row_q = quantize_row_q4_0 ,
2354
+ .vec_dot_q = ggml_vec_dot_q4_0 ,
2355
+ },
2356
+ [GGML_TYPE_Q4_1 ] = {
2357
+ .dequantize_row_q = dequantize_row_q4_1 ,
2358
+ .quantize_row_q = quantize_row_q4_1 ,
2359
+ .vec_dot_q = ggml_vec_dot_q4_1 ,
2360
+ },
2361
+ };
2362
+
2339
2363
// compute GGML_VEC_DOT_UNROLL dot products at once
2340
2364
// xs - x row stride in bytes
2341
2365
inline static void ggml_vec_dot_f16_unroll (const int n , const int xs , float * restrict s , void * restrict xv , ggml_fp16_t * restrict y ) {
@@ -5184,13 +5208,13 @@ static void ggml_compute_forward_add_f16_f32(
5184
5208
const int n = ggml_nrows (src0 );
5185
5209
const int nc = src0 -> ne [0 ];
5186
5210
5187
- const size_t nb00 = src0 -> nb [0 ];
5211
+ // const size_t nb00 = src0->nb[0];
5188
5212
const size_t nb01 = src0 -> nb [1 ];
5189
5213
5190
5214
const size_t nb10 = src1 -> nb [0 ];
5191
5215
const size_t nb11 = src1 -> nb [1 ];
5192
5216
5193
- const size_t nb0 = dst -> nb [0 ];
5217
+ // const size_t nb0 = dst->nb[0];
5194
5218
const size_t nb1 = dst -> nb [1 ];
5195
5219
5196
5220
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
@@ -5202,12 +5226,163 @@ static void ggml_compute_forward_add_f16_f32(
5202
5226
ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5203
5227
for (int i = 0 ; i < nc ; i ++ ) {
5204
5228
float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5205
-
5206
5229
dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5207
5230
}
5208
5231
}
5209
5232
}
5210
5233
5234
+ static void ggml_compute_forward_add_f16_f16 (
5235
+ const struct ggml_compute_params * params ,
5236
+ const struct ggml_tensor * src0 ,
5237
+ const struct ggml_tensor * src1 ,
5238
+ struct ggml_tensor * dst ) {
5239
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5240
+
5241
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5242
+ return ;
5243
+ }
5244
+
5245
+ const int ith = params -> ith ;
5246
+ const int nth = params -> nth ;
5247
+
5248
+ const int n = ggml_nrows (src0 );
5249
+ const int nc = src0 -> ne [0 ];
5250
+
5251
+ //const size_t nb00 = src0->nb[0];
5252
+ const size_t nb01 = src0 -> nb [1 ];
5253
+
5254
+ const size_t nb10 = src1 -> nb [0 ];
5255
+ const size_t nb11 = src1 -> nb [1 ];
5256
+
5257
+ //const size_t nb0 = dst->nb[0];
5258
+ const size_t nb1 = dst -> nb [1 ];
5259
+
5260
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5261
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5262
+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5263
+
5264
+ for (int j = ith ; j < n ; j += nth ) {
5265
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5266
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5267
+ for (int i = 0 ; i < nc ; i ++ ) {
5268
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5269
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5270
+ }
5271
+ }
5272
+ }
5273
+
5274
+ static void ggml_compute_forward_add_q_f32 (
5275
+ const struct ggml_compute_params * params ,
5276
+ const struct ggml_tensor * src0 ,
5277
+ const struct ggml_tensor * src1 ,
5278
+ struct ggml_tensor * dst ) {
5279
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5280
+
5281
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5282
+ return ;
5283
+ }
5284
+
5285
+ const int64_t ne00 = src0 -> ne [0 ];
5286
+ const int64_t ne01 = src0 -> ne [1 ];
5287
+ const int64_t ne02 = src0 -> ne [2 ];
5288
+ const int64_t ne03 = src0 -> ne [3 ];
5289
+
5290
+ //const int64_t ne10 = src1->ne[0];
5291
+ const int64_t ne11 = src1 -> ne [1 ];
5292
+ const int64_t ne12 = src1 -> ne [2 ];
5293
+ const int64_t ne13 = src1 -> ne [3 ];
5294
+
5295
+ const int64_t ne0 = dst -> ne [0 ];
5296
+ const int64_t ne1 = dst -> ne [1 ];
5297
+ const int64_t ne2 = dst -> ne [2 ];
5298
+ const int64_t ne3 = dst -> ne [3 ];
5299
+
5300
+ const int nb00 = src0 -> nb [0 ];
5301
+ const int nb01 = src0 -> nb [1 ];
5302
+ const int nb02 = src0 -> nb [2 ];
5303
+ const int nb03 = src0 -> nb [3 ];
5304
+
5305
+ const int nb10 = src1 -> nb [0 ];
5306
+ const int nb11 = src1 -> nb [1 ];
5307
+ const int nb12 = src1 -> nb [2 ];
5308
+ const int nb13 = src1 -> nb [3 ];
5309
+
5310
+ const int nb0 = dst -> nb [0 ];
5311
+ const int nb1 = dst -> nb [1 ];
5312
+ const int nb2 = dst -> nb [2 ];
5313
+ const int nb3 = dst -> nb [3 ];
5314
+
5315
+ const int ith = params -> ith ;
5316
+ const int nth = params -> nth ;
5317
+
5318
+ GGML_ASSERT (ne02 == ne12 );
5319
+ GGML_ASSERT (ne03 == ne13 );
5320
+ GGML_ASSERT (ne2 == ne12 );
5321
+ GGML_ASSERT (ne3 == ne13 );
5322
+
5323
+ const enum ggml_type type = src0 -> type ;
5324
+ dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
5325
+ quantize_row_q_t const quantize_row_q = quantize_fns [type ].quantize_row_q ;
5326
+
5327
+ // we don't support permuted src0 or src1
5328
+ GGML_ASSERT (nb00 == (int ) GGML_TYPE_SIZE [type ]);
5329
+ GGML_ASSERT (nb10 == sizeof (float ));
5330
+
5331
+ // dst cannot be transposed or permuted
5332
+ GGML_ASSERT (nb0 <= nb1 );
5333
+ GGML_ASSERT (nb1 <= nb2 );
5334
+ GGML_ASSERT (nb2 <= nb3 );
5335
+
5336
+ GGML_ASSERT (ne0 == ne01 );
5337
+ GGML_ASSERT (ne1 == ne11 );
5338
+ GGML_ASSERT (ne2 == ne02 );
5339
+ GGML_ASSERT (ne3 == ne03 );
5340
+
5341
+ GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 );
5342
+ GGML_ASSERT (dst -> type == src0 -> type );
5343
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5344
+
5345
+ // total rows in src0
5346
+ const int nr = ne01 * ne02 * ne03 ;
5347
+
5348
+ // rows per thread
5349
+ const int dr = (nr + nth - 1 )/nth ;
5350
+
5351
+ // row range for this thread
5352
+ const int ir0 = dr * ith ;
5353
+ const int ir1 = MIN (ir0 + dr , nr );
5354
+
5355
+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5356
+ // src0 indices
5357
+ const int i03 = ir /(ne02 * ne01 );
5358
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5359
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5360
+
5361
+ // src1 and dst are same shape as src0 => same indices
5362
+ const int i13 = i03 ;
5363
+ const int i12 = i02 ;
5364
+ const int i11 = i01 ;
5365
+
5366
+ const int i3 = i03 ;
5367
+ const int i2 = i02 ;
5368
+ const int i1 = i01 ;
5369
+
5370
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
5371
+ float * src1_row = (float * )((char * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ));
5372
+ void * dst_row = (void * ) ((char * ) dst -> data + ( i1 * nb1 + i2 * nb2 + i3 * nb0 ));
5373
+
5374
+ assert (ne00 % 32 == 0 );
5375
+
5376
+ // unquantize row from src0 to temp buffer
5377
+ float tmp [ne00 ];
5378
+ dequantize_row_q (src0_row , tmp , ne00 );
5379
+ // add src1
5380
+ ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5381
+ // quantize row to dst
5382
+ quantize_row_q (tmp , dst_row , ne00 );
5383
+ }
5384
+ }
5385
+
5211
5386
static void ggml_compute_forward_add (
5212
5387
const struct ggml_compute_params * params ,
5213
5388
const struct ggml_tensor * src0 ,
@@ -5220,10 +5395,21 @@ static void ggml_compute_forward_add(
5220
5395
} break ;
5221
5396
case GGML_TYPE_F16 :
5222
5397
{
5223
- ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5398
+ if (src1 -> type == GGML_TYPE_F16 ) {
5399
+ ggml_compute_forward_add_f16_f16 (params , src0 , src1 , dst );
5400
+ }
5401
+ else if (src1 -> type == GGML_TYPE_F32 ) {
5402
+ ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5403
+ }
5404
+ else {
5405
+ GGML_ASSERT (false);
5406
+ }
5224
5407
} break ;
5225
5408
case GGML_TYPE_Q4_0 :
5226
5409
case GGML_TYPE_Q4_1 :
5410
+ {
5411
+ ggml_compute_forward_add_q_f32 (params , src0 , src1 , dst );
5412
+ } break ;
5227
5413
case GGML_TYPE_I8 :
5228
5414
case GGML_TYPE_I16 :
5229
5415
case GGML_TYPE_I32 :
@@ -6608,29 +6794,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6608
6794
//}
6609
6795
}
6610
6796
6611
- typedef void (* dequantize_row_q_t )(const void * restrict x , float * restrict y , int k );
6612
- typedef void (* quantize_row_q_t )(const float * restrict x , void * restrict y , int k );
6613
- typedef void (* vec_dot_q_t )(const int n , float * restrict s , const void * restrict x , const void * restrict y );
6614
-
6615
- typedef struct {
6616
- dequantize_row_q_t dequantize_row_q ;
6617
- quantize_row_q_t quantize_row_q ;
6618
- vec_dot_q_t vec_dot_q ;
6619
- } quantize_fns_t ;
6620
-
6621
- static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
6622
- [GGML_TYPE_Q4_0 ] = {
6623
- .dequantize_row_q = dequantize_row_q4_0 ,
6624
- .quantize_row_q = quantize_row_q4_0 ,
6625
- .vec_dot_q = ggml_vec_dot_q4_0 ,
6626
- },
6627
- [GGML_TYPE_Q4_1 ] = {
6628
- .dequantize_row_q = dequantize_row_q4_1 ,
6629
- .quantize_row_q = quantize_row_q4_1 ,
6630
- .vec_dot_q = ggml_vec_dot_q4_1 ,
6631
- },
6632
- };
6633
-
6634
6797
static void ggml_compute_forward_mul_mat_q_f32 (
6635
6798
const struct ggml_compute_params * params ,
6636
6799
const struct ggml_tensor * src0 ,
0 commit comments