@@ -96,14 +96,15 @@ __global__ void q4_matmul_kernel
96
96
{
97
97
half2 w_scale = w_scales_.item_half2half2 (group, w_column);
98
98
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
99
- dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
99
+ acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
100
100
}
101
101
else
102
102
{
103
103
half w_scale = w_scales_.item (group, w_column);
104
104
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
105
- dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
105
+ acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
106
106
}
107
+ __syncthreads ();
107
108
}
108
109
}
109
110
else
@@ -124,15 +125,16 @@ __global__ void q4_matmul_kernel
124
125
int group = k / groupsize;
125
126
half2 w_scale = w_scales_.item_half2half2 (group, w_column);
126
127
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
127
- dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, 1 );
128
+ acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, 1 );
128
129
}
129
130
else
130
131
{
131
132
int group = k / groupsize;
132
133
half w_scale = w_scales_.item (group, w_column);
133
134
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
134
- dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1 );
135
+ acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1 );
135
136
}
137
+ __syncthreads ();
136
138
}
137
139
}
138
140
0 commit comments