Skip to content

Commit f02accb

Browse files
axingingfs-eire
authored andcommitted
[js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753)
1 parent b03a1c5 commit f02accb

9 files changed

+418
-344
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts

+69-56
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import {LOG_DEBUG} from '../../../log';
2323
import {TensorView} from '../../../tensor-view';
24-
import {ProgramInfo, ProgramUniform} from '../../types';
25-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
24+
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
25+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
2626
import {ConvAttributes} from '../conv';
2727
import {getActivationSnippet} from '../fuse-utils';
2828

@@ -88,10 +88,10 @@ const conv2dCommonSnippet =
8888
let outRow = ${row} / outWidth;
8989
let outCol = ${row} % outWidth;
9090
91-
let WRow = ${col} / (filterDims[1] * inChannels);
92-
let WCol = ${col} / inChannels % filterDims[1];
93-
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
94-
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
91+
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
92+
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
93+
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
94+
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
9595
let xCh = ${col} % inChannels;
9696
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
9797
// The bounds checking is always needed since we use it to pad zero for
@@ -108,7 +108,7 @@ const conv2dCommonSnippet =
108108
${readXSnippet}` :
109109
`
110110
let col = colIn * ${innerElementSizeX};
111-
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
111+
if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
112112
${readXSnippet}
113113
}
114114
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
@@ -117,7 +117,7 @@ const conv2dCommonSnippet =
117117
${readXSnippet}` :
118118
`
119119
let col = colIn * ${innerElementSizeX};
120-
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
120+
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
121121
${readXSnippet}
122122
}
123123
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
@@ -129,9 +129,8 @@ const conv2dCommonSnippet =
129129
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
130130
const bType =
131131
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
132-
const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType);
132+
const applyActivation = getActivationSnippet(attributes, resType);
133133
const userCode = `
134-
${activationFunction}
135134
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
136135
${isChannelsLast ? sampleX : sampleW}
137136
}
@@ -142,7 +141,7 @@ const conv2dCommonSnippet =
142141
143142
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
144143
let col = colIn * ${innerElementSize};
145-
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
144+
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer)
146145
{
147146
var value = valueIn;
148147
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
@@ -181,83 +180,97 @@ export const createConv2DMatMulProgramInfo =
181180
LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);
182181

183182
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
184-
185183
const tileAOuter = workGroupSize[1] * elementsPerThread[1];
186184
const tileBOuter = workGroupSize[0] * elementsPerThread[0];
187185
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
188-
189186
const fitAOuter = dimAOuter % tileAOuter === 0;
190187
const fitBOuter = dimBOuter % tileBOuter === 0;
191188
const fitInner = dimInner % tileInner === 0;
192-
193189
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
194-
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
195190

196-
// TODO: support component 2, 3.
197-
const components = isVec4 ? 4 : 1;
198-
const programUniforms: ProgramUniform[] =
199-
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
200-
const x =
201-
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
202-
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
203-
const inputVariables = [x, w];
191+
const programUniforms: ProgramUniform[] = [
192+
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
193+
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
194+
{type: 'int32', data: attributes.dilations}
195+
];
196+
if (attributes.activation === 'Clip') {
197+
programUniforms.push(
198+
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
199+
}
200+
programUniforms.push(
201+
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
202+
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
203+
if (hasBias) {
204+
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
205+
inputDependencies.push('rank');
206+
}
207+
programUniforms.push(...createTensorShapeVariables(outputShape));
204208

205-
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
206-
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
209+
const getShaderSource = (shaderHelper: ShaderHelper) => {
210+
const uniforms: UniformsArrayType = [
211+
{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
212+
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
213+
{name: 'dilation', type: 'i32', length: 2}
214+
];
215+
if (attributes.activation === 'Clip') {
216+
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
217+
}
207218

208-
let declareFunctions = `
219+
// TODO: support component 2, 3.
220+
const components = isVec4 ? 4 : 1;
221+
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
222+
let declareFunctions = `
209223
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
210224
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
211225
}
212226
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
213227
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
214228
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
215229
}`;
216-
if (hasBias) {
217-
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
218-
inputVariables.push(bias);
219-
220-
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
221-
222-
declareFunctions += `
230+
const x = inputVariable(
231+
'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
232+
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
233+
const inputVariables = [x, w];
234+
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
235+
if (hasBias) {
236+
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
237+
inputVariables.push(bias);
238+
declareFunctions += `
223239
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
224240
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
225241
}`;
226-
}
227-
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
228-
programUniforms.push(...createTensorShapeVariables(outputShape));
229-
return {
230-
name: 'Conv2DMatMul',
231-
shaderCache: {hint: attributes.cacheKey},
232-
getRunData: () => ({
233-
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
234-
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
235-
programUniforms,
236-
}),
237-
getShaderSource: (shaderHelper: ShaderHelper) => `
242+
}
243+
244+
return `
238245
${utilFunctions('uniforms.result_strides')}
239246
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
240247
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
241248
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
242-
${
243-
shaderHelper.registerUniform('dimAOuter', 'i32')
244-
.registerUniform('dimBOuter', 'i32')
245-
.registerUniform('dimInner', 'i32')
246-
.declareVariables(...inputVariables, output)}
247-
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
248-
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
249-
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
250-
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
249+
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
251250
${declareFunctions}
252251
${
253252
conv2dCommonSnippet(
254253
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1],
255254
elementsSize[2], t)}
256-
${
255+
${
257256
isVec4 ?
258257
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
259258
makeMatMulPackedSource(
260259
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
261-
sequentialAccessByThreads)}`
260+
sequentialAccessByThreads)}`;
261+
};
262+
return {
263+
name: 'Conv2DMatMul',
264+
shaderCache: {
265+
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
266+
tileAOuter};${tileBOuter};${tileInner}`,
267+
inputDependencies
268+
},
269+
getRunData: () => ({
270+
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
271+
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
272+
programUniforms,
273+
}),
274+
getShaderSource
262275
};
263276
};

0 commit comments

Comments
 (0)