Skip to content

Commit 6a55aac

Browse files
committed
[ET-VK] Improve packing format for int4 linear operator + misc improvements
Pull Request resolved: #9883 ## Context Improve performance of the quantized int4 linear shader by packing the scales and zeros tensor, as well as the weight tensor in a more optimal way. See the comments in the `pack_int4_linear_weight_transposed_interleave` shader for more details about how the new packing works. ## Changes * Split int8 quantized linear and int4 quantized linear into separate C++ files for better code organization * Introduce packing shader for int4 weights * Update int4 linear shader to account for packed weights ## Impact This change massively improves the performance of the weight int4 quantized linear operator. With this change, running LLaMa 3.2 1B can now achieve 10 tok/s, from 0.9 tok/s on an Adreno 740. This is a 10x improvement! With this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 332.3 MB/s (74692800 bytes in 0.214s) I 00:00:00.003353 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003533 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003563 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003685 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003747 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003799 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003852 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003902 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003976 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004289 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:04.841690 executorch:runner.cpp:101] Reading metadata from model I 00:00:04.841808 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:04.841830 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:04.841851 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:04.841874 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:04.841893 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:04.841909 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:04.841927 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:04.841945 executorch:runner.cpp:133] eos_id = 128009 I 00:00:04.841951 executorch:runner.cpp:133] eos_id = 128001 I 00:00:04.841963 executorch:runner.cpp:188] RSS after loading model: 2229.828125 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:06.239633 executorch:runner.cpp:258] RSS after prompt prefill: 2229.828125 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:00:17.699086 executorch:runner.cpp:272] RSS after finishing text generation: 2229.828125 MiB (0 if unsupported) I 00:00:17.699155 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:00:17.699161 executorch:stats.h:114] Model Load Time: 4.837000 (seconds) I 00:00:17.699165 executorch:stats.h:124] Total inference time: 12.857000 (seconds) Rate: 8.788987 (tokens/second) I 00:00:17.699168 executorch:stats.h:132] Prompt evaluation: 1.398000 (seconds) Rate: 10.014306 (tokens/second) I 00:00:17.699171 executorch:stats.h:143] Generated 113 tokens: 11.459000 (seconds) Rate: 9.861244 (tokens/second) I 00:00:17.699174 executorch:stats.h:151] Time to first generated token: 1.398000 (seconds) I 00:00:17.699177 executorch:stats.h:158] Sampling time over 127 tokens: 549246500.843000 (seconds) ``` Before this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 302.0 MB/s (74637464 bytes in 0.236s) I 00:00:00.003050 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003200 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003226 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003337 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003396 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003449 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003502 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003553 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003629 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004075 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:05.417531 executorch:runner.cpp:101] Reading metadata from model I 00:00:05.417647 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:05.417669 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:05.417698 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:05.417716 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:05.417735 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:05.417751 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:05.417768 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:05.417787 executorch:runner.cpp:133] eos_id = 128009 I 00:00:05.417793 executorch:runner.cpp:133] eos_id = 128001 I 00:00:05.417808 executorch:runner.cpp:188] RSS after loading model: 2230.812500 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:19.689616 executorch:runner.cpp:258] RSS after prompt prefill: 2230.812500 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:02:15.269693 executorch:runner.cpp:272] RSS after finishing text generation: 2230.812500 MiB (0 if unsupported) I 00:02:15.269810 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:02:15.269825 executorch:stats.h:114] Model Load Time: 5.414000 (seconds) I 00:02:15.269832 executorch:stats.h:124] Total inference time: 129.852000 (seconds) Rate: 0.870221 (tokens/second) I 00:02:15.269837 executorch:stats.h:132] Prompt evaluation: 14.271000 (seconds) Rate: 0.981010 (tokens/second) I 00:02:15.269841 executorch:stats.h:143] Generated 113 tokens: 115.581000 (seconds) Rate: 0.977669 (tokens/second) I 00:02:15.269844 executorch:stats.h:151] Time to first generated token: 14.271000 (seconds) I 00:02:15.269847 executorch:stats.h:158] Sampling time over 127 tokens: 549711269.115000 (seconds) PyTorchObserver {"prompt_tokens":14,"generated_tokens":113,"model_load_start_ms":1743712527974,"model_load_end_ms":1743712533388,"inference_start_ms":1743712533388,"inference_end_ms":1743712663240,"prompt_eval_end_ms":1743712547659,"first_token_ms":1743712547659,"aggregate_sampling_time_ms":549711269115,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` ghstack-source-id: 276219518 @exported-using-ghexport Differential Revision: [D72412950](https://our.internmc.facebook.com/intern/diff/D72412950/)
1 parent 95d38c4 commit 6a55aac

9 files changed

+491
-230
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,7 @@ vTensor::vTensor(
497497
VK_CHECK_COND(
498498
dim_order_is_valid(dim_order_), "computed dim order is invalid");
499499

500-
if (storage_type != utils::kBuffer) {
501-
set_logical_limits(storage_.image_extents_);
502-
}
500+
set_logical_limits(storage_.image_extents_);
503501
}
504502

505503
// NOLINTNEXTLINE
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_required_extensions("uint8")}
14+
${define_required_extensions("int8")}
15+
16+
layout(std430) buffer;
17+
18+
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
19+
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
20+
21+
layout(push_constant) uniform restrict Block {
22+
ivec4 qmat2_sizes;
23+
};
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
uint8_t get_first(const uint8_t packed) {
28+
return uint8_t((packed & 0xF0) >> 4);
29+
}
30+
31+
uint8_t get_second(const uint8_t packed) {
32+
return uint8_t(packed & 0x0F);
33+
}
34+
35+
uint8_t combine(const uint8_t first, const uint8_t second) {
36+
return uint8_t(first << 4 | second);
37+
}
38+
39+
/*
40+
* This shader packs the weight tensor into a texture.
41+
*
42+
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
43+
* is a uint8_t, which contains 2 packed 4 bit uint values.
44+
*
45+
* The transform performed by this shader is to first transpose the tensor, so
46+
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
47+
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
48+
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
49+
* each value contain the 4, 5, 6, 7 4-bit values.
50+
*
51+
* As a concrete example, consider the following weight tensor. The | demarks
52+
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
53+
* leftmost 4 bits and 2 in the rightmost 4 bits.
54+
*
55+
* 1| 2, 3| 4, 5| 6, 7| 8,
56+
* 9|10, 11|12, 13|14, 15|16,
57+
* 17|18, 19|20, 21|22, 23|24,
58+
* 25|26, 27|28, 29|30, 31|32,
59+
* 33|34, 35|36, 37|38, 39|40,
60+
* 41|42, 43|44, 45|46, 47|48,
61+
* 49|50, 51|52, 53|54, 55|56,
62+
* 57|58, 59|60, 61|62, 63|64,
63+
*
64+
* After packing, the packed tensor would contain
65+
*
66+
* 1|33, 9|41, 17|49, 25|57,
67+
* 2|34, 10|42, 18|50, 26|58,
68+
* 3|35, 11|43, 19|51, 27|59,
69+
* 4|36, 12|44, 20|52, 28|60,
70+
* 5|37, 13|45, 21|53, 29|61,
71+
* 6|38, 14|46, 22|54, 30|62,
72+
* 7|39, 15|47, 23|55, 31|63,
73+
* 8|40, 16|48, 24|56, 32|64,
74+
*
75+
* The purpose of interleaving is to make it easier to extract the unpacked
76+
* values in order using the u8vec4 vectorized type. With the packing in place,
77+
* The 4-bit values can be extracted via
78+
*
79+
* u8vec4 packed;
80+
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
81+
* u8vec4 vals_4567 = (packed | 0x0F);
82+
*/
83+
void main() {
84+
// Each thread writes 2 output texels along the height axis
85+
ivec2 packed_pos = ivec2(
86+
gl_GlobalInvocationID.x,
87+
gl_GlobalInvocationID.y << 1);
88+
89+
// The packed tensor is width packed
90+
if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) {
91+
return;
92+
}
93+
94+
int out_col = packed_pos.x << 3;
95+
int out_row = packed_pos.y;
96+
97+
int in_col = out_row;
98+
int in_int8_col = in_col >> 1;
99+
int in_row = out_col;
100+
101+
int in_numrows = qmat2_sizes.x << 1;
102+
int in_numcols = qmat2_sizes.y;
103+
int in_num_int8_cols = qmat2_sizes.y >> 1;
104+
105+
uint8_t in_vals[8][2];
106+
for (int r = 0; r < 8; ++r) {
107+
if (in_row + r < in_numrows) {
108+
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
109+
in_vals[r][0] = get_first(in_val_packed);
110+
in_vals[r][1] = get_second(in_val_packed);
111+
} else {
112+
in_vals[r][0] = uint8_t(254);
113+
in_vals[r][1] = uint8_t(254);
114+
}
115+
}
116+
117+
u8vec4 out_tex_1 = u8vec4(
118+
combine(in_vals[0][0], in_vals[4][0]),
119+
combine(in_vals[1][0], in_vals[5][0]),
120+
combine(in_vals[2][0], in_vals[6][0]),
121+
combine(in_vals[3][0], in_vals[7][0]));
122+
123+
u8vec4 out_tex_2 = u8vec4(
124+
combine(in_vals[0][1], in_vals[4][1]),
125+
combine(in_vals[1][1], in_vals[5][1]),
126+
combine(in_vals[2][1], in_vals[6][1]),
127+
combine(in_vals[3][1], in_vals[7][1]));
128+
129+
$if STORAGE == "buffer":
130+
int stride = qmat2_sizes.x >> 2;
131+
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
132+
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
133+
$else:
134+
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
135+
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
136+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
pack_int4_linear_weight_transposed_interleaved:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
shader_variants:
11+
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
12+
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
13+
STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

+84-62
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,54 @@
88

99
#version 450 core
1010

11-
#include "indexing_utils.h"
12-
1311
#define PRECISION ${PRECISION}
1412

15-
#define FOUR 4
16-
17-
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
18-
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
19-
20-
${define_active_storage_type(STORAGE)}
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
2115

22-
${define_required_extensions([DTYPE, "uint8", "uint16"])}
23-
#extension GL_EXT_control_flow_attributes : require
16+
${define_required_extensions(DTYPE)}
17+
${define_required_extensions("int8")}
2418

2519
layout(std430) buffer;
2620

27-
${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
28-
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
29-
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
30-
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
31-
${layout_declare_ubo(B, "ivec3", "ret_limits")}
32-
${layout_declare_ubo(B, "ivec4", "x_sizes")}
33-
${layout_declare_ubo(B, "ivec4", "weights_strides")}
34-
${layout_declare_ubo(B, "ivec4", "qparams_strides")}
21+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
22+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
23+
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
24+
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}
25+
26+
layout(push_constant) uniform restrict Block {
27+
ivec4 out_sizes;
28+
ivec4 mat1_sizes;
29+
ivec4 qmat2_sizes;
30+
};
3531

3632
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3733

38-
layout(constant_id = 3) const int group_size = 1;
34+
layout(constant_id = 3) const int group_size = 64;
35+
36+
uint8_t get_first(const uint8_t packed) {
37+
return uint8_t((packed & 0xF0) >> 4);
38+
}
39+
40+
uint8_t get_second(const uint8_t packed) {
41+
return uint8_t(packed & 0x0F);
42+
}
43+
44+
uint8_t combine(const uint8_t first, const uint8_t second) {
45+
return uint8_t(first << 4 | second);
46+
}
3947

4048
/*
4149
* This shader computes a linear operator between a floating point input matrix
4250
* x and a weights matrix that is quantized to 4 bits.
4351
*
4452
* The (W, H, C) shape of each tensor is:
4553
* - x: (K, M)
46-
* - weights: (K / 2, N)
54+
* - weights: (N / 2, K)
4755
* - The weights tensor has a data type of `uint8`. Each element in the tensor
4856
* contains 2 4-bit values packed into a uint8.
57+
* - See the pack_int4_linear_weight_transposed_interleave shader to see more
58+
* details on how the weight tensor is stored.
4959
* - qparams: (2, N, number_of_groups)
5060
* - This tensor contains the scales and zeros quantization parameters for the
5161
* weights tensor. The weight tensor is quantized group-wise, which means
@@ -57,56 +67,68 @@ layout(constant_id = 3) const int group_size = 1;
5767
* Note that this shader assumes that all tensors are width packed.
5868
*/
5969
void main() {
60-
// output positions being calculated are (n, m), (n + 1, m), ...
61-
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
62-
// of the weights tensor.
63-
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
64-
if (any(greaterThanEqual(ret_pos, ret_limits))) {
70+
const uint out_row = gl_GlobalInvocationID.y;
71+
// Each thread writes out 2 texels along the width axis, equivalent to 8
72+
// scalar elements. Therefore multiply the thread_idx.x by 8.
73+
const uint out_col = gl_GlobalInvocationID.x << 3;
74+
// Similar reasoning to the above, each thread works on 2 texels along the
75+
// width axis so multiply thread_idx.x by 2.
76+
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
77+
78+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
6579
return;
6680
}
6781

68-
// Since ret is width packed, need to multiply by 4
69-
const uint16_t n = uint16_t(ret_pos.x * 4);
82+
const int num_blocks = mat1_sizes.x / group_size;
7083

71-
// K is guaranteed to be a multiple of group size
72-
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
84+
VEC4_T sums[2];
7385

74-
uint16_t k_texel_i = uint16_t(0);
75-
vec4 sums = vec4(0.0);
76-
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
77-
vec4 scales;
78-
vec4 zeros;
86+
sums[0] = VEC4_T(0);
87+
sums[1] = VEC4_T(0);
7988

80-
[[unroll]] for (int comp = 0; comp < 4; comp++) {
81-
const vec4 scale_and_zero = load_texel(
82-
qparams, u16vec3(0, n + comp, block_idx));
83-
scales[comp] = scale_and_zero.x;
84-
zeros[comp] = scale_and_zero.y;
85-
}
89+
VEC4_T scales[2];
90+
VEC4_T zeros[2];
91+
92+
$if WEIGHT_STORAGE == "buffer":
93+
const int qmat2_stride = qmat2_sizes.x >> 2;
94+
95+
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
96+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
97+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
98+
99+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
100+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
101+
102+
for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
103+
const int k = block_idx * group_size + g_idx;
86104

87-
for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
88-
const VEC4_T x_texel = load_texel(
89-
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));
90-
91-
[[unroll]] for (int comp = 0; comp < 4; comp++) {
92-
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
93-
// Need to read 4 unpacked values, which corresponds to 2 packed values
94-
const uint8_t weights_val_1 = weights[weights_bufi];
95-
const uint8_t weights_val_2 = weights[weights_bufi + 1];
96-
97-
const u8vec4 weights_texel = u8vec4(
98-
(weights_val_1 & 0xF0) >> 4,
99-
weights_val_1 & 0x0F,
100-
(weights_val_2 & 0xF0) >> 4,
101-
weights_val_2 & 0x0F);
102-
103-
// Note that the unpacked 4-bit values are unsigned, therefore they must
104-
// first be "centered" around 0 by subtracting 8 before applying the
105-
// scale and zero point.
106-
sums[comp] += dot(
107-
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
105+
$if IN_STORAGE == "buffer":
106+
const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2];
107+
$else:
108+
const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0);
109+
110+
for (int comp = 0; comp < 4; ++comp) {
111+
$if WEIGHT_STORAGE == "buffer":
112+
const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
113+
$else:
114+
const uvec4 packed_weight_tex = texelFetch(
115+
t_qmat2,
116+
ivec3(gl_GlobalInvocationID.x, k + comp, 0),
117+
0);
118+
119+
const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;
120+
const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
121+
122+
sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]);
123+
sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]);
108124
}
109125
}
110126
}
111-
write_texel(ret, ret_pos, sums);
127+
128+
$if OUT_STORAGE == "buffer":
129+
t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0];
130+
t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1];
131+
$else:
132+
imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]);
133+
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]);
112134
}

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

+13-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
q_4w_linear:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
STORAGE: texture3d
11-
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: float
14-
- VALUE: half
10+
OUT_STORAGE: texture3d
11+
IN_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture3d
1513
shader_variants:
16-
- NAME: q_4w_linear_texture3d
14+
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
15+
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
16+
IN_STORAGE: buffer
17+
- NAME: q_4w_linear_buffer_buffer_texture3d_float
18+
OUT_STORAGE: buffer
19+
IN_STORAGE: buffer
20+
- NAME: q_4w_linear_buffer_buffer_buffer_float
21+
OUT_STORAGE: buffer
22+
IN_STORAGE: buffer
23+
WEIGHT_STORAGE: buffer

0 commit comments

Comments
 (0)