@@ -5352,6 +5352,8 @@ static void ggml_compute_forward_add_q_f32(
5352
5352
const int ir0 = dr * ith ;
5353
5353
const int ir1 = MIN (ir0 + dr , nr );
5354
5354
5355
+ float * wdata = (float * ) params -> wdata + ne00 * ith ;
5356
+
5355
5357
for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5356
5358
// src0 indices
5357
5359
const int i03 = ir /(ne02 * ne01 );
@@ -5374,12 +5376,11 @@ static void ggml_compute_forward_add_q_f32(
5374
5376
assert (ne00 % 32 == 0 );
5375
5377
5376
5378
// unquantize row from src0 to temp buffer
5377
- float tmp [ne00 ];
5378
- dequantize_row_q (src0_row , tmp , ne00 );
5379
+ dequantize_row_q (src0_row , wdata , ne00 );
5379
5380
// add src1
5380
- ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5381
+ ggml_vec_acc_f32 (ne00 , wdata , src1_row );
5381
5382
// quantize row to dst
5382
- quantize_row_q (tmp , dst_row , ne00 );
5383
+ quantize_row_q (wdata , dst_row , ne00 );
5383
5384
}
5384
5385
}
5385
5386
@@ -9568,6 +9569,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9568
9569
case GGML_OP_ADD :
9569
9570
{
9570
9571
node -> n_tasks = n_threads ;
9572
+
9573
+ size_t cur = 0 ;
9574
+
9575
+ if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 ) {
9576
+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
9577
+ }
9578
+
9579
+ work_size = MAX (work_size , cur );
9571
9580
} break ;
9572
9581
case GGML_OP_SUB :
9573
9582
case GGML_OP_MUL :
0 commit comments