@@ -85,28 +85,28 @@ const createLayerNormProgramInfo =
85
85
${ shaderHelper . mainStart ( ) }
86
86
${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.norm_count' ) }
87
87
let offset = global_idx * uniforms.norm_size_vectorized;
88
- var meanVector = ${ fillVector ( 'f32' , components ) } ;
89
- var meanSquareVector = ${ fillVector ( 'f32' , components ) } ;
88
+ var mean_vector = ${ fillVector ( 'f32' , components ) } ;
89
+ var mean_square_vector = ${ fillVector ( 'f32' , components ) } ;
90
90
91
91
for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
92
92
let value = ${ castToF32 ( dataType , components , 'x[h + offset]' ) } ;
93
- meanVector += value;
94
- meanSquareVector += value * value;
93
+ mean_vector += value;
94
+ mean_square_vector += value * value;
95
95
}
96
- let mean = ${ sumVector ( 'meanVector ' , components ) } / uniforms.norm_size;
97
- let invStdDev =
98
- inverseSqrt( ${ sumVector ( 'meanSquareVector ' , components ) } / uniforms.norm_size - mean * mean + uniforms.epsilon);
96
+ let mean = ${ sumVector ( 'mean_vector ' , components ) } / uniforms.norm_size;
97
+ let inv_std_dev = inverseSqrt( ${
98
+ sumVector ( 'mean_square_vector ' , components ) } / uniforms.norm_size - mean * mean + uniforms.epsilon);
99
99
100
100
for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
101
101
let f32input = ${ castToF32 ( dataType , components , 'x[j + offset]' ) } ;
102
102
let f32scale = ${ castToF32 ( dataType , components , 'scale[j]' ) } ;
103
- output[j + offset] = ${ variables [ 0 ] . type . value } ((f32input - mean) * invStdDev * f32scale
103
+ output[j + offset] = ${ variables [ 0 ] . type . value } ((f32input - mean) * inv_std_dev * f32scale
104
104
${ bias ? `+ ${ castToF32 ( dataType , components , 'bias[j]' ) } ` : '' }
105
105
);
106
106
}
107
107
108
108
${ hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : '' } ;
109
- ${ hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev ' : '' } ;
109
+ ${ hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev ' : '' } ;
110
110
}` ;
111
111
} ;
112
112
const outputs = [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ;
0 commit comments