@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
7
7
import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
8
import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10
+ import { createTensorShapeVariables , getMaxComponents , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
// TODO support quantization bits not equal to 4
13
13
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -51,124 +51,190 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
51
51
52
52
export const createMatMulNBitsProgramInfo =
53
53
( inputs : readonly TensorView [ ] , attributes : MatMulNBitsAttributes ) : ProgramInfo => {
54
- const a = inputs [ 0 ] ;
55
- const b = inputs [ 1 ] ;
56
- const scales = inputs [ 2 ] ;
57
- const aRank = a . dims . length ;
58
- const outputShape = a . dims . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
59
- const outputSize = ShapeUtil . size ( outputShape ) ;
60
-
61
-
54
+ const inputShape = inputs [ 0 ] . dims ;
55
+ const aRank = inputShape . length ;
56
+ const outputShape = inputShape . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
57
+ const m = inputShape [ aRank - 2 ] ;
58
+ const blobSize = attributes . blockSize / 8 * attributes . bits ;
59
+ const blobSizeInWords = blobSize / 4 ;
60
+ const outputNumber = getMaxComponents ( m ) ;
61
+ const components = getMaxComponents ( attributes . n ) ;
62
+ const aComponents = getMaxComponents ( attributes . k ) ;
63
+ const bComponents = getMaxComponents ( blobSizeInWords ) ;
64
+ const outputSize = ShapeUtil . size ( outputShape ) / components / outputNumber ;
62
65
const programUniforms : ProgramUniform [ ] = [
63
66
{ type : DataType . uint32 , data : outputSize } , { type : DataType . uint32 , data : attributes . k } ,
64
67
{ type : DataType . uint32 , data : attributes . n } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
65
68
{ type : DataType . uint32 , data : attributes . bits } , { type : DataType . uint32 , data : attributes . blockSize }
66
69
] ;
67
- programUniforms . push ( ...createTensorShapeVariables ( a . dims ) ) ;
68
- programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( b . dims ) ) ) ;
69
- programUniforms . push ( ...createTensorShapeVariables ( scales . dims ) ) ;
70
+ const aShape = inputShape . slice ( ) ;
71
+ aShape . splice ( - 1 , 1 , attributes . k / aComponents ) ;
72
+ const bShape = ShapeUtil . convertShape ( inputs [ 1 ] . dims ) . slice ( ) ;
73
+ bShape . splice ( - 1 , 1 , blobSizeInWords / bComponents ) ;
74
+ programUniforms . push ( ...createTensorShapeVariables ( aShape ) ) ;
75
+ programUniforms . push ( ...createTensorShapeVariables ( bShape ) ) ;
76
+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
70
77
if ( inputs . length === 4 ) {
71
78
programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( inputs [ 3 ] . dims ) ) ) ;
72
79
}
73
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
80
+ const oShape = outputShape . slice ( ) ;
81
+ oShape . splice ( - 1 , 1 , attributes . n / components ) ;
82
+ programUniforms . push ( ...createTensorShapeVariables ( oShape ) ) ;
74
83
const getShaderSource = ( shaderHelper : ShaderHelper ) => {
75
- const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length ) ;
76
- const b = inputVariable ( 'b' , DataType . uint32 , inputs [ 1 ] . dims . length ) ;
84
+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , aShape . length , aComponents ) ;
85
+ const b = inputVariable ( 'b' , DataType . uint32 , bShape . length , bComponents ) ;
77
86
const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
78
87
const inputVariables = [ a , b , scales ] ;
79
88
const zeroPoints =
80
89
inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims . length ) : undefined ;
81
90
if ( zeroPoints ) {
82
91
inputVariables . push ( zeroPoints ) ;
83
92
}
84
- const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length ) ;
93
+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length , components ) ;
85
94
const uniforms : UniformsArrayType = [
86
- { name : 'output_size' , type : 'u32' } , { name : 'k ' , type : 'u32' } , { name : 'n ' , type : 'u32' } ,
95
+ { name : 'output_size' , type : 'u32' } , { name : 'K ' , type : 'u32' } , { name : 'N ' , type : 'u32' } ,
87
96
{ name : 'accuracy_level' , type : 'u32' } , { name : 'bits' , type : 'u32' } , { name : 'block_size' , type : 'u32' }
88
97
] ;
89
98
const nBlocksPerCol = Math . floor ( ( attributes . k + attributes . blockSize - 1 ) / attributes . blockSize ) ;
90
- const blobSize = attributes . blockSize / 8 * attributes . bits ;
91
- const wordPerBlob = blobSize / 4 ;
92
99
const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
93
- return `
94
- fn ortUnpack8x4snorm(value: u32) -> array<${ dataType } , 8>{
95
- var result = array<${ dataType } , 8>();
100
+
101
+ const qDqDataType = ( ( ) => {
102
+ switch ( aComponents ) {
103
+ case 1 :
104
+ return `array<${ dataType } , 8>` ;
105
+ case 2 :
106
+ return `mat4x2<${ dataType } >` ;
107
+ case 4 :
108
+ return `mat2x4<${ dataType } >` ;
109
+ default :
110
+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
111
+ }
112
+ } ) ( ) ;
113
+
114
+ const dequantizeImpl = `
115
+ fn dequantize(quantized: ${ qDqDataType } , zero_point: ${ dataType } , scale: ${ dataType } ) -> ${ qDqDataType } {
116
+ ${ ( ( ) => {
117
+ if ( aComponents === 1 ) {
118
+ return `var dequantized = ${ qDqDataType } (${
119
+ Array . from ( { length : 8 } , ( _ , i ) => `(quantized[${ i } ] - zero_point) * scale` ) . join ( ', ' ) } );
120
+ return dequantized;` ;
121
+ } else {
122
+ return `var zero_points: ${ qDqDataType } = ${ qDqDataType } (${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } );
123
+ return (quantized - zero_points) * scale;` ;
124
+ }
125
+ } ) ( ) }
126
+ }` ;
127
+ const ortUnpack8x4snormImpl = `
128
+ fn ortUnpack8x4snorm(value: u32) -> ${ qDqDataType } {
129
+ var quantized: ${ qDqDataType } ;
96
130
var offset: u32 = 0;
97
131
let count: u32 = 4;
98
132
for (var i: u32 = 0; i < 8u; i++) {
99
- result[i] = ${ dataType } (extractBits(value, offset, count));
133
+ var result = ${ dataType } (extractBits(value, offset, count));
134
+ ${ ( ( ) => {
135
+ switch ( aComponents ) {
136
+ case 1 :
137
+ return 'quantized[i] = result;' ;
138
+ case 2 :
139
+ return 'quantized[i / 2][i % 2] = result;' ;
140
+ case 4 :
141
+ return 'quantized[i / 4][i % 4] = result;' ;
142
+ default :
143
+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
144
+ }
145
+ } ) ( ) }
100
146
offset += count;
101
147
}
102
- return result;
103
- }
148
+ return quantized;
149
+ }` ;
150
+
151
+ const updateZeroPointIndex = zeroPoints ? `
152
+ zero_point_offset += 4;
153
+ if (zero_point_offset == 32) {
154
+ zero_point_offset = 0;
155
+ zero_point_index++;
156
+ zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
157
+ }` :
158
+ '' ;
159
+
160
+ return `
161
+ ${ dequantizeImpl } ;
162
+ ${ ortUnpack8x4snormImpl } ;
104
163
${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) }
105
164
${ shaderHelper . mainStart ( ) }
106
165
${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
107
- var value: ${ dataType } = 0.0;
108
- let output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
109
- var a_indices: ${ a . type . indices } = output_indices;
166
+ var output_values: array<${ output . type . value } , ${ outputNumber } >;
167
+ var output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
110
168
var n = ${ output . indicesGet ( 'output_indices' , aRank - 1 ) } ;
169
+ var m = ${ output . indicesGet ( 'output_indices' , aRank - 2 ) } ;
170
+ var a_indices: ${ a . type . indices } = output_indices;
111
171
// Two zero points are packed into one byte because uniforms.bits <= 4.
112
172
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
113
173
// TODO support zero_point_offset for bits > 4
114
174
${
115
175
zeroPoints ? `
116
- var zero_point_index: u32 = n * ((${ nBlocksPerCol } + 1) / 2) / 4;
117
- var zero_point_word: u32 = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
118
- var zero_point_offset: u32 = 0;` :
176
+ var zero_point_index: u32 = n * ${ components } * ((${ nBlocksPerCol } + 1) / 2) / 4;
177
+ var zero_point_word: u32 = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
178
+ var zero_point_offset: u32 = 0;` :
119
179
'' }
120
- var scale_idex = n * ${ nBlocksPerCol } ;
180
+ var scale_index = n * ${ nBlocksPerCol * components } ;
121
181
var b_indices: ${ b . type . indices } ;
122
- ${ b . indicesSet ( 'b_indices' , '0' , 'n' ) } ;
123
- var block_offset: u32 = 0;
124
- for (var block: u32 = 0; block < ${ nBlocksPerCol } ; block++) {
125
- // The scale and zero points are computed per block.
126
- let scale = ${ scales . getByOffset ( 'scale_idex' ) } ;
127
- // The default zero point is 8 for unsigned 4-bit quantization.
128
- let zero_point: ${ dataType } = ${
129
- zeroPoints ? `${ dataType } (extractBits(zero_point_word, zero_point_offset, 4))` : 8.0 } ;
130
- ${ b . indicesSet ( 'b_indices' , '1' , 'block' ) } ;
131
- var word_offset: u32 = block_offset;
132
- for (var word: u32 = 0; word < ${ wordPerBlob } ; word++) {
133
- ${ b . indicesSet ( 'b_indices' , '2' , 'word' ) } ;
134
- let b_value = ${ b . getByIndices ( 'b_indices' ) } ;
135
- let b_quantized_values: array<${ dataType } , 8> = ortUnpack8x4snorm(b_value);
136
- // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
137
- var offset: u32 = word_offset;
138
- for (var i: u32 = 0; i < 8; i++) {
139
- ${ a . indicesSet ( 'a_indices' , aRank - 1 , 'offset' ) } ;
140
- let a_value = ${ a . getByIndices ( 'a_indices' ) } ;
141
- let b_quantized_value = b_quantized_values[i];
142
- let b_dequantized_value = (b_quantized_value - zero_point) * scale;
143
- value += a_value * b_dequantized_value;
144
- offset++;
182
+ for (var c: u32 = 0; c < ${ components } ; c++) {
183
+ ${ b . indicesSet ( 'b_indices' , '0' , `n * ${ components } + c` ) } ;
184
+ var block_offset: u32 = 0;
185
+ for (var block: u32 = 0; block < ${ nBlocksPerCol } ; block++) {
186
+ // The scale and zero points are computed per block.
187
+ let scale = ${ scales . getByOffset ( 'scale_index' ) } ;
188
+ // The default zero point is 8 for unsigned 4-bit quantization.
189
+ let zero_point = ${ dataType } (${ zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0 } );
190
+ ${ b . indicesSet ( 'b_indices' , '1' , 'block' ) } ;
191
+ var word_offset: u32 = block_offset;
192
+ for (var word: u32 = 0; word < ${ blobSizeInWords } ; word += ${ bComponents } ) {
193
+ ${ b . indicesSet ( 'b_indices' , '2' , 'word' ) } ;
194
+ let b_data = ${ b . getByIndices ( 'b_indices' ) } ;
195
+ for (var i: u32 = 0; i < ${ bComponents } ; i++) {
196
+ let b_value = ${ bComponents === 1 ? 'b_data' : 'b_data[word + i]' } ;
197
+ let b_quantized_values: ${ qDqDataType } = ortUnpack8x4snorm(b_value);
198
+ let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
199
+ // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
200
+ var offset: u32 = word_offset;
201
+ for (var j: u32 = 0; j < 8/${ aComponents } ; j++) {
202
+ ${ a . indicesSet ( 'a_indices' , aRank - 1 , `offset/${ aComponents } ` ) } ;
203
+ for (var k: u32 = 0; k < ${ outputNumber } u; k++) {
204
+ ${ a . indicesSet ( 'a_indices' , aRank - 2 , `m * ${ outputNumber } + k` ) } ;
205
+ let a_data = ${ a . getByIndices ( 'a_indices' ) } ;
206
+ output_values[k]${ components > 1 ? '[c]' : '' } += ${
207
+ aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])' } ;
208
+ }
209
+ offset += ${ aComponents } ;
210
+ }
211
+ word_offset += 8;
212
+ }
145
213
}
146
- word_offset += 8;
214
+ scale_index++;
215
+ ${ updateZeroPointIndex }
216
+ block_offset += uniforms.block_size;
147
217
}
148
- scale_idex++;
218
+ // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
149
219
${
150
- zeroPoints ? `
151
- if (zero_point_offset == 28) {
152
- zero_point_offset = 0;
153
- zero_point_index++;
154
- zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
155
- } else {
156
- zero_point_offset += 4;
157
- }` :
220
+ zeroPoints ? `if (zero_point_offset % 8 > 0) {
221
+ ${ updateZeroPointIndex }
222
+ }` :
158
223
'' }
159
- block_offset += uniforms.block_size;
160
- }
161
- ${ output . setByOffset ( 'global_idx' , 'value' ) } ;
162
- }
163
- ` ;
224
+ }
225
+ for (var k: u32 = 0u; k < ${ outputNumber } u; k++) {
226
+ ${ output . indicesSet ( 'output_indices' , aRank - 2 , `${ outputNumber + ' * m + k' } ` ) } ;
227
+ ${ output . setByIndices ( 'output_indices' , 'output_values[k]' ) }
228
+ }
229
+ }` ;
164
230
} ;
165
231
return {
166
232
name : 'MatMulNBits' ,
167
233
shaderCache :
168
234
{ hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'rank' ) } ,
169
235
getRunData : ( ) => ( {
170
236
outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
171
- dispatchGroup : { x : Math . ceil ( outputSize / 64 ) } ,
237
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
172
238
programUniforms
173
239
} ) ,
174
240
getShaderSource
0 commit comments