@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
7
7
import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
8
import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10
+ import { getMaxComponents , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
// TODO support quantization bits not equal to 4
13
13
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -51,36 +51,91 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
51
51
52
52
export const createMatMulNBitsProgramInfo =
53
53
( inputs : readonly TensorView [ ] , attributes : MatMulNBitsAttributes ) : ProgramInfo => {
54
- const a = inputs [ 0 ] ;
55
- const aRank = a . dims . length ;
56
- const outputShape = a . dims . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
57
- const outputSize = ShapeUtil . size ( outputShape ) ;
58
-
59
-
54
+ const inputShape = inputs [ 0 ] . dims ;
55
+ const aRank = inputShape . length ;
56
+ const outputShape = inputShape . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
57
+ const m = inputShape [ aRank - 2 ] ;
58
+ const blobSize = attributes . blockSize / 8 * attributes . bits ;
59
+ const blobSizeInWords = blobSize / 4 ;
60
+ const outputNumber = getMaxComponents ( m ) ;
61
+ const components = 1 ; // getMaxComponents(attributes.n);
62
+ const aComponents = getMaxComponents ( attributes . k ) ;
63
+ const bComponents = getMaxComponents ( blobSizeInWords ) ;
64
+ const zComponents = 1 ; // getMaxComponents(attributes.n / 8);
65
+ const outputSize = ShapeUtil . size ( outputShape ) / components / outputNumber ;
60
66
const programUniforms : ProgramUniform [ ] = [
61
67
{ type : DataType . uint32 , data : outputSize } , { type : DataType . uint32 , data : attributes . k } ,
62
68
{ type : DataType . uint32 , data : attributes . n } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
63
69
{ type : DataType . uint32 , data : attributes . bits } , { type : DataType . uint32 , data : attributes . blockSize }
64
70
] ;
65
71
const getShaderSource = ( shaderHelper : ShaderHelper ) => {
66
- const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims ) ;
67
- const b = inputVariable ( 'b' , DataType . uint32 , ShapeUtil . convertShape ( inputs [ 1 ] . dims ) ) ;
72
+ const aShape = inputs [ 0 ] . dims . slice ( ) ;
73
+ aShape . splice ( - 1 , 1 , attributes . k / aComponents ) ;
74
+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , aShape , aComponents ) ;
75
+ const bShape = ShapeUtil . convertShape ( inputs [ 1 ] . dims ) . slice ( ) ;
76
+ bShape . splice ( - 1 , 1 , blobSizeInWords / bComponents ) ;
77
+ const b = inputVariable ( 'b' , DataType . uint32 , bShape , bComponents ) ;
68
78
const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims ) ;
69
79
const inputVariables = [ a , b , scales ] ;
70
- const zeroPoints =
71
- inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims ) : undefined ;
80
+ const zeroPoints = inputs . length === 4 ?
81
+ inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims , zComponents ) :
82
+ undefined ;
72
83
if ( zeroPoints ) {
73
84
inputVariables . push ( zeroPoints ) ;
74
85
}
75
- const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape ) ;
86
+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape , components ) ;
76
87
const uniforms : UniformsArrayType = [
77
- { name : 'output_size' , type : 'u32' } , { name : 'k ' , type : 'u32' } , { name : 'n ' , type : 'u32' } ,
88
+ { name : 'output_size' , type : 'u32' } , { name : 'K ' , type : 'u32' } , { name : 'N ' , type : 'u32' } ,
78
89
{ name : 'accuracy_level' , type : 'u32' } , { name : 'bits' , type : 'u32' } , { name : 'block_size' , type : 'u32' }
79
90
] ;
80
91
const nBlocksPerCol = Math . floor ( ( attributes . k + attributes . blockSize - 1 ) / attributes . blockSize ) ;
81
- const blobSize = attributes . blockSize / 8 * attributes . bits ;
82
- const wordPerBlob = blobSize / 4 ;
83
92
const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
93
+ const dequantizeArrayReturnType = ( ( ) => {
94
+ switch ( aComponents ) {
95
+ case 1 :
96
+ return `array<${ dataType } , 8>` ;
97
+ case 2 :
98
+ return `array<vec2<${ dataType } >, 4>` ;
99
+ case 4 :
100
+ return `array<vec4<${ dataType } >, 2>` ;
101
+ default :
102
+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
103
+ }
104
+ } ) ( ) ;
105
+ const dequantizeArrayImpl =
106
+ ( ( ) => `fn dequantize_array(quantized_data: array<${ dataType } , 8>, zero_point: ${ dataType } , scale: ${
107
+ dataType } ) -> ${ dequantizeArrayReturnType } {
108
+ var result: ${ dequantizeArrayReturnType } ;
109
+ ${ ( ( ) => {
110
+ switch ( aComponents ) {
111
+ case 1 :
112
+ return `
113
+ for (var i: u32 = 0; i < 8; i++) {
114
+ result[i] = dequantize(quantized_data[i], zero_point, scale);
115
+ }` ;
116
+ case 2 :
117
+ return `
118
+ for (var i: u32 = 0; i < 4; i++) {
119
+ let dequantized0 = dequantize(quantized_data[i*2], zero_point, scale);
120
+ let dequantized1 = dequantize(quantized_data[i*2+1], zero_point, scale);
121
+ result[i] = vec2<${ dataType } >(dequantized0, dequantized1);
122
+ }` ;
123
+ case 4 :
124
+ return `
125
+ for (var i: u32 = 0; i < 2; i++) {
126
+ let dequantized0 = dequantize(quantized_data[i*4], zero_point, scale);
127
+ let dequantized1 = dequantize(quantized_data[i*4+1], zero_point, scale);
128
+ let dequantized2 = dequantize(quantized_data[i*4+2], zero_point, scale);
129
+ let dequantized3 = dequantize(quantized_data[i*4+3], zero_point, scale);
130
+ result[i] = vec4<${ dataType } >(dequantized0, dequantized1, dequantized2, dequantized3);
131
+ }` ;
132
+ default :
133
+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
134
+ }
135
+ } ) ( ) }
136
+ return result;
137
+ }` ) ( ) ;
138
+
84
139
return `
85
140
fn ortUnpack8x4snorm(value: u32) -> array<${ dataType } , 8>{
86
141
var result = array<${ dataType } , 8>();
@@ -92,13 +147,21 @@ export const createMatMulNBitsProgramInfo =
92
147
}
93
148
return result;
94
149
}
150
+
151
+ fn dequantize(value: ${ dataType } , zero_point: ${ dataType } , scale: ${ dataType } ) -> ${ dataType } {
152
+ return (value - zero_point) * scale;
153
+ }
154
+
155
+ ${ dequantizeArrayImpl } ;
156
+
95
157
${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) }
96
158
${ shaderHelper . mainStart ( ) }
97
159
${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
98
- var value: ${ dataType } = 0.0;
99
- let output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
100
- var a_indices: ${ a . type . indices } = output_indices;
160
+ var output_values: array<${ output . type . value } , ${ outputNumber } >;
161
+ var output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
101
162
var n = ${ output . indicesGet ( 'output_indices' , aRank - 1 ) } ;
163
+ var m = ${ output . indicesGet ( 'output_indices' , aRank - 2 ) } ;
164
+ var a_indices: ${ a . type . indices } = output_indices;
102
165
// Two zero points are packed into one byte because uniforms.bits <= 4.
103
166
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
104
167
// TODO support zero_point_offset for bits > 4
@@ -108,35 +171,41 @@ export const createMatMulNBitsProgramInfo =
108
171
var zero_point_word: u32 = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
109
172
var zero_point_offset: u32 = 0;` :
110
173
'' }
111
- var scale_idex = n * ${ nBlocksPerCol } ;
174
+ var scale_index = n * ${ nBlocksPerCol } ;
112
175
var b_indices: ${ b . type . indices } ;
113
176
${ b . indicesSet ( 'b_indices' , '0' , 'n' ) } ;
114
177
var block_offset: u32 = 0;
115
178
for (var block: u32 = 0; block < ${ nBlocksPerCol } ; block++) {
116
179
// The scale and zero points are computed per block.
117
- let scale = ${ scales . getByOffset ( 'scale_idex ' ) } ;
180
+ let scale = ${ scales . getByOffset ( 'scale_index ' ) } ;
118
181
// The default zero point is 8 for unsigned 4-bit quantization.
119
182
let zero_point: ${ dataType } = ${
120
183
zeroPoints ? `${ dataType } (extractBits(zero_point_word, zero_point_offset, 4))` : 8.0 } ;
121
184
${ b . indicesSet ( 'b_indices' , '1' , 'block' ) } ;
122
185
var word_offset: u32 = block_offset;
123
- for (var word: u32 = 0; word < ${ wordPerBlob } ; word++ ) {
186
+ for (var word: u32 = 0; word < ${ blobSizeInWords } ; word += ${ bComponents } ) {
124
187
${ b . indicesSet ( 'b_indices' , '2' , 'word' ) } ;
125
- let b_value = ${ b . getByIndices ( 'b_indices' ) } ;
126
- let b_quantized_values: array<${ dataType } , 8> = ortUnpack8x4snorm(b_value);
127
- // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
128
- var offset: u32 = word_offset;
129
- for (var i: u32 = 0; i < 8; i++) {
130
- ${ a . indicesSet ( 'a_indices' , aRank - 1 , 'offset' ) } ;
131
- let a_value = ${ a . getByIndices ( 'a_indices' ) } ;
132
- let b_quantized_value = b_quantized_values[i];
133
- let b_dequantized_value = (b_quantized_value - zero_point) * scale;
134
- value += a_value * b_dequantized_value;
135
- offset++;
188
+ let b_data = ${ b . getByIndices ( 'b_indices' ) } ;
189
+ for (var i: u32 = 0; i < ${ bComponents } ; i++) {
190
+ let b_value = ${ bComponents === 1 ? 'b_data' : 'b_data[word + i]' } ;
191
+ let b_quantized_values: array<${ dataType } , 8> = ortUnpack8x4snorm(b_value);
192
+ let b_dequantized_values = dequantize_array(b_quantized_values, zero_point, scale);
193
+ // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
194
+ var offset: u32 = word_offset;
195
+ for (var j: u32 = 0; j < 8/${ aComponents } ; j++) {
196
+ ${ a . indicesSet ( 'a_indices' , aRank - 1 , `offset/${ aComponents } ` ) } ;
197
+ for (var k: u32 = 0; k < ${ outputNumber } ; k++) {
198
+ ${ a . indicesSet ( 'a_indices' , aRank - 2 , `m * ${ outputNumber } + k` ) } ;
199
+ let a_data = ${ a . getByIndices ( 'a_indices' ) } ;
200
+ output_values[k] += ${
201
+ aComponents === 1 ? 'a_data * b_dequantized_values[j]' : `dot(a_data, b_dequantized_values[j])` } ;
202
+ }
203
+ offset += ${ aComponents } ;
204
+ }
205
+ word_offset += 8;
136
206
}
137
- word_offset += 8;
138
207
}
139
- scale_idex ++;
208
+ scale_index ++;
140
209
${
141
210
zeroPoints ? `
142
211
if (zero_point_offset == 28) {
@@ -149,17 +218,20 @@ export const createMatMulNBitsProgramInfo =
149
218
'' }
150
219
block_offset += uniforms.block_size;
151
220
}
152
- ${ output . setByOffset ( 'global_idx' , 'value' ) } ;
221
+ for (var k: u32 = 0u; k < ${ outputNumber } u; k++) {
222
+ ${ output . indicesSet ( 'output_indices' , aRank - 2 , `${ outputNumber + ' * m + k' } ` ) } ;
223
+ ${ output . setByIndices ( 'output_indices' , 'output_values[k]' ) }
224
+ }
153
225
}
154
226
` ;
155
227
} ;
156
228
return {
157
229
name : 'MatMulNBits' ,
158
230
shaderCache :
159
- { hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'dims ' ) } ,
231
+ { hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'rank ' ) } ,
160
232
getRunData : ( ) => ( {
161
233
outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
162
- dispatchGroup : { x : Math . ceil ( outputSize / 64 ) } ,
234
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
163
235
programUniforms
164
236
} ) ,
165
237
getShaderSource
0 commit comments