8
8
9
9
#version 450 core
10
10
11
- #include "indexing_utils.h"
12
-
13
11
#define PRECISION ${PRECISION}
14
12
15
- #define FOUR 4
16
-
17
- #define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
18
- #define FLOAT_T ${buffer_scalar_type(DTYPE)}
13
+ #define T ${buffer_scalar_type(DTYPE)}
14
+ #define VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
19
15
20
- ${define_active_storage_type(STORAGE)}
21
-
22
- ${define_required_extensions([DTYPE, "uint8", "uint16"])}
23
- #extension GL_EXT_control_flow_attributes : require
16
+ ${define_required_extensions(DTYPE)}
17
+ ${define_required_extensions("int8")}
24
18
25
19
layout (std430) buffer ;
26
20
27
- ${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
28
- ${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
29
- ${layout_declare_tensor(B, "r", "weights", "uint8", "buffer ")}
30
- ${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
31
- ${layout_declare_ubo(B, "ivec3 ", "ret_limits")}
32
- ${layout_declare_ubo(B, "ivec4 ", "x_sizes")}
33
- ${layout_declare_ubo(B, "ivec4 ", "weights_strides")}
34
- ${layout_declare_ubo(B, "ivec4 ", "qparams_strides")}
21
+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
22
+ ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array= False)}
23
+ ${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array= False)}
24
+ ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D ")}
25
+
26
+ layout (push_constant) uniform restrict Block {
27
+ ivec4 out_sizes;
28
+ ivec4 mat1_sizes;
29
+ ivec4 qmat2_sizes;
30
+ };
35
31
36
32
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
37
33
38
- layout (constant_id = 3 ) const int group_size = 1 ;
34
+ layout (constant_id = 3 ) const int group_size = 64 ;
39
35
40
36
/*
41
37
* This shader computes a linear operator between a floating point input matrix
42
38
* x and a weights matrix that is quantized to 4 bits.
43
39
*
44
40
* The (W, H, C) shape of each tensor is:
45
41
* - x: (K, M)
46
- * - weights: (K / 2, N )
42
+ * - weights: (N / 2, K )
47
43
* - The weights tensor has a data type of `uint8`. Each element in the tensor
48
44
* contains 2 4-bit values packed into a uint8.
45
+ * - See the pack_int4_linear_weight_transposed_interleave shader to see more
46
+ * details on how the weight tensor is stored.
49
47
* - qparams: (2, N, number_of_groups)
50
48
* - This tensor contains the scales and zeros quantization parameters for the
51
49
* weights tensor. The weight tensor is quantized group-wise, which means
@@ -57,56 +55,68 @@ layout(constant_id = 3) const int group_size = 1;
57
55
* Note that this shader assumes that all tensors are width packed.
58
56
*/
59
57
void main() {
60
- // output positions being calculated are (n, m), (n + 1, m), ...
61
- // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
62
- // of the weights tensor.
63
- const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
64
- if (any (greaterThanEqual (ret_pos, ret_limits))) {
58
+ const uint out_row = gl_GlobalInvocationID.y;
59
+ // Each thread writes out 2 texels along the width axis, equivalent to 8
60
+ // scalar elements. Therefore multiply the thread_idx.x by 8.
61
+ const uint out_col = gl_GlobalInvocationID.x << 3 ;
62
+ // Similar reasoning to the above, each thread works on 2 texels along the
63
+ // width axis so multiply thread_idx.x by 2.
64
+ const int out_col_texel_idx = int (gl_GlobalInvocationID.x) << 1 ;
65
+
66
+ if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
65
67
return ;
66
68
}
67
69
68
- // Since ret is width packed, need to multiply by 4
69
- const uint16_t n = uint16_t(ret_pos.x * 4 );
70
+ const int num_blocks = mat1_sizes.x / group_size;
70
71
71
- // K is guaranteed to be a multiple of group size
72
- const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
72
+ VEC4_T sums[2 ];
73
73
74
- uint16_t k_texel_i = uint16_t(0 );
75
- vec4 sums = vec4 (0.0 );
76
- for (uint16_t block_idx = uint16_t(0 ); block_idx < num_blocks; block_idx++ ) {
77
- vec4 scales;
78
- vec4 zeros;
74
+ sums[0 ] = VEC4_T(0 );
75
+ sums[1 ] = VEC4_T(0 );
79
76
80
- [[unroll]] for (int comp = 0 ; comp < 4 ; comp++ ) {
81
- const vec4 scale_and_zero = load_texel(
82
- qparams, u16vec3(0 , n + comp, block_idx));
83
- scales[comp] = scale_and_zero.x;
84
- zeros[comp] = scale_and_zero.y;
85
- }
77
+ VEC4_T scales[2 ];
78
+ VEC4_T zeros[2 ];
79
+
80
+ $if WEIGHT_STORAGE == "buffer ":
81
+ const int qmat2_stride = qmat2_sizes.x >> 2 ;
82
+
83
+ for (int block_idx = 0 ; block_idx < num_blocks; ++ block_idx) {
84
+ scales[0 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx, 0 , block_idx), 0 );
85
+ zeros[0 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx, 1 , block_idx), 0 );
86
86
87
- for (uint16_t i = uint16_t(0 ); i < group_size; i += uint16_t(4 ), k_texel_i++ ) {
88
- const VEC4_T x_texel = load_texel(
89
- x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));
90
-
91
- [[unroll]] for (int comp = 0 ; comp < 4 ; comp++ ) {
92
- const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2 );
93
- // Need to read 4 unpacked values, which corresponds to 2 packed values
94
- const uint8_t weights_val_1 = weights[weights_bufi];
95
- const uint8_t weights_val_2 = weights[weights_bufi + 1 ];
96
-
97
- const u8vec4 weights_texel = u8vec4(
98
- (weights_val_1 & 0xF0) >> 4 ,
99
- weights_val_1 & 0x0F,
100
- (weights_val_2 & 0xF0) >> 4 ,
101
- weights_val_2 & 0x0F);
102
-
103
- // Note that the unpacked 4-bit values are unsigned, therefore they must
104
- // first be "centered" around 0 by subtracting 8 before applying the
105
- // scale and zero point.
106
- sums[comp] += dot (
107
- x_texel, (vec4 (weights_texel) - 8.0 ) * scales[comp] + zeros[comp]);
87
+ scales[1 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx + 1 , 0 , block_idx), 0 );
88
+ zeros[1 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx + 1 , 1 , block_idx), 0 );
89
+
90
+ for (int g_idx = 0 ; g_idx < group_size; g_idx += 4 ) {
91
+ const int k = block_idx * group_size + g_idx;
92
+
93
+ $if IN_STORAGE == "buffer ":
94
+ const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2 ];
95
+ $else :
96
+ const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3 (k >> 2 , out_row, 0 ), 0 );
97
+
98
+ for (int comp = 0 ; comp < 4 ; ++ comp) {
99
+ $if WEIGHT_STORAGE == "buffer ":
100
+ const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
101
+ $else :
102
+ const uvec4 packed_weight_tex = texelFetch(
103
+ t_qmat2,
104
+ ivec3 (gl_GlobalInvocationID.x, k + comp, 0 ),
105
+ 0 );
106
+
107
+ const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4 ;
108
+ const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
109
+
110
+ sums[0 ] += mat1_tex[comp] * ((vec4 (weight_tex_1) - 8.0 ) * scales[0 ] + zeros[0 ]);
111
+ sums[1 ] += mat1_tex[comp] * ((vec4 (weight_tex_2) - 8.0 ) * scales[1 ] + zeros[1 ]);
108
112
}
109
113
}
110
114
}
111
- write_texel(ret, ret_pos, sums);
115
+
116
+ $if OUT_STORAGE == "buffer ":
117
+ t_out[(out_row * out_sizes.x + out_col) >> 2 ] = sums[0 ];
118
+ t_out[(out_row * out_sizes.x + out_col + 4 ) >> 2 ] = sums[1 ];
119
+ $else :
120
+ imageStore(t_out, ivec3 (out_col_texel_idx, out_row, 0 ), sums[0 ]);
121
+ imageStore(t_out, ivec3 (out_col_texel_idx + 1 , out_row, 0 ), sums[1 ]);
112
122
}
0 commit comments