Skip to content

Commit 215a715

Browse files
committed
Optimize q4_matmul: add missing assignment back
1 parent caf4de8 commit 215a715

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

Diff for: exllama_ext/cuda_func/q4_matmul.cu

+6-4
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ __global__ void q4_matmul_kernel
9696
{
9797
half2 w_scale = w_scales_.item_half2half2(group, w_column);
9898
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
99-
dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
99+
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
100100
}
101101
else
102102
{
103103
half w_scale = w_scales_.item(group, w_column);
104104
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
105-
dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
105+
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
106106
}
107+
__syncthreads();
107108
}
108109
}
109110
else
@@ -124,15 +125,16 @@ __global__ void q4_matmul_kernel
124125
int group = k / groupsize;
125126
half2 w_scale = w_scales_.item_half2half2(group, w_column);
126127
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
127-
dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
128+
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
128129
}
129130
else
130131
{
131132
int group = k / groupsize;
132133
half w_scale = w_scales_.item(group, w_column);
133134
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
134-
dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
135+
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
135136
}
137+
__syncthreads();
136138
}
137139
}
138140

0 commit comments

Comments
 (0)