@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
7
7
import { AttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
8
import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { castToF32 , createTensorShapeVariables , fillVector , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10
+ import { castToF32 , fillVector , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
13
13
epsilon : number ;
@@ -98,36 +98,30 @@ const createSkipLayerNormProgramInfo =
98
98
{ type : 'uint32' , data : hiddenSize } ,
99
99
{ type : 'float32' , data : attributes . epsilon } ,
100
100
] ;
101
- inputs . forEach ( ( input , _ ) => {
102
- programUniforms . push ( ...createTensorShapeVariables ( input . dims ) ) ;
103
- } ) ;
104
- const variables = [
105
- inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , components ) ,
106
- inputVariable ( 'skip' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , components ) ,
107
- inputVariable ( 'gamma' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , components ) ,
108
- ] ;
109
- if ( hasBetaInput ) {
110
- variables . push ( inputVariable ( 'beta' , inputs [ 3 ] . dataType , inputs [ 3 ] . dims . length , components ) ) ;
111
- }
112
- if ( hasBiasInput ) {
113
- variables . push ( inputVariable ( 'bias' , inputs [ 4 ] . dataType , inputs [ 4 ] . dims . length , components ) ) ;
114
- }
115
- variables . push ( outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length , components ) ) ;
116
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
117
- if ( hasMeanOutput ) {
118
- variables . push ( outputVariable ( 'mean_output' , DataType . float , meanInvStdDevDim . length ) ) ;
119
- programUniforms . push ( ...createTensorShapeVariables ( meanInvStdDevDim ) ) ;
120
- }
121
- if ( hasInvStdDevOutput ) {
122
- variables . push ( outputVariable ( 'inv_std_output' , DataType . float , meanInvStdDevDim . length ) ) ;
123
- programUniforms . push ( ...createTensorShapeVariables ( meanInvStdDevDim ) ) ;
124
- }
125
- if ( hasInputSkipBiasSumOutput ) {
126
- variables . push ( outputVariable ( 'input_skip_bias_sum' , inputs [ 0 ] . dataType , outputShape . length , components ) ) ;
127
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
128
- }
129
- const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
130
- const getShaderSource = ( shaderHelper : ShaderHelper ) => `
101
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
102
+ const variables = [
103
+ inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims , components ) ,
104
+ inputVariable ( 'skip' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims , components ) ,
105
+ inputVariable ( 'gamma' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims , components ) ,
106
+ ] ;
107
+ if ( hasBetaInput ) {
108
+ variables . push ( inputVariable ( 'beta' , inputs [ 3 ] . dataType , inputs [ 3 ] . dims , components ) ) ;
109
+ }
110
+ if ( hasBiasInput ) {
111
+ variables . push ( inputVariable ( 'bias' , inputs [ 4 ] . dataType , inputs [ 4 ] . dims , components ) ) ;
112
+ }
113
+ variables . push ( outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
114
+ if ( hasMeanOutput ) {
115
+ variables . push ( outputVariable ( 'mean_output' , DataType . float , meanInvStdDevDim ) ) ;
116
+ }
117
+ if ( hasInvStdDevOutput ) {
118
+ variables . push ( outputVariable ( 'inv_std_output' , DataType . float , meanInvStdDevDim ) ) ;
119
+ }
120
+ if ( hasInputSkipBiasSumOutput ) {
121
+ variables . push ( outputVariable ( 'input_skip_bias_sum' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
122
+ }
123
+ const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
124
+ return `
131
125
const epsilon: f32 = ${ attributes . epsilon } ;
132
126
133
127
${ shaderHelper . registerUniforms ( uniformsArray ) . declareVariables ( ...variables ) }
@@ -151,14 +145,15 @@ const createSkipLayerNormProgramInfo =
151
145
}
152
146
let mean = ${ sumVector ( 'sum' , components ) } / f32(uniforms.hidden_size);
153
147
let inv_std_dev = inverseSqrt(${
154
- sumVector ( 'squareSum' , components ) } / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
148
+ sumVector ( 'squareSum' , components ) } / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
155
149
${ hasMeanOutput ? 'mean_output[global_idx] = mean;' : '' }
156
150
${ hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : '' }
157
151
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
158
152
output[offset + i] = (output[offset + i] - ${ dataType } (mean)) * ${ dataType } (inv_std_dev) * gamma[i] + ${
159
- hasBetaInput ? 'beta[i]' : '0.0' } ;
153
+ hasBetaInput ? 'beta[i]' : '0.0' } ;
160
154
}
161
155
}` ;
156
+ } ;
162
157
const outputs = [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ;
163
158
if ( outputCount > 1 ) {
164
159
outputs . push ( { dims : meanInvStdDevDim , dataType : DataType . float } ) ;
@@ -173,7 +168,7 @@ const createSkipLayerNormProgramInfo =
173
168
name : 'SkipLayerNormalization' ,
174
169
shaderCache : {
175
170
hint : `${ components } ;${ hasMeanOutput } ;${ hasInvStdDevOutput } ;${ hasInputSkipBiasSumOutput } ` ,
176
- inputDependencies : inputs . map ( ( _input , _index ) => 'rank ' )
171
+ inputDependencies : inputs . map ( ( _input , _index ) => 'type ' )
177
172
} ,
178
173
getShaderSource,
179
174
getRunData : ( ) => ( { outputs, dispatchGroup : { x : Math . ceil ( outputSize / hiddenSize / 64 ) } , programUniforms} ) ,
0 commit comments