Skip to content

Commit 465176e

Browse files
authored
[js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (microsoft#19596)
This is used in sam-h-decoder-f16. ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 28a6689 commit 465176e

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts

+13-9
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
2323
import {LOG_DEBUG} from '../../../log';
2424
import {TensorView} from '../../../tensor-view';
2525
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
26-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
26+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
2727
import {ConvTransposeAttributes} from '../conv-transpose';
2828
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
2929

30-
import {biasSnippet, typeSnippet} from './activation_util';
30+
import {biasSnippet} from './activation_util';
3131
import {utilFunctions} from './conv_util';
3232
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
3333

3434
const conv2dTransposeCommonSnippet =
35-
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
36-
const type = typeSnippet(innerElementSize, 'f32');
35+
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
36+
innerElementSize = 4): string => {
3737
const getWSnippet = (innerElementSize: number) => {
3838
switch (innerElementSize) {
3939
case 1:
@@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
4747
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
4848
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
4949
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
50-
return vec4<f32>(v0, v1, v2, v3);
50+
return ${type}(v0, v1, v2, v3);
5151
`;
5252
default:
5353
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
@@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
224224
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
225225
inputVariables.push(bias);
226226
declareFunctions += `
227-
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
227+
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${bias.type.value} {
228228
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
229229
}`;
230230
}
@@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
236236
{name: 'pads', type: 'i32', length: pads.length}
237237
];
238238
appendActivationUniforms(attributes, uniforms);
239+
const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
240+
if (elemType !== 'f16' && elemType !== 'f32') {
241+
throw new Error(`elemType ${elemType} is not supported.`);
242+
}
239243
return `
240244
${utilFunctions('uniforms.result_strides')}
241245
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
242246
${declareFunctions}
243-
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
247+
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
244248
${
245249
isVec4 ? makeMatMulPackedVec4Source(
246-
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
250+
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
247251
makeMatMulPackedSource(
248-
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
252+
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
249253
undefined, sequentialAccessByThreads)}`;
250254
};
251255

0 commit comments

Comments
 (0)