diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b5b6a2a15cd8c..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'pads', type: 'i32', length: pads.length} ]; appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; };