Skip to content

Commit 7e77cc2

Browse files
committed
[js/webgpu] Support uniforms for conv, conv transpose, conv grouped
1 parent 9479ba5 commit 7e77cc2

File tree

7 files changed

+219
-184
lines changed

7 files changed

+219
-184
lines changed

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import {biasAdd} from './ops/bias-add';
88
import {biasSplitGelu} from './ops/bias-split-gelu';
99
import * as binaryOps from './ops/binary-op';
1010
import {concat, parseConcatAttributes} from './ops/concat';
11-
import {conv, parseConvAttributes} from './ops/conv';
12-
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
11+
import {conv} from './ops/conv';
12+
import {convTranspose} from './ops/conv-transpose';
1313
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
1414
import {einsum, parseEinsumAttributes} from './ops/einsum';
1515
import {expand} from './ops/expand';
@@ -60,8 +60,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
6060
['Ceil', [unaryOps.ceil]],
6161
['Clip', [unaryOps.clip]],
6262
['Concat', [concat, parseConcatAttributes]],
63-
['Conv', [conv, parseConvAttributes]],
64-
['ConvTranspose', [convTranspose, parseConvTransposeAttributes]],
63+
['Conv', [conv]],
64+
['ConvTranspose', [convTranspose]],
6565
['Cos', [unaryOps.cos]],
6666
['Cosh', [unaryOps.cosh]],
6767
['CumSum', [cumsum, parseCumSumAttributes]],
@@ -73,7 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
7373
['Exp', [unaryOps.exp]],
7474
['Expand', [expand]],
7575
['Floor', [unaryOps.floor]],
76-
['FusedConv', [conv, parseConvAttributes]],
76+
['FusedConv', [conv]],
7777
['Gather', [gather, parseGatherAttributes]],
7878
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
7979
['Gelu', [unaryOps.gelu]],

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts

+22-20
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import {LOG_DEBUG} from '../../../log';
2323
import {TensorView} from '../../../tensor-view';
24-
import {ProgramInfo, ProgramUniform} from '../../types';
25-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
24+
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
25+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
2626
import {ConvAttributes} from '../conv';
2727
import {getActivationSnippet} from '../fuse-utils';
2828

@@ -88,10 +88,10 @@ const conv2dCommonSnippet =
8888
let outRow = ${row} / outWidth;
8989
let outCol = ${row} % outWidth;
9090
91-
let WRow = ${col} / (filterDims[1] * inChannels);
92-
let WCol = ${col} / inChannels % filterDims[1];
93-
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
94-
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
91+
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
92+
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
93+
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
94+
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
9595
let xCh = ${col} % inChannels;
9696
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
9797
// The bounds checking is always needed since we use it to pad zero for
@@ -195,15 +195,18 @@ export const createConv2DMatMulProgramInfo =
195195

196196
// TODO: support component 2, 3.
197197
const components = isVec4 ? 4 : 1;
198-
const programUniforms: ProgramUniform[] =
199-
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
198+
const programUniforms: ProgramUniform[] = [
199+
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
200+
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
201+
{type: 'int32', data: attributes.dilations}
202+
];
200203
const x =
201204
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
202205
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
203206
const inputVariables = [x, w];
204207

205-
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
206-
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
208+
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
209+
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
207210

208211
let declareFunctions = `
209212
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
@@ -218,6 +221,7 @@ export const createConv2DMatMulProgramInfo =
218221
inputVariables.push(bias);
219222

220223
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
224+
inputDependencies.push('rank');
221225

222226
declareFunctions += `
223227
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
@@ -226,9 +230,15 @@ export const createConv2DMatMulProgramInfo =
226230
}
227231
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
228232
programUniforms.push(...createTensorShapeVariables(outputShape));
233+
234+
const uniforms: UniformsArrayType = [
235+
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
236+
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
237+
{name: 'dilation', type: 'i32', length: 2}
238+
];
229239
return {
230240
name: 'Conv2DMatMul',
231-
shaderCache: {hint: attributes.cacheKey},
241+
shaderCache: {hint: `${attributes.format}`, inputDependencies},
232242
getRunData: () => ({
233243
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
234244
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
@@ -239,15 +249,7 @@ export const createConv2DMatMulProgramInfo =
239249
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
240250
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
241251
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
242-
${
243-
shaderHelper.registerUniform('dimAOuter', 'i32')
244-
.registerUniform('dimBOuter', 'i32')
245-
.registerUniform('dimInner', 'i32')
246-
.declareVariables(...inputVariables, output)}
247-
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
248-
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
249-
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
250-
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
252+
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
251253
${declareFunctions}
252254
${
253255
conv2dCommonSnippet(

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts

+37-43
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import {LOG_DEBUG} from '../../../log';
2323
import {TensorView} from '../../../tensor-view';
24-
import {ProgramInfo, ProgramUniform} from '../../types';
25-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
24+
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
25+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
2626
import {ConvTransposeAttributes} from '../conv-transpose';
2727
import {getActivationSnippet} from '../fuse-utils';
2828

@@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet =
7474
col % outWidth);
7575
`;
7676

77-
const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
78-
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
77+
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
78+
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
7979
const row = isChannelsLast ? 'row' : 'col';
8080
const col = isChannelsLast ? 'col' : 'row';
8181

8282
const readASnippet = `
83-
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
83+
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
8484
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
8585
let outRow = ${row} / outWidth;
8686
let outCol = ${row} % outWidth;
8787
88-
let WRow = ${col} / (filterDims[1] * inChannels);
89-
let WCol = ${col} / inChannels % filterDims[1];
90-
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
91-
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
88+
let WRow = ${col} / (uniforms.filterDims[1] * inChannels);
89+
let WCol = ${col} / inChannels % uniforms.filterDims[1];
90+
let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]);
91+
let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]);
9292
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
9393
return ${type}(0.0);
9494
}
@@ -116,9 +116,9 @@ const conv2dTransposeCommonSnippet =
116116

117117
const sampleW = `
118118
let col = colIn * ${innerElementSize};
119-
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
120-
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
121-
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
119+
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
120+
let coordX = uniforms.filterDims[0] - 1 - row / (uniforms.filterDims[1] * inChannels);
121+
let coordY = uniforms.filterDims[1] - 1 - (row / inChannels) % uniforms.filterDims[1];
122122
if (${
123123
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
124124
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
@@ -186,20 +186,33 @@ export const createConv2DTransposeMatMulProgramInfo =
186186
const innerElementSize = isVec4 ? 4 : 1;
187187
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
188188
const components = isVec4 ? 4 : 1;
189-
const programUniforms: ProgramUniform[] =
190-
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
189+
const filterDims0 = attributes.kernelShape[isChannelsLast ? 1 : 2];
190+
const filterDims1 = attributes.kernelShape[isChannelsLast ? 2 : 3];
191+
const effectiveFilterDims0 =
192+
filterDims0 + (attributes.dilations[0] <= 1 ? 0 : (filterDims0 - 1) * (attributes.dilations[0] - 1));
193+
const effectiveFilterDims1 =
194+
filterDims1 + (attributes.dilations[1] <= 1 ? 0 : (filterDims1 - 1) * (attributes.dilations[1] - 1));
195+
const pads0 = effectiveFilterDims0 - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2);
196+
const pads1 = effectiveFilterDims1 - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2);
197+
const programUniforms: ProgramUniform[] = [
198+
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
199+
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
200+
{type: 'int32', data: [filterDims0, filterDims1]}, {type: 'int32', data: [pads0, pads1]}
201+
];
191202
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
192203
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
193204
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
194205
const inputVariables = [x, w];
195-
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
196-
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
206+
programUniforms.push(
207+
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
197208

209+
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
198210
let declareFunctions = '';
199211
if (hasBias) {
200212
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
201213
inputVariables.push(bias);
202214
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
215+
inputDependencies.push('rank');
203216

204217
declareFunctions += `
205218
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
@@ -209,42 +222,23 @@ export const createConv2DTransposeMatMulProgramInfo =
209222

210223
programUniforms.push(...createTensorShapeVariables(outputShape));
211224

225+
const uniforms: UniformsArrayType = [
226+
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
227+
{name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2},
228+
{name: 'filterDims', type: 'i32', length: 2}, {name: 'pads', type: 'i32', length: 2}
229+
];
230+
212231
return {
213232
name: 'Conv2DTransposeMatMul',
214-
shaderCache: {hint: attributes.cacheKey},
233+
shaderCache: {hint: `${attributes.format}`, inputDependencies},
215234
getRunData: () => ({
216235
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
217236
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
218237
programUniforms
219238
}),
220239
getShaderSource: (shaderHelper: ShaderHelper) => `
221240
${utilFunctions('uniforms.result_strides')}
222-
${
223-
shaderHelper.registerUniform('dimAOuter', 'i32')
224-
.registerUniform('dimBOuter', 'i32')
225-
.registerUniform('dimInner', 'i32')
226-
.declareVariables(...inputVariables, output)};
227-
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
228-
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
229-
attributes.kernelShape[isChannelsLast ? 2 : 3]});
230-
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
231-
${
232-
attributes.dilations[0] <= 1 ?
233-
0 :
234-
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
235-
${
236-
attributes.dilations[1] <= 1 ?
237-
0 :
238-
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
239-
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
240-
attributes.pads[0] + attributes.pads[2]})/2,
241-
i32(effectiveFilterDims[1]) - 1 - (${
242-
attributes.pads[1] + attributes.pads[3]})/2);
243-
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
244-
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
245-
const dimAOuter : i32 = ${dimAOuter};
246-
const dimBOuter : i32 = ${dimBOuter};
247-
const dimInner : i32 = ${dimInner};
241+
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
248242
${declareFunctions}
249243
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
250244
${

0 commit comments

Comments
 (0)