Skip to content

Commit 1df9911

Browse files
satyajandhyalafs-eire
authored andcommitted
[JS/WebGPU] Optimize MatMulNBits (#19852)
### Description Use vec<2> or vec<4>, operands in MatMulNBits ### Motivation and Context Improve performance
1 parent b54dd28 commit 1df9911

File tree

2 files changed

+194
-71
lines changed

2 files changed

+194
-71
lines changed

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

+137-71
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
88
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
10+
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
// TODO support quantization bits not equal to 4
1313
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -51,124 +51,190 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
5151

5252
export const createMatMulNBitsProgramInfo =
5353
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
54-
const a = inputs[0];
55-
const b = inputs[1];
56-
const scales = inputs[2];
57-
const aRank = a.dims.length;
58-
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
59-
const outputSize = ShapeUtil.size(outputShape);
60-
61-
54+
const inputShape = inputs[0].dims;
55+
const aRank = inputShape.length;
56+
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
57+
const m = inputShape[aRank - 2];
58+
const blobSize = attributes.blockSize / 8 * attributes.bits;
59+
const blobSizeInWords = blobSize / 4;
60+
const outputNumber = getMaxComponents(m);
61+
const components = getMaxComponents(attributes.n);
62+
const aComponents = getMaxComponents(attributes.k);
63+
const bComponents = getMaxComponents(blobSizeInWords);
64+
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
6265
const programUniforms: ProgramUniform[] = [
6366
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
6467
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
6568
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
6669
];
67-
programUniforms.push(...createTensorShapeVariables(a.dims));
68-
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims)));
69-
programUniforms.push(...createTensorShapeVariables(scales.dims));
70+
const aShape = inputShape.slice();
71+
aShape.splice(-1, 1, attributes.k / aComponents);
72+
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
73+
bShape.splice(-1, 1, blobSizeInWords / bComponents);
74+
programUniforms.push(...createTensorShapeVariables(aShape));
75+
programUniforms.push(...createTensorShapeVariables(bShape));
76+
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
7077
if (inputs.length === 4) {
7178
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
7279
}
73-
programUniforms.push(...createTensorShapeVariables(outputShape));
80+
const oShape = outputShape.slice();
81+
oShape.splice(-1, 1, attributes.n / components);
82+
programUniforms.push(...createTensorShapeVariables(oShape));
7483
const getShaderSource = (shaderHelper: ShaderHelper) => {
75-
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length);
76-
const b = inputVariable('b', DataType.uint32, inputs[1].dims.length);
84+
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
85+
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
7786
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
7887
const inputVariables = [a, b, scales];
7988
const zeroPoints =
8089
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
8190
if (zeroPoints) {
8291
inputVariables.push(zeroPoints);
8392
}
84-
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
93+
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
8594
const uniforms: UniformsArrayType = [
86-
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
95+
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
8796
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
8897
];
8998
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
90-
const blobSize = attributes.blockSize / 8 * attributes.bits;
91-
const wordPerBlob = blobSize / 4;
9299
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
93-
return `
94-
fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{
95-
var result = array<${dataType}, 8>();
100+
101+
const qDqDataType = (() => {
102+
switch (aComponents) {
103+
case 1:
104+
return `array<${dataType}, 8>`;
105+
case 2:
106+
return `mat4x2<${dataType}>`;
107+
case 4:
108+
return `mat2x4<${dataType}>`;
109+
default:
110+
throw new Error(`${aComponents}-component is not supported.`);
111+
}
112+
})();
113+
114+
const dequantizeImpl = `
115+
fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} {
116+
${(() => {
117+
if (aComponents === 1) {
118+
return `var dequantized = ${qDqDataType}(${
119+
Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')});
120+
return dequantized;`;
121+
} else {
122+
return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')});
123+
return (quantized - zero_points) * scale;`;
124+
}
125+
})()}
126+
}`;
127+
const ortUnpack8x4snormImpl = `
128+
fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} {
129+
var quantized: ${qDqDataType};
96130
var offset: u32 = 0;
97131
let count: u32 = 4;
98132
for (var i: u32 = 0; i < 8u; i++) {
99-
result[i] = ${dataType}(extractBits(value, offset, count));
133+
var result = ${dataType}(extractBits(value, offset, count));
134+
${(() => {
135+
switch (aComponents) {
136+
case 1:
137+
return 'quantized[i] = result;';
138+
case 2:
139+
return 'quantized[i / 2][i % 2] = result;';
140+
case 4:
141+
return 'quantized[i / 4][i % 4] = result;';
142+
default:
143+
throw new Error(`${aComponents}-component is not supported.`);
144+
}
145+
})()}
100146
offset += count;
101147
}
102-
return result;
103-
}
148+
return quantized;
149+
}`;
150+
151+
const updateZeroPointIndex = zeroPoints ? `
152+
zero_point_offset += 4;
153+
if (zero_point_offset == 32) {
154+
zero_point_offset = 0;
155+
zero_point_index++;
156+
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
157+
}` :
158+
'';
159+
160+
return `
161+
${dequantizeImpl};
162+
${ortUnpack8x4snormImpl};
104163
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
105164
${shaderHelper.mainStart()}
106165
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
107-
var value: ${dataType} = 0.0;
108-
let output_indices = ${output.offsetToIndices('global_idx')};
109-
var a_indices: ${a.type.indices} = output_indices;
166+
var output_values: array<${output.type.value}, ${outputNumber}>;
167+
var output_indices = ${output.offsetToIndices('global_idx')};
110168
var n = ${output.indicesGet('output_indices', aRank - 1)};
169+
var m = ${output.indicesGet('output_indices', aRank - 2)};
170+
var a_indices: ${a.type.indices} = output_indices;
111171
// Two zero points are packed into one byte because uniforms.bits <= 4.
112172
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
113173
// TODO support zero_point_offset for bits > 4
114174
${
115175
zeroPoints ? `
116-
var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4;
117-
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
118-
var zero_point_offset: u32 = 0;` :
176+
var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4;
177+
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
178+
var zero_point_offset: u32 = 0;` :
119179
''}
120-
var scale_idex = n * ${nBlocksPerCol};
180+
var scale_index = n * ${nBlocksPerCol * components};
121181
var b_indices: ${b.type.indices};
122-
${b.indicesSet('b_indices', '0', 'n')};
123-
var block_offset: u32 = 0;
124-
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
125-
// The scale and zero points are computed per block.
126-
let scale = ${scales.getByOffset('scale_idex')};
127-
// The default zero point is 8 for unsigned 4-bit quantization.
128-
let zero_point: ${dataType} = ${
129-
zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0};
130-
${b.indicesSet('b_indices', '1', 'block')};
131-
var word_offset: u32 = block_offset;
132-
for (var word: u32 = 0; word < ${wordPerBlob}; word++) {
133-
${b.indicesSet('b_indices', '2', 'word')};
134-
let b_value = ${b.getByIndices('b_indices')};
135-
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
136-
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
137-
var offset: u32 = word_offset;
138-
for (var i: u32 = 0; i < 8; i++) {
139-
${a.indicesSet('a_indices', aRank - 1, 'offset')};
140-
let a_value = ${a.getByIndices('a_indices')};
141-
let b_quantized_value = b_quantized_values[i];
142-
let b_dequantized_value = (b_quantized_value - zero_point) * scale;
143-
value += a_value * b_dequantized_value;
144-
offset++;
182+
for (var c: u32 = 0; c < ${components}; c++) {
183+
${b.indicesSet('b_indices', '0', `n * ${components} + c`)};
184+
var block_offset: u32 = 0;
185+
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
186+
// The scale and zero points are computed per block.
187+
let scale = ${scales.getByOffset('scale_index')};
188+
// The default zero point is 8 for unsigned 4-bit quantization.
189+
let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
190+
${b.indicesSet('b_indices', '1', 'block')};
191+
var word_offset: u32 = block_offset;
192+
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
193+
${b.indicesSet('b_indices', '2', 'word')};
194+
let b_data = ${b.getByIndices('b_indices')};
195+
for (var i: u32 = 0; i < ${bComponents}; i++) {
196+
let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
197+
let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value);
198+
let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
199+
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
200+
var offset: u32 = word_offset;
201+
for (var j: u32 = 0; j < 8/${aComponents}; j++) {
202+
${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
203+
for (var k: u32 = 0; k < ${outputNumber}u; k++) {
204+
${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
205+
let a_data = ${a.getByIndices('a_indices')};
206+
output_values[k]${components > 1 ? '[c]' : ''} += ${
207+
aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'};
208+
}
209+
offset += ${aComponents};
210+
}
211+
word_offset += 8;
212+
}
145213
}
146-
word_offset += 8;
214+
scale_index++;
215+
${updateZeroPointIndex}
216+
block_offset += uniforms.block_size;
147217
}
148-
scale_idex++;
218+
// Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
149219
${
150-
zeroPoints ? `
151-
if (zero_point_offset == 28) {
152-
zero_point_offset = 0;
153-
zero_point_index++;
154-
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
155-
} else {
156-
zero_point_offset += 4;
157-
}` :
220+
zeroPoints ? `if (zero_point_offset % 8 > 0) {
221+
${updateZeroPointIndex}
222+
}` :
158223
''}
159-
block_offset += uniforms.block_size;
160-
}
161-
${output.setByOffset('global_idx', 'value')};
162-
}
163-
`;
224+
}
225+
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
226+
${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
227+
${output.setByIndices('output_indices', 'output_values[k]')}
228+
}
229+
}`;
164230
};
165231
return {
166232
name: 'MatMulNBits',
167233
shaderCache:
168234
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
169235
getRunData: () => ({
170236
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
171-
dispatchGroup: {x: Math.ceil(outputSize / 64)},
237+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
172238
programUniforms
173239
}),
174240
getShaderSource

js/web/test/data/ops/matmulnbits.jsonc

+57
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,61 @@
11
[
2+
{
3+
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
4+
"operator": "MatMulNBits",
5+
"opset": { "domain": "com.microsoft", "version": 1 },
6+
"attributes": [
7+
{ "name": "K", "data": 16, "type": "int" },
8+
{ "name": "N", "data": 8, "type": "int" },
9+
{ "name": "block_size", "data": 16, "type": "int" },
10+
{ "name": "bits", "data": 4, "type": "int" }
11+
],
12+
"cases": [
13+
{
14+
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric",
15+
"inputs": [
16+
{
17+
"data": [
18+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
19+
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
20+
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
21+
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
22+
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
23+
127
24+
],
25+
"dims": [8, 16],
26+
"type": "float32"
27+
},
28+
{
29+
"dims": [8, 1, 8],
30+
"type": "uint8",
31+
"data": [
32+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
33+
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
34+
56, 57, 58, 59, 60, 61, 62, 63, 64
35+
]
36+
},
37+
{
38+
"dims": [8],
39+
"type": "float32",
40+
"data": [0, 1, 2, 3, 4, 5, 6, 7]
41+
}
42+
],
43+
"outputs": [
44+
{
45+
"dims": [8, 8],
46+
"type": "float32",
47+
"data": [
48+
0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
49+
-1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
50+
0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
51+
-11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
52+
-16405, -48288, -16247
53+
]
54+
}
55+
]
56+
}
57+
]
58+
},
259
{
360
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
461
"operator": "MatMulNBits",

0 commit comments

Comments
 (0)