Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f186004

Browse files
committedMar 11, 2024·
Vectorize MatMulNBits.
1 parent d114d9a commit f186004

File tree

1 file changed

+109
-37
lines changed

1 file changed

+109
-37
lines changed
 

‎js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

+109-37
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
88
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
10+
import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
// TODO support quantization bits not equal to 4
1313
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -51,36 +51,91 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
5151

5252
export const createMatMulNBitsProgramInfo =
5353
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
54-
const a = inputs[0];
55-
const aRank = a.dims.length;
56-
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
57-
const outputSize = ShapeUtil.size(outputShape);
58-
59-
54+
const inputShape = inputs[0].dims;
55+
const aRank = inputShape.length;
56+
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
57+
const m = inputShape[aRank - 2];
58+
const blobSize = attributes.blockSize / 8 * attributes.bits;
59+
const blobSizeInWords = blobSize / 4;
60+
const outputNumber = getMaxComponents(m);
61+
const components = 1; // getMaxComponents(attributes.n);
62+
const aComponents = getMaxComponents(attributes.k);
63+
const bComponents = getMaxComponents(blobSizeInWords);
64+
const zComponents = 1; // getMaxComponents(attributes.n / 8);
65+
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
6066
const programUniforms: ProgramUniform[] = [
6167
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
6268
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
6369
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
6470
];
6571
const getShaderSource = (shaderHelper: ShaderHelper) => {
66-
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
67-
const b = inputVariable('b', DataType.uint32, ShapeUtil.convertShape(inputs[1].dims));
72+
const aShape = inputs[0].dims.slice();
73+
aShape.splice(-1, 1, attributes.k / aComponents);
74+
const a = inputVariable('a', inputs[0].dataType, aShape, aComponents);
75+
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
76+
bShape.splice(-1, 1, blobSizeInWords / bComponents);
77+
const b = inputVariable('b', DataType.uint32, bShape, bComponents);
6878
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims);
6979
const inputVariables = [a, b, scales];
70-
const zeroPoints =
71-
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims) : undefined;
80+
const zeroPoints = inputs.length === 4 ?
81+
inputVariable('zero_points', DataType.uint32, inputs[3].dims, zComponents) :
82+
undefined;
7283
if (zeroPoints) {
7384
inputVariables.push(zeroPoints);
7485
}
75-
const output = outputVariable('output', inputs[0].dataType, outputShape);
86+
const output = outputVariable('output', inputs[0].dataType, outputShape, components);
7687
const uniforms: UniformsArrayType = [
77-
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
88+
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
7889
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
7990
];
8091
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
81-
const blobSize = attributes.blockSize / 8 * attributes.bits;
82-
const wordPerBlob = blobSize / 4;
8392
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
93+
const dequantizeArrayReturnType = (() => {
94+
switch (aComponents) {
95+
case 1:
96+
return `array<${dataType}, 8>`;
97+
case 2:
98+
return `array<vec2<${dataType}>, 4>`;
99+
case 4:
100+
return `array<vec4<${dataType}>, 2>`;
101+
default:
102+
throw new Error(`${aComponents}-component is not supported.`);
103+
}
104+
})();
105+
const dequantizeArrayImpl =
106+
(() => `fn dequantize_array(quantized_data: array<${dataType}, 8>, zero_point: ${dataType}, scale: ${
107+
dataType}) -> ${dequantizeArrayReturnType} {
108+
var result: ${dequantizeArrayReturnType};
109+
${(() => {
110+
switch (aComponents) {
111+
case 1:
112+
return `
113+
for (var i: u32 = 0; i < 8; i++) {
114+
result[i] = dequantize(quantized_data[i], zero_point, scale);
115+
}`;
116+
case 2:
117+
return `
118+
for (var i: u32 = 0; i < 4; i++) {
119+
let dequantized0 = dequantize(quantized_data[i*2], zero_point, scale);
120+
let dequantized1 = dequantize(quantized_data[i*2+1], zero_point, scale);
121+
result[i] = vec2<${dataType}>(dequantized0, dequantized1);
122+
}`;
123+
case 4:
124+
return `
125+
for (var i: u32 = 0; i < 2; i++) {
126+
let dequantized0 = dequantize(quantized_data[i*4], zero_point, scale);
127+
let dequantized1 = dequantize(quantized_data[i*4+1], zero_point, scale);
128+
let dequantized2 = dequantize(quantized_data[i*4+2], zero_point, scale);
129+
let dequantized3 = dequantize(quantized_data[i*4+3], zero_point, scale);
130+
result[i] = vec4<${dataType}>(dequantized0, dequantized1, dequantized2, dequantized3);
131+
}`;
132+
default:
133+
throw new Error(`${aComponents}-component is not supported.`);
134+
}
135+
})()}
136+
return result;
137+
}`)();
138+
84139
return `
85140
fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{
86141
var result = array<${dataType}, 8>();
@@ -92,13 +147,21 @@ export const createMatMulNBitsProgramInfo =
92147
}
93148
return result;
94149
}
150+
151+
fn dequantize(value: ${dataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${dataType} {
152+
return (value - zero_point) * scale;
153+
}
154+
155+
${dequantizeArrayImpl};
156+
95157
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
96158
${shaderHelper.mainStart()}
97159
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
98-
var value: ${dataType} = 0.0;
99-
let output_indices = ${output.offsetToIndices('global_idx')};
100-
var a_indices: ${a.type.indices} = output_indices;
160+
var output_values: array<${output.type.value}, ${outputNumber}>;
161+
var output_indices = ${output.offsetToIndices('global_idx')};
101162
var n = ${output.indicesGet('output_indices', aRank - 1)};
163+
var m = ${output.indicesGet('output_indices', aRank - 2)};
164+
var a_indices: ${a.type.indices} = output_indices;
102165
// Two zero points are packed into one byte because uniforms.bits <= 4.
103166
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
104167
// TODO support zero_point_offset for bits > 4
@@ -108,35 +171,41 @@ export const createMatMulNBitsProgramInfo =
108171
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
109172
var zero_point_offset: u32 = 0;` :
110173
''}
111-
var scale_idex = n * ${nBlocksPerCol};
174+
var scale_index = n * ${nBlocksPerCol};
112175
var b_indices: ${b.type.indices};
113176
${b.indicesSet('b_indices', '0', 'n')};
114177
var block_offset: u32 = 0;
115178
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
116179
// The scale and zero points are computed per block.
117-
let scale = ${scales.getByOffset('scale_idex')};
180+
let scale = ${scales.getByOffset('scale_index')};
118181
// The default zero point is 8 for unsigned 4-bit quantization.
119182
let zero_point: ${dataType} = ${
120183
zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0};
121184
${b.indicesSet('b_indices', '1', 'block')};
122185
var word_offset: u32 = block_offset;
123-
for (var word: u32 = 0; word < ${wordPerBlob}; word++) {
186+
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
124187
${b.indicesSet('b_indices', '2', 'word')};
125-
let b_value = ${b.getByIndices('b_indices')};
126-
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
127-
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
128-
var offset: u32 = word_offset;
129-
for (var i: u32 = 0; i < 8; i++) {
130-
${a.indicesSet('a_indices', aRank - 1, 'offset')};
131-
let a_value = ${a.getByIndices('a_indices')};
132-
let b_quantized_value = b_quantized_values[i];
133-
let b_dequantized_value = (b_quantized_value - zero_point) * scale;
134-
value += a_value * b_dequantized_value;
135-
offset++;
188+
let b_data = ${b.getByIndices('b_indices')};
189+
for (var i: u32 = 0; i < ${bComponents}; i++) {
190+
let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
191+
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
192+
let b_dequantized_values = dequantize_array(b_quantized_values, zero_point, scale);
193+
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
194+
var offset: u32 = word_offset;
195+
for (var j: u32 = 0; j < 8/${aComponents}; j++) {
196+
${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
197+
for (var k: u32 = 0; k < ${outputNumber}; k++) {
198+
${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
199+
let a_data = ${a.getByIndices('a_indices')};
200+
output_values[k] += ${
201+
aComponents === 1 ? 'a_data * b_dequantized_values[j]' : `dot(a_data, b_dequantized_values[j])`};
202+
}
203+
offset += ${aComponents};
204+
}
205+
word_offset += 8;
136206
}
137-
word_offset += 8;
138207
}
139-
scale_idex++;
208+
scale_index++;
140209
${
141210
zeroPoints ? `
142211
if (zero_point_offset == 28) {
@@ -149,17 +218,20 @@ export const createMatMulNBitsProgramInfo =
149218
''}
150219
block_offset += uniforms.block_size;
151220
}
152-
${output.setByOffset('global_idx', 'value')};
221+
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
222+
${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
223+
${output.setByIndices('output_indices', 'output_values[k]')}
224+
}
153225
}
154226
`;
155227
};
156228
return {
157229
name: 'MatMulNBits',
158230
shaderCache:
159-
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('dims')},
231+
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
160232
getRunData: () => ({
161233
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
162-
dispatchGroup: {x: Math.ceil(outputSize / 64)},
234+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
163235
programUniforms
164236
}),
165237
getShaderSource

0 commit comments

Comments
 (0)
Please sign in to comment.