21
21
22
22
import { LOG_DEBUG } from '../../../log' ;
23
23
import { TensorView } from '../../../tensor-view' ;
24
- import { ProgramInfo , ProgramUniform } from '../../types' ;
25
- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType } from '../common' ;
24
+ import { ProgramInfo , ProgramInputTensorInfoDependency , ProgramUniform } from '../../types' ;
25
+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from '../common' ;
26
26
import { ConvAttributes } from '../conv' ;
27
27
import { getActivationSnippet } from '../fuse-utils' ;
28
28
@@ -88,10 +88,10 @@ const conv2dCommonSnippet =
88
88
let outRow = ${ row } / outWidth;
89
89
let outCol = ${ row } % outWidth;
90
90
91
- let WRow = ${ col } / (filterDims [1] * inChannels);
92
- let WCol = ${ col } / inChannels % filterDims [1];
93
- let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
94
- let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
91
+ let WRow = ${ col } / (i32(uniforms.w_shape [1]) * inChannels);
92
+ let WCol = ${ col } / inChannels % i32(uniforms.w_shape [1]) ;
93
+ let xRow = outRow * uniforms. stride[0] + uniforms. dilation[0] * WRow - uniforms. pad[0];
94
+ let xCol = outCol * uniforms. stride[1] + uniforms. dilation[1] * WCol - uniforms. pad[1];
95
95
let xCh = ${ col } % inChannels;
96
96
var resData = ${ typeSnippet ( innerElementSizeX , dataType ) } (0.0);
97
97
// The bounds checking is always needed since we use it to pad zero for
@@ -108,7 +108,7 @@ const conv2dCommonSnippet =
108
108
${ readXSnippet } ` :
109
109
`
110
110
let col = colIn * ${ innerElementSizeX } ;
111
- if (row < uniforms.dimAOuter && col < uniforms.dimInner ) {
111
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_inner ) {
112
112
${ readXSnippet }
113
113
}
114
114
return ${ typeSnippet ( innerElementSizeX , dataType ) } (0.0);` ) :
@@ -117,7 +117,7 @@ const conv2dCommonSnippet =
117
117
${ readXSnippet } ` :
118
118
`
119
119
let col = colIn * ${ innerElementSizeX } ;
120
- if (row < uniforms.dimInner && col < uniforms.dimBOuter ) {
120
+ if (row < uniforms.dim_inner && col < uniforms.dim_b_outer ) {
121
121
${ readXSnippet }
122
122
}
123
123
return ${ typeSnippet ( innerElementSizeX , dataType ) } (0.0);` ) ;
@@ -129,9 +129,8 @@ const conv2dCommonSnippet =
129
129
isChannelsLast ? typeSnippet ( innerElementSizeX , dataType ) : typeSnippet ( innerElementSizeW , dataType ) ;
130
130
const bType =
131
131
isChannelsLast ? typeSnippet ( innerElementSizeW , dataType ) : typeSnippet ( innerElementSizeX , dataType ) ;
132
- const { activationFunction , applyActivation} = getActivationSnippet ( attributes , resType ) ;
132
+ const applyActivation = getActivationSnippet ( attributes , resType ) ;
133
133
const userCode = `
134
- ${ activationFunction }
135
134
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${ aType } {
136
135
${ isChannelsLast ? sampleX : sampleW }
137
136
}
@@ -142,7 +141,7 @@ const conv2dCommonSnippet =
142
141
143
142
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${ resType } ) {
144
143
let col = colIn * ${ innerElementSize } ;
145
- if (row < uniforms.dimAOuter && col < uniforms.dimBOuter )
144
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer )
146
145
{
147
146
var value = valueIn;
148
147
let outWidth = ${ isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])' } ;
@@ -181,83 +180,97 @@ export const createConv2DMatMulProgramInfo =
181
180
LOG_DEBUG ( 'verbose' , ( ) => `[conv2d_mm_webgpu] dispatch = ${ dispatch } ` ) ;
182
181
183
182
const innerElementSize = isVec4 ? ( isChannelsLast && inChannels % 4 !== 0 ? 3 : 4 ) : 1 ;
184
-
185
183
const tileAOuter = workGroupSize [ 1 ] * elementsPerThread [ 1 ] ;
186
184
const tileBOuter = workGroupSize [ 0 ] * elementsPerThread [ 0 ] ;
187
185
const tileInner = Math . max ( workGroupSize [ 0 ] * innerElementSize , workGroupSize [ 1 ] ) ;
188
-
189
186
const fitAOuter = dimAOuter % tileAOuter === 0 ;
190
187
const fitBOuter = dimBOuter % tileBOuter === 0 ;
191
188
const fitInner = dimInner % tileInner === 0 ;
192
-
193
189
const elementsSize = isVec4 ? [ innerElementSize , 4 , 4 ] : [ 1 , 1 , 1 ] ;
194
- const t = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
195
190
196
- // TODO: support component 2, 3.
197
- const components = isVec4 ? 4 : 1 ;
198
- const programUniforms : ProgramUniform [ ] =
199
- [ { type : 'int32' , data : dimAOuter } , { type : 'int32' , data : dimBOuter } , { type : 'int32' , data : dimInner } ] ;
200
- const x =
201
- inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , innerElementSize === 3 ? 1 : innerElementSize ) ;
202
- const w = inputVariable ( 'w' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , components ) ;
203
- const inputVariables = [ x , w ] ;
191
+ const programUniforms : ProgramUniform [ ] = [
192
+ { type : 'int32' , data : dimAOuter } , { type : 'int32' , data : dimBOuter } , { type : 'int32' , data : dimInner } ,
193
+ { type : 'int32' , data : [ attributes . pads [ 0 ] , attributes . pads [ 1 ] ] } , { type : 'int32' , data : attributes . strides } ,
194
+ { type : 'int32' , data : attributes . dilations }
195
+ ] ;
196
+ if ( attributes . activation === 'Clip' ) {
197
+ programUniforms . push (
198
+ { type : 'float32' , data : attributes . clipMax ! } , { type : 'float32' , data : attributes . clipMin ! } ) ;
199
+ }
200
+ programUniforms . push (
201
+ ...createTensorShapeVariables ( inputs [ 0 ] . dims ) , ...createTensorShapeVariables ( inputs [ 1 ] . dims ) ) ;
202
+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ 'rank' , 'rank' ] ;
203
+ if ( hasBias ) {
204
+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
205
+ inputDependencies . push ( 'rank' ) ;
206
+ }
207
+ programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
204
208
205
- programUniforms . push ( ...createTensorShapeVariables ( inputs [ 0 ] . dims ) ) ;
206
- programUniforms . push ( ...createTensorShapeVariables ( inputs [ 1 ] . dims ) ) ;
209
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
210
+ const uniforms : UniformsArrayType = [
211
+ { name : 'dim_a_outer' , type : 'i32' } , { name : 'dim_b_outer' , type : 'i32' } , { name : 'dim_inner' , type : 'i32' } ,
212
+ { name : 'pad' , type : 'i32' , length : 2 } , { name : 'stride' , type : 'i32' , length : 2 } ,
213
+ { name : 'dilation' , type : 'i32' , length : 2 }
214
+ ] ;
215
+ if ( attributes . activation === 'Clip' ) {
216
+ uniforms . push ( { name : 'clip_max' , type : 'f32' } , { name : 'clip_min' , type : 'f32' } ) ;
217
+ }
207
218
208
- let declareFunctions = `
219
+ // TODO: support component 2, 3.
220
+ const components = isVec4 ? 4 : 1 ;
221
+ const t = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
222
+ let declareFunctions = `
209
223
fn setOutputAtIndex(flatIndex : i32, value : ${ isVec4 ? `vec4<${ t } >` : t } ) {
210
224
result[flatIndex] = ${ isVec4 ? `vec4<${ t } >` : t } (value);
211
225
}
212
226
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${ isVec4 ? `vec4<${ t } >` : t } ) {
213
227
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
214
228
setOutputAtIndex(flatIndex ${ isVec4 ? '/ 4' : '' } , value);
215
229
}` ;
216
- if ( hasBias ) {
217
- const bias = inputVariable ( 'bias' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , components ) ;
218
- inputVariables . push ( bias ) ;
219
-
220
- programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
221
-
222
- declareFunctions += `
230
+ const x = inputVariable (
231
+ 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , innerElementSize === 3 ? 1 : innerElementSize ) ;
232
+ const w = inputVariable ( 'w' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , components ) ;
233
+ const inputVariables = [ x , w ] ;
234
+ const output = outputVariable ( 'result' , inputs [ 0 ] . dataType , outputShape . length , components ) ;
235
+ if ( hasBias ) {
236
+ const bias = inputVariable ( 'bias' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , components ) ;
237
+ inputVariables . push ( bias ) ;
238
+ declareFunctions += `
223
239
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${ isVec4 ? `vec4<${ t } >` : t } {
224
240
return bias[coords.${ isChannelsLast ? 'w' : 'y' } ${ isVec4 ? '/ 4' : '' } ];
225
241
}` ;
226
- }
227
- const output = outputVariable ( 'result' , inputs [ 0 ] . dataType , outputShape . length , components ) ;
228
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
229
- return {
230
- name : 'Conv2DMatMul' ,
231
- shaderCache : { hint : attributes . cacheKey } ,
232
- getRunData : ( ) => ( {
233
- outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
234
- dispatchGroup : { x : dispatch [ 0 ] , y : dispatch [ 1 ] , z : dispatch [ 2 ] } ,
235
- programUniforms,
236
- } ) ,
237
- getShaderSource : ( shaderHelper : ShaderHelper ) => `
242
+ }
243
+
244
+ return `
238
245
${ utilFunctions ( 'uniforms.result_strides' ) }
239
246
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
240
247
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
241
248
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
242
- ${
243
- shaderHelper . registerUniform ( 'dimAOuter' , 'i32' )
244
- . registerUniform ( 'dimBOuter' , 'i32' )
245
- . registerUniform ( 'dimInner' , 'i32' )
246
- . declareVariables ( ...inputVariables , output ) }
247
- const filterDims : vec2<i32> = vec2<i32>(${ attributes . kernelShape [ 0 ] } , ${ attributes . kernelShape [ 1 ] } );
248
- const pad : vec2<i32> = vec2<i32>(${ attributes . pads [ 0 ] } , ${ attributes . pads [ 1 ] } );
249
- const stride : vec2<i32> = vec2<i32>(${ attributes . strides [ 0 ] } , ${ attributes . strides [ 1 ] } );
250
- const dilation : vec2<i32> = vec2<i32>(${ attributes . dilations [ 0 ] } , ${ attributes . dilations [ 1 ] } );
249
+ ${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) }
251
250
${ declareFunctions }
252
251
${
253
252
conv2dCommonSnippet (
254
253
isChannelsLast , fitAOuter , fitBOuter , fitInner , hasBias , attributes , elementsSize [ 0 ] , elementsSize [ 1 ] ,
255
254
elementsSize [ 2 ] , t ) }
256
- ${
255
+ ${
257
256
isVec4 ?
258
257
makeMatMulPackedVec4Source ( elementsPerThread , workGroupSize , t , undefined , ! isChannelsLast , tileInner ) :
259
258
makeMatMulPackedSource (
260
259
elementsPerThread , workGroupSize , t , undefined , ! isChannelsLast , tileInner , false , undefined ,
261
- sequentialAccessByThreads ) } `
260
+ sequentialAccessByThreads ) } `;
261
+ } ;
262
+ return {
263
+ name : 'Conv2DMatMul' ,
264
+ shaderCache : {
265
+ hint : `${ attributes . cacheKey } ;${ innerElementSize } ;${ isVec4 } ;${ fitAOuter } ;${ fitBOuter } ;${ fitInner } ;${
266
+ tileAOuter } ;${ tileBOuter } ;${ tileInner } `,
267
+ inputDependencies
268
+ } ,
269
+ getRunData : ( ) => ( {
270
+ outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
271
+ dispatchGroup : { x : dispatch [ 0 ] , y : dispatch [ 1 ] , z : dispatch [ 2 ] } ,
272
+ programUniforms,
273
+ } ) ,
274
+ getShaderSource
262
275
} ;
263
276
} ;
0 commit comments