4
4
import { DataType } from '../../../wasm-common' ;
5
5
import { TensorView } from '../../tensor-view' ;
6
6
import { ShapeUtil } from '../../util' ;
7
- import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
- import { ComputeContext , ProgramInfo } from '../types' ;
7
+ import { AttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
+ import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { castToF32 , fillVector , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , } from './common' ;
10
+ import { castToF32 , fillVector , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
13
13
epsilon : number ;
@@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo =
86
86
const hasInputSkipBiasSumOutput = outputCount > 3 ;
87
87
88
88
const components = getMaxComponents ( hiddenSize ) ;
89
- const variables = [
90
- inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims , components ) ,
91
- inputVariable ( 'skip' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims , components ) ,
92
- inputVariable ( 'gamma' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims , components ) ,
93
- ] ;
94
- if ( hasBetaInput ) {
95
- variables . push ( inputVariable ( 'beta' , inputs [ 3 ] . dataType , inputs [ 3 ] . dims , components ) ) ;
96
- }
97
- if ( hasBiasInput ) {
98
- variables . push ( inputVariable ( 'bias' , inputs [ 4 ] . dataType , inputs [ 4 ] . dims , components ) ) ;
99
- }
100
- variables . push ( outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
101
- if ( hasMeanOutput ) {
102
- variables . push ( outputVariable ( 'meanOutput' , DataType . float , meanInvStdDevDim ) ) ;
103
- }
104
- if ( hasInvStdDevOutput ) {
105
- variables . push ( outputVariable ( 'invStdOutput' , DataType . float , meanInvStdDevDim ) ) ;
106
- }
107
- if ( hasInputSkipBiasSumOutput ) {
108
- variables . push ( outputVariable ( 'inputSkipBiasSum' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
109
- }
110
- const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
111
- const getShaderSource = ( shaderHelper : ShaderHelper ) => `
112
- const hiddenSize: f32 = ${ hiddenSize } ;
113
- const hiddenSizeVectorized: u32 = ${ hiddenSize / components } ;
114
- const epsilon: f32 = ${ attributes . epsilon } ;
115
89
116
- ${ shaderHelper . declareVariables ( ...variables ) }
90
+ const programUniforms : ProgramUniform [ ] = [
91
+ { type : 'uint32' , data : outputSize } ,
92
+ { type : 'uint32' , data : components } ,
93
+ { type : 'uint32' , data : hiddenSize } ,
94
+ { type : 'float32' , data : attributes . epsilon } ,
95
+ ] ;
96
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
97
+ const uniformsArray : UniformsArrayType = [
98
+ { name : 'output_size' , type : 'u32' } ,
99
+ { name : 'components' , type : 'u32' } ,
100
+ { name : 'hidden_size' , type : 'u32' } ,
101
+ { name : 'epsilon' , type : 'f32' } ,
102
+ ] ;
103
+ const variables = [
104
+ inputVariable ( 'x' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims , components ) ,
105
+ inputVariable ( 'skip' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims , components ) ,
106
+ inputVariable ( 'gamma' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims , components ) ,
107
+ ] ;
108
+ if ( hasBetaInput ) {
109
+ variables . push ( inputVariable ( 'beta' , inputs [ 3 ] . dataType , inputs [ 3 ] . dims , components ) ) ;
110
+ }
111
+ if ( hasBiasInput ) {
112
+ variables . push ( inputVariable ( 'bias' , inputs [ 4 ] . dataType , inputs [ 4 ] . dims , components ) ) ;
113
+ }
114
+ variables . push ( outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
115
+ if ( hasMeanOutput ) {
116
+ variables . push ( outputVariable ( 'mean_output' , DataType . float , meanInvStdDevDim ) ) ;
117
+ }
118
+ if ( hasInvStdDevOutput ) {
119
+ variables . push ( outputVariable ( 'inv_std_output' , DataType . float , meanInvStdDevDim ) ) ;
120
+ }
121
+ if ( hasInputSkipBiasSumOutput ) {
122
+ variables . push ( outputVariable ( 'input_skip_bias_sum' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
123
+ }
124
+ const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
125
+ return `
126
+
127
+ ${ shaderHelper . registerUniforms ( uniformsArray ) . declareVariables ( ...variables ) }
117
128
118
129
${ shaderHelper . mainStart ( ) }
119
- ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( outputSize / hiddenSize ) }
120
- let offset = global_idx * hiddenSizeVectorized;
130
+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size / uniforms.hidden_size' ) }
131
+ let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
132
+ let offset = global_idx * hidden_size_vectorized;
121
133
var sum = ${ fillVector ( 'f32' , components ) } ;
122
134
var squareSum = ${ fillVector ( 'f32' , components ) } ;
123
- for (var i: u32 = 0; i < hiddenSizeVectorized ; i++) {
124
- let skipValue = skip[offset + i];
125
- let biasValue = ${ hasBiasInput ? 'bias[i]' : '0.0' } ;
126
- let inputValue = x[offset + i];
127
- let value = inputValue + skipValue + biasValue ;
128
- ${ hasInputSkipBiasSumOutput ? 'inputSkipBiasSum [offset + i] = value;' : '' }
135
+ for (var i: u32 = 0; i < hidden_size_vectorized ; i++) {
136
+ let skip_value = skip[offset + i];
137
+ let bias_value = ${ hasBiasInput ? 'bias[i]' : '0.0' } ;
138
+ let input_value = x[offset + i];
139
+ let value = input_value + skip_value + bias_value ;
140
+ ${ hasInputSkipBiasSumOutput ? 'input_skip_bias_sum [offset + i] = value;' : '' }
129
141
output[offset + i] = value;
130
- let f32Value = ${ castToF32 ( dataType , components , 'value' ) } ;
131
- sum += f32Value ;
132
- squareSum += f32Value * f32Value ;
142
+ let f32_value = ${ castToF32 ( dataType , components , 'value' ) } ;
143
+ sum += f32_value ;
144
+ squareSum += f32_value * f32_value ;
133
145
}
134
- let mean = ${ sumVector ( 'sum' , components ) } / hiddenSize;
135
- let invStdDev = inverseSqrt(${ sumVector ( 'squareSum' , components ) } / hiddenSize - mean * mean + epsilon);
136
- ${ hasMeanOutput ? 'meanOutput[global_idx] = mean;' : '' }
137
- ${ hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : '' }
138
- for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
139
- output[offset + i] = (output[offset + i] - ${ dataType } (mean)) * ${ dataType } (invStdDev) * gamma[i]
140
- + ${ hasBetaInput ? 'beta[i]' : '0.0' } ;
146
+ let mean = ${ sumVector ( 'sum' , components ) } / f32(uniforms.hidden_size);
147
+ let inv_std_dev = inverseSqrt(${
148
+ sumVector ( 'squareSum' , components ) } / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
149
+ ${ hasMeanOutput ? 'mean_output[global_idx] = mean;' : '' }
150
+ ${ hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : '' }
151
+ for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
152
+ output[offset + i] = (output[offset + i] - ${ dataType } (mean)) * ${ dataType } (inv_std_dev) * gamma[i] + ${
153
+ hasBetaInput ? 'beta[i]' : '0.0' } ;
141
154
}
142
155
}` ;
156
+ } ;
143
157
const outputs = [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ;
144
158
if ( outputCount > 1 ) {
145
159
outputs . push ( { dims : meanInvStdDevDim , dataType : DataType . float } ) ;
@@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo =
150
164
if ( outputCount > 3 ) {
151
165
outputs . push ( { dims : inputShape , dataType : inputs [ 0 ] . dataType } ) ;
152
166
}
153
-
154
167
return {
155
168
name : 'SkipLayerNormalization' ,
156
- shaderCache : { hint : attributes . cacheKey } ,
169
+ shaderCache : {
170
+ hint : `${ components } ;${ hasMeanOutput } ;${ hasInvStdDevOutput } ;${ hasInputSkipBiasSumOutput } ` ,
171
+ inputDependencies : inputs . map ( ( _input , _index ) => 'type' )
172
+ } ,
157
173
getShaderSource,
158
- getRunData : ( ) => ( { outputs, dispatchGroup : { x : Math . ceil ( outputSize / hiddenSize / 64 ) } } ) ,
174
+ getRunData : ( ) => ( { outputs, dispatchGroup : { x : Math . ceil ( outputSize / hiddenSize / 64 ) } , programUniforms } ) ,
159
175
} ;
160
176
} ;
161
177
@@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm
178
194
context . compute (
179
195
createSkipLayerNormProgramInfo ( context . inputs , attributes , context . outputCount , isTraining ) , { outputs} ) ;
180
196
} ;
181
-
182
- export const parseSkipLayerNormAttributes = ( attributes : Record < string , unknown > ) : SkipLayerNormAttributes => {
183
- const epsilon = attributes . epsilon as number ;
184
- return createAttributeWithCacheKey ( { epsilon} ) ;
185
- } ;
0 commit comments