21
21
22
22
import { LOG_DEBUG } from '../../../log' ;
23
23
import { TensorView } from '../../../tensor-view' ;
24
- import { ProgramInfo , ProgramUniform } from '../../types' ;
25
- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper } from '../common' ;
24
+ import { ProgramInfo , ProgramInputTensorInfoDependency , ProgramUniform } from '../../types' ;
25
+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , UniformsArrayType } from '../common' ;
26
26
import { ConvTransposeAttributes } from '../conv-transpose' ;
27
27
import { getActivationSnippet } from '../fuse-utils' ;
28
28
@@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet =
74
74
col % outWidth);
75
75
` ;
76
76
77
- const xHeight = isChannelsLast ? 'outBackprop [1]' : 'outBackprop [2]' ;
78
- const xWidth = isChannelsLast ? 'outBackprop [2]' : 'outBackprop [3]' ;
77
+ const xHeight = isChannelsLast ? 'i32(uniforms.x_shape [1]) ' : 'i32(uniforms.x_shape [2]) ' ;
78
+ const xWidth = isChannelsLast ? 'i32(uniforms.x_shape [2]) ' : 'i32(uniforms.x_shape [3]) ' ;
79
79
const row = isChannelsLast ? 'row' : 'col' ;
80
80
const col = isChannelsLast ? 'col' : 'row' ;
81
81
82
82
const readASnippet = `
83
- let inChannels = ${ isChannelsLast ? 'outBackprop [3]' : 'outBackprop [1]' } ;
83
+ let inChannels = ${ isChannelsLast ? 'i32(uniforms.x_shape [3]) ' : 'i32(uniforms.x_shape [1]) ' } ;
84
84
let outWidth = ${ isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])' } ;
85
85
let outRow = ${ row } / outWidth;
86
86
let outCol = ${ row } % outWidth;
87
87
88
- let WRow = ${ col } / (filterDims[1] * inChannels);
89
- let WCol = ${ col } / inChannels % filterDims[1];
90
- let xR = f32(outRow - pads[0] + dilation [0] * WRow) / f32(strides[0]);
91
- let xC = f32(outCol - pads[1] + dilation [1] * WCol) / f32(strides[1]);
88
+ let WRow = ${ col } / (uniforms. filterDims[1] * inChannels);
89
+ let WCol = ${ col } / inChannels % uniforms. filterDims[1];
90
+ let xR = f32(outRow - uniforms. pads[0] + uniforms.dilations [0] * WRow) / f32(uniforms. strides[0]);
91
+ let xC = f32(outCol - uniforms. pads[1] + uniforms.dilations [1] * WCol) / f32(uniforms. strides[1]);
92
92
if (xR < 0.0 || xR >= f32(${ xHeight } ) || fract(xR) > 0.0) {
93
93
return ${ type } (0.0);
94
94
}
@@ -116,9 +116,9 @@ const conv2dTransposeCommonSnippet =
116
116
117
117
const sampleW = `
118
118
let col = colIn * ${ innerElementSize } ;
119
- let inChannels = ${ isChannelsLast ? 'outBackprop [3]' : 'outBackprop [1]' } ;
120
- let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
121
- let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
119
+ let inChannels = ${ isChannelsLast ? 'i32(uniforms.x_shape [3]) ' : 'i32(uniforms.x_shape [1]) ' } ;
120
+ let coordX = uniforms.filterDims[0] - 1 - row / (uniforms. filterDims[1] * inChannels);
121
+ let coordY = uniforms.filterDims[1] - 1 - (row / inChannels) % uniforms. filterDims[1];
122
122
if (${
123
123
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
124
124
'row < uniforms.dimInner && col < uniforms.dimAOuter' } && coordX >= 0 && coordY >= 0) {
@@ -186,20 +186,33 @@ export const createConv2DTransposeMatMulProgramInfo =
186
186
const innerElementSize = isVec4 ? 4 : 1 ;
187
187
const tileInner = Math . max ( workGroupSize [ 0 ] * innerElementSize , workGroupSize [ 1 ] ) ;
188
188
const components = isVec4 ? 4 : 1 ;
189
- const programUniforms : ProgramUniform [ ] =
190
- [ { type : 'int32' , data : dimAOuter } , { type : 'int32' , data : dimBOuter } , { type : 'int32' , data : dimInner } ] ;
189
+ const filterDims0 = attributes . kernelShape [ isChannelsLast ? 1 : 2 ] ;
190
+ const filterDims1 = attributes . kernelShape [ isChannelsLast ? 2 : 3 ] ;
191
+ const effectiveFilterDims0 =
192
+ filterDims0 + ( attributes . dilations [ 0 ] <= 1 ? 0 : ( filterDims0 - 1 ) * ( attributes . dilations [ 0 ] - 1 ) ) ;
193
+ const effectiveFilterDims1 =
194
+ filterDims1 + ( attributes . dilations [ 1 ] <= 1 ? 0 : ( filterDims1 - 1 ) * ( attributes . dilations [ 1 ] - 1 ) ) ;
195
+ const pads0 = effectiveFilterDims0 - 1 - Math . floor ( ( attributes . pads [ 0 ] + attributes . pads [ 2 ] ) / 2 ) ;
196
+ const pads1 = effectiveFilterDims1 - 1 - Math . floor ( ( attributes . pads [ 1 ] + attributes . pads [ 3 ] ) / 2 ) ;
197
+ const programUniforms : ProgramUniform [ ] = [
198
+ { type : 'int32' , data : dimAOuter } , { type : 'int32' , data : dimBOuter } , { type : 'int32' , data : dimInner } ,
199
+ { type : 'int32' , data : attributes . strides } , { type : 'int32' , data : attributes . dilations } ,
200
+ { type : 'int32' , data : [ filterDims0 , filterDims1 ] } , { type : 'int32' , data : [ pads0 , pads1 ] }
201
+ ] ;
191
202
const x = inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , components ) ;
192
203
const w = inputVariable ( 'w' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , 1 ) ;
193
204
const output = outputVariable ( 'result' , inputs [ 0 ] . dataType , outputShape . length , components ) ;
194
205
const inputVariables = [ x , w ] ;
195
- programUniforms . push ( ... createTensorShapeVariables ( inputs [ 0 ] . dims ) ) ;
196
- programUniforms . push ( ...createTensorShapeVariables ( inputs [ 1 ] . dims ) ) ;
206
+ programUniforms . push (
207
+ ... createTensorShapeVariables ( inputs [ 0 ] . dims ) , ...createTensorShapeVariables ( inputs [ 1 ] . dims ) ) ;
197
208
209
+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ 'rank' , 'rank' ] ;
198
210
let declareFunctions = '' ;
199
211
if ( hasBias ) {
200
212
const bias = inputVariable ( 'bias' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , components ) ;
201
213
inputVariables . push ( bias ) ;
202
214
programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
215
+ inputDependencies . push ( 'rank' ) ;
203
216
204
217
declareFunctions += `
205
218
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${ isVec4 ? 'vec4<f32>' : 'f32' } {
@@ -209,42 +222,23 @@ export const createConv2DTransposeMatMulProgramInfo =
209
222
210
223
programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
211
224
225
+ const uniforms : UniformsArrayType = [
226
+ { name : 'dimAOuter' , type : 'i32' } , { name : 'dimBOuter' , type : 'i32' } , { name : 'dimInner' , type : 'i32' } ,
227
+ { name : 'strides' , type : 'i32' , length : 2 } , { name : 'dilations' , type : 'i32' , length : 2 } ,
228
+ { name : 'filterDims' , type : 'i32' , length : 2 } , { name : 'pads' , type : 'i32' , length : 2 }
229
+ ] ;
230
+
212
231
return {
213
232
name : 'Conv2DTransposeMatMul' ,
214
- shaderCache : { hint : attributes . cacheKey } ,
233
+ shaderCache : { hint : ` ${ attributes . format } ` , inputDependencies } ,
215
234
getRunData : ( ) => ( {
216
235
outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
217
236
dispatchGroup : { x : dispatch [ 0 ] , y : dispatch [ 1 ] , z : dispatch [ 2 ] } ,
218
237
programUniforms
219
238
} ) ,
220
239
getShaderSource : ( shaderHelper : ShaderHelper ) => `
221
240
${ utilFunctions ( 'uniforms.result_strides' ) }
222
- ${
223
- shaderHelper . registerUniform ( 'dimAOuter' , 'i32' )
224
- . registerUniform ( 'dimBOuter' , 'i32' )
225
- . registerUniform ( 'dimInner' , 'i32' )
226
- . declareVariables ( ...inputVariables , output ) } ;
227
- const outBackprop : vec4<i32> = vec4<i32>(${ inputs [ 0 ] . dims . join ( ',' ) } );
228
- const filterDims : vec2<i32> = vec2<i32>(${ attributes . kernelShape [ isChannelsLast ? 1 : 2 ] } , ${
229
- attributes . kernelShape [ isChannelsLast ? 2 : 3 ] } );
230
- const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
231
- ${
232
- attributes . dilations [ 0 ] <= 1 ?
233
- 0 :
234
- ( attributes . kernelShape [ isChannelsLast ? 1 : 2 ] - 1 ) * ( attributes . dilations [ 0 ] - 1 ) } ,
235
- ${
236
- attributes . dilations [ 1 ] <= 1 ?
237
- 0 :
238
- ( attributes . kernelShape [ isChannelsLast ? 2 : 3 ] - 1 ) * ( attributes . dilations [ 1 ] - 1 ) } );
239
- const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
240
- attributes . pads [ 0 ] + attributes . pads [ 2 ] } )/2,
241
- i32(effectiveFilterDims[1]) - 1 - (${
242
- attributes . pads [ 1 ] + attributes . pads [ 3 ] } )/2);
243
- const strides : vec2<i32> = vec2<i32>(${ attributes . strides [ 0 ] } , ${ attributes . strides [ 1 ] } );
244
- const dilation : vec2<i32> = vec2<i32>(${ attributes . dilations [ 0 ] } , ${ attributes . dilations [ 1 ] } );
245
- const dimAOuter : i32 = ${ dimAOuter } ;
246
- const dimBOuter : i32 = ${ dimBOuter } ;
247
- const dimInner : i32 = ${ dimInner } ;
241
+ ${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) } ;
248
242
${ declareFunctions }
249
243
${ conv2dTransposeCommonSnippet ( isChannelsLast , hasBias , attributes , innerElementSize ) }
250
244
${
0 commit comments