@@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
23
23
import { LOG_DEBUG } from '../../../log' ;
24
24
import { TensorView } from '../../../tensor-view' ;
25
25
import { ProgramInfo , ProgramInputTensorInfoDependency , ProgramUniform } from '../../types' ;
26
- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , UniformsArrayType } from '../common' ;
26
+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from '../common' ;
27
27
import { ConvTransposeAttributes } from '../conv-transpose' ;
28
28
import { appendActivationUniforms , appendActivationUniformsData , getActivationSnippet } from '../fuse-utils' ;
29
29
30
- import { biasSnippet , typeSnippet } from './activation_util' ;
30
+ import { biasSnippet } from './activation_util' ;
31
31
import { utilFunctions } from './conv_util' ;
32
32
import { makeMatMulPackedSource , makeMatMulPackedVec4Source } from './matmul_packed_webgpu' ;
33
33
34
34
const conv2dTransposeCommonSnippet =
35
- ( isChannelsLast : boolean , addBias = false , attributes : ConvTransposeAttributes , innerElementSize = 4 ) : string => {
36
- const type = typeSnippet ( innerElementSize , 'f32' ) ;
35
+ ( isChannelsLast : boolean , addBias = false , attributes : ConvTransposeAttributes , type : string ,
36
+ innerElementSize = 4 ) : string => {
37
37
const getWSnippet = ( innerElementSize : number ) => {
38
38
switch ( innerElementSize ) {
39
39
case 1 :
@@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
47
47
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
48
48
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
49
49
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
50
- return vec4<f32> (v0, v1, v2, v3);
50
+ return ${ type } (v0, v1, v2, v3);
51
51
` ;
52
52
default :
53
53
throw new Error ( `innerElementSize ${ innerElementSize } is not supported.` ) ;
@@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
224
224
const bias = inputVariable ( 'bias' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , components ) ;
225
225
inputVariables . push ( bias ) ;
226
226
declareFunctions += `
227
- fn getBiasByOutputCoords(coords : vec4<i32>) -> ${ isVec4 ? 'vec4<f32>' : 'f32' } {
227
+ fn getBiasByOutputCoords(coords : vec4<i32>) -> ${ bias . type . value } {
228
228
return bias[coords.${ isChannelsLast ? 'w' : 'y' } ${ isVec4 ? '/ 4' : '' } ];
229
229
}` ;
230
230
}
@@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
236
236
{ name : 'pads' , type : 'i32' , length : pads . length }
237
237
] ;
238
238
appendActivationUniforms ( attributes , uniforms ) ;
239
+ const elemType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType , 1 ) ;
240
+ if ( elemType !== 'f16' && elemType !== 'f32' ) {
241
+ throw new Error ( `elemType ${ elemType } is not supported.` ) ;
242
+ }
239
243
return `
240
244
${ utilFunctions ( 'uniforms.result_strides' ) }
241
245
${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) } ;
242
246
${ declareFunctions }
243
- ${ conv2dTransposeCommonSnippet ( isChannelsLast , hasBias , attributes , innerElementSize ) }
247
+ ${ conv2dTransposeCommonSnippet ( isChannelsLast , hasBias , attributes , x . type . value , innerElementSize ) }
244
248
${
245
249
isVec4 ? makeMatMulPackedVec4Source (
246
- elementsPerThread , workGroupSize , 'f32' , undefined , ! isChannelsLast , tileInner ) :
250
+ elementsPerThread , workGroupSize , elemType , undefined , ! isChannelsLast , tileInner ) :
247
251
makeMatMulPackedSource (
248
- elementsPerThread , workGroupSize , 'f32' , undefined , ! isChannelsLast , tileInner , false ,
252
+ elementsPerThread , workGroupSize , elemType , undefined , ! isChannelsLast , tileInner , false ,
249
253
undefined , sequentialAccessByThreads ) } `;
250
254
} ;
251
255
0 commit comments