Skip to content

Commit 079d734

Browse files
authored
[ET-VK] Improve packing format for int4 linear operator + misc improvements
Differential Revision: D72412950 Pull Request resolved: #9883
1 parent c83aba9 commit 079d734

9 files changed

+479
-230
lines changed

Diff for: backends/vulkan/runtime/api/containers/Tensor.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,7 @@ vTensor::vTensor(
497497
VK_CHECK_COND(
498498
dim_order_is_valid(dim_order_), "computed dim order is invalid");
499499

500-
if (storage_type != utils::kBuffer) {
501-
set_logical_limits(storage_.image_extents_);
502-
}
500+
set_logical_limits(storage_.image_extents_);
503501
}
504502

505503
// NOLINTNEXTLINE
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_required_extensions("uint8")}
14+
${define_required_extensions("int8")}
15+
16+
layout(std430) buffer;
17+
18+
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
19+
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
20+
21+
layout(push_constant) uniform restrict Block {
22+
ivec4 qmat2_sizes;
23+
};
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
uint8_t get_first(const uint8_t packed) {
28+
return uint8_t((packed & 0xF0) >> 4);
29+
}
30+
31+
uint8_t get_second(const uint8_t packed) {
32+
return uint8_t(packed & 0x0F);
33+
}
34+
35+
uint8_t combine(const uint8_t first, const uint8_t second) {
36+
return uint8_t(first << 4 | second);
37+
}
38+
39+
/*
40+
* This shader packs the weight tensor into a texture.
41+
*
42+
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
43+
* is a uint8_t, which contains 2 packed 4 bit uint values.
44+
*
45+
* The transform performed by this shader is to first transpose the tensor, so
46+
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
47+
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
48+
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
49+
* each value contain the 4, 5, 6, 7 4-bit values.
50+
*
51+
* As a concrete example, consider the following weight tensor. The | demarks
52+
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
53+
* leftmost 4 bits and 2 in the rightmost 4 bits.
54+
*
55+
* 1| 2, 3| 4, 5| 6, 7| 8,
56+
* 9|10, 11|12, 13|14, 15|16,
57+
* 17|18, 19|20, 21|22, 23|24,
58+
* 25|26, 27|28, 29|30, 31|32,
59+
* 33|34, 35|36, 37|38, 39|40,
60+
* 41|42, 43|44, 45|46, 47|48,
61+
* 49|50, 51|52, 53|54, 55|56,
62+
* 57|58, 59|60, 61|62, 63|64,
63+
*
64+
* After packing, the packed tensor would contain
65+
*
66+
* 1|33, 9|41, 17|49, 25|57,
67+
* 2|34, 10|42, 18|50, 26|58,
68+
* 3|35, 11|43, 19|51, 27|59,
69+
* 4|36, 12|44, 20|52, 28|60,
70+
* 5|37, 13|45, 21|53, 29|61,
71+
* 6|38, 14|46, 22|54, 30|62,
72+
* 7|39, 15|47, 23|55, 31|63,
73+
* 8|40, 16|48, 24|56, 32|64,
74+
*
75+
* The purpose of interleaving is to make it easier to extract the unpacked
76+
* values in order using the u8vec4 vectorized type. With the packing in place,
77+
* The 4-bit values can be extracted via
78+
*
79+
* u8vec4 packed;
80+
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
81+
* u8vec4 vals_4567 = (packed | 0x0F);
82+
*/
83+
void main() {
84+
// Each thread writes 2 output texels along the height axis
85+
ivec2 packed_pos = ivec2(
86+
gl_GlobalInvocationID.x,
87+
gl_GlobalInvocationID.y << 1);
88+
89+
// The packed tensor is width packed
90+
if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) {
91+
return;
92+
}
93+
94+
int out_col = packed_pos.x << 3;
95+
int out_row = packed_pos.y;
96+
97+
int in_col = out_row;
98+
int in_int8_col = in_col >> 1;
99+
int in_row = out_col;
100+
101+
int in_numrows = qmat2_sizes.x << 1;
102+
int in_numcols = qmat2_sizes.y;
103+
int in_num_int8_cols = qmat2_sizes.y >> 1;
104+
105+
uint8_t in_vals[8][2];
106+
for (int r = 0; r < 8; ++r) {
107+
if (in_row + r < in_numrows) {
108+
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
109+
in_vals[r][0] = get_first(in_val_packed);
110+
in_vals[r][1] = get_second(in_val_packed);
111+
} else {
112+
in_vals[r][0] = uint8_t(254);
113+
in_vals[r][1] = uint8_t(254);
114+
}
115+
}
116+
117+
u8vec4 out_tex_1 = u8vec4(
118+
combine(in_vals[0][0], in_vals[4][0]),
119+
combine(in_vals[1][0], in_vals[5][0]),
120+
combine(in_vals[2][0], in_vals[6][0]),
121+
combine(in_vals[3][0], in_vals[7][0]));
122+
123+
u8vec4 out_tex_2 = u8vec4(
124+
combine(in_vals[0][1], in_vals[4][1]),
125+
combine(in_vals[1][1], in_vals[5][1]),
126+
combine(in_vals[2][1], in_vals[6][1]),
127+
combine(in_vals[3][1], in_vals[7][1]));
128+
129+
$if STORAGE == "buffer":
130+
int stride = qmat2_sizes.x >> 2;
131+
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
132+
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
133+
$else:
134+
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
135+
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
136+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
pack_int4_linear_weight_transposed_interleaved:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
shader_variants:
11+
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
12+
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
13+
STORAGE: buffer

Diff for: backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

+72-62
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,42 @@
88

99
#version 450 core
1010

11-
#include "indexing_utils.h"
12-
1311
#define PRECISION ${PRECISION}
1412

15-
#define FOUR 4
16-
17-
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
18-
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
1915

20-
${define_active_storage_type(STORAGE)}
21-
22-
${define_required_extensions([DTYPE, "uint8", "uint16"])}
23-
#extension GL_EXT_control_flow_attributes : require
16+
${define_required_extensions(DTYPE)}
17+
${define_required_extensions("int8")}
2418

2519
layout(std430) buffer;
2620

27-
${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
28-
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
29-
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
30-
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
31-
${layout_declare_ubo(B, "ivec3", "ret_limits")}
32-
${layout_declare_ubo(B, "ivec4", "x_sizes")}
33-
${layout_declare_ubo(B, "ivec4", "weights_strides")}
34-
${layout_declare_ubo(B, "ivec4", "qparams_strides")}
21+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
22+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
23+
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
24+
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}
25+
26+
layout(push_constant) uniform restrict Block {
27+
ivec4 out_sizes;
28+
ivec4 mat1_sizes;
29+
ivec4 qmat2_sizes;
30+
};
3531

3632
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3733

38-
layout(constant_id = 3) const int group_size = 1;
34+
layout(constant_id = 3) const int group_size = 64;
3935

4036
/*
4137
* This shader computes a linear operator between a floating point input matrix
4238
* x and a weights matrix that is quantized to 4 bits.
4339
*
4440
* The (W, H, C) shape of each tensor is:
4541
* - x: (K, M)
46-
* - weights: (K / 2, N)
42+
* - weights: (N / 2, K)
4743
* - The weights tensor has a data type of `uint8`. Each element in the tensor
4844
* contains 2 4-bit values packed into a uint8.
45+
* - See the pack_int4_linear_weight_transposed_interleave shader to see more
46+
* details on how the weight tensor is stored.
4947
* - qparams: (2, N, number_of_groups)
5048
* - This tensor contains the scales and zeros quantization parameters for the
5149
* weights tensor. The weight tensor is quantized group-wise, which means
@@ -57,56 +55,68 @@ layout(constant_id = 3) const int group_size = 1;
5755
* Note that this shader assumes that all tensors are width packed.
5856
*/
5957
void main() {
60-
// output positions being calculated are (n, m), (n + 1, m), ...
61-
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
62-
// of the weights tensor.
63-
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
64-
if (any(greaterThanEqual(ret_pos, ret_limits))) {
58+
const uint out_row = gl_GlobalInvocationID.y;
59+
// Each thread writes out 2 texels along the width axis, equivalent to 8
60+
// scalar elements. Therefore multiply the thread_idx.x by 8.
61+
const uint out_col = gl_GlobalInvocationID.x << 3;
62+
// Similar reasoning to the above, each thread works on 2 texels along the
63+
// width axis so multiply thread_idx.x by 2.
64+
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
65+
66+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
6567
return;
6668
}
6769

68-
// Since ret is width packed, need to multiply by 4
69-
const uint16_t n = uint16_t(ret_pos.x * 4);
70+
const int num_blocks = mat1_sizes.x / group_size;
7071

71-
// K is guaranteed to be a multiple of group size
72-
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
72+
VEC4_T sums[2];
7373

74-
uint16_t k_texel_i = uint16_t(0);
75-
vec4 sums = vec4(0.0);
76-
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
77-
vec4 scales;
78-
vec4 zeros;
74+
sums[0] = VEC4_T(0);
75+
sums[1] = VEC4_T(0);
7976

80-
[[unroll]] for (int comp = 0; comp < 4; comp++) {
81-
const vec4 scale_and_zero = load_texel(
82-
qparams, u16vec3(0, n + comp, block_idx));
83-
scales[comp] = scale_and_zero.x;
84-
zeros[comp] = scale_and_zero.y;
85-
}
77+
VEC4_T scales[2];
78+
VEC4_T zeros[2];
79+
80+
$if WEIGHT_STORAGE == "buffer":
81+
const int qmat2_stride = qmat2_sizes.x >> 2;
82+
83+
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
84+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
85+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
8686

87-
for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
88-
const VEC4_T x_texel = load_texel(
89-
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));
90-
91-
[[unroll]] for (int comp = 0; comp < 4; comp++) {
92-
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
93-
// Need to read 4 unpacked values, which corresponds to 2 packed values
94-
const uint8_t weights_val_1 = weights[weights_bufi];
95-
const uint8_t weights_val_2 = weights[weights_bufi + 1];
96-
97-
const u8vec4 weights_texel = u8vec4(
98-
(weights_val_1 & 0xF0) >> 4,
99-
weights_val_1 & 0x0F,
100-
(weights_val_2 & 0xF0) >> 4,
101-
weights_val_2 & 0x0F);
102-
103-
// Note that the unpacked 4-bit values are unsigned, therefore they must
104-
// first be "centered" around 0 by subtracting 8 before applying the
105-
// scale and zero point.
106-
sums[comp] += dot(
107-
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
87+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
88+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
89+
90+
for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
91+
const int k = block_idx * group_size + g_idx;
92+
93+
$if IN_STORAGE == "buffer":
94+
const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2];
95+
$else:
96+
const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0);
97+
98+
for (int comp = 0; comp < 4; ++comp) {
99+
$if WEIGHT_STORAGE == "buffer":
100+
const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
101+
$else:
102+
const uvec4 packed_weight_tex = texelFetch(
103+
t_qmat2,
104+
ivec3(gl_GlobalInvocationID.x, k + comp, 0),
105+
0);
106+
107+
const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;
108+
const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
109+
110+
sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]);
111+
sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]);
108112
}
109113
}
110114
}
111-
write_texel(ret, ret_pos, sums);
115+
116+
$if OUT_STORAGE == "buffer":
117+
t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0];
118+
t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1];
119+
$else:
120+
imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]);
121+
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]);
112122
}

Diff for: backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

+13-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
q_4w_linear:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
STORAGE: texture3d
11-
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: float
14-
- VALUE: half
10+
OUT_STORAGE: texture3d
11+
IN_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture3d
1513
shader_variants:
16-
- NAME: q_4w_linear_texture3d
14+
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
15+
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
16+
IN_STORAGE: buffer
17+
- NAME: q_4w_linear_buffer_buffer_texture3d_float
18+
OUT_STORAGE: buffer
19+
IN_STORAGE: buffer
20+
- NAME: q_4w_linear_buffer_buffer_buffer_float
21+
OUT_STORAGE: buffer
22+
IN_STORAGE: buffer
23+
WEIGHT_STORAGE: buffer

0 commit comments

Comments
 (0)