-
Notifications
You must be signed in to change notification settings - Fork 3.1k
/
Copy pathskip-layer-norm.ts
196 lines (180 loc) · 8.33 KB
/
skip-layer-norm.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
epsilon: number;
}
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 3) {
throw new Error('layerNorm requires at least 3 inputs.');
}
const input: TensorView = inputs[0];
const skip: TensorView = inputs[1];
const gamma: TensorView = inputs[2];
if (input.dataType !== skip.dataType || input.dataType !== gamma.dataType) {
throw new Error('All inputs must have the same data type');
}
if (input.dims.length !== 3 && input.dims.length !== 2) {
throw new Error('Input must be 2D or 3D');
}
if (skip.dims.length !== 3 && skip.dims.length !== 2) {
throw new Error('Skip must be 2D or 3D');
}
const hiddenSize = input.dims[input.dims.length - 1];
const sequenceLength = input.dims[input.dims.length - 2];
if (skip.dims[skip.dims.length - 1] !== hiddenSize) {
throw new Error('Skip must have the same hidden size as input');
}
if (skip.dims[skip.dims.length - 2] !== sequenceLength) {
throw new Error('Skip must have the same sequence length as input');
}
if (gamma.dims.length !== 1) {
throw new Error('Gamma must be 1D');
}
if (gamma.dims[gamma.dims.length - 1] !== hiddenSize) {
throw new Error('Gamma must have the same hidden size as input');
}
if (inputs.length > 3) {
const beta: TensorView = inputs[3];
if (beta.dims.length !== 1) {
throw new Error('Beta must be 1D');
}
if (beta.dims[beta.dims.length - 1] !== hiddenSize) {
throw new Error('Beta must have the same hidden size as input');
}
}
if (inputs.length > 4) {
const bias: TensorView = inputs[4];
if (bias.dims.length !== 1) {
throw new Error('Bias must be 1D');
}
if (bias.dims[bias.dims.length - 1] !== hiddenSize) {
throw new Error('Bias must have the same hidden size as input');
}
}
};
const createSkipLayerNormProgramInfo =
(inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean):
ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const outputShape = inputShape;
const outputSize = inputSize;
const hiddenSize = inputShape.slice(-1)[0];
const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
const hasBetaInput = inputs.length > 3;
const hasBiasInput = inputs.length > 4;
const hasMeanOutput = isTraining && outputCount > 1;
const hasInvStdDevOutput = isTraining && outputCount > 2;
const hasInputSkipBiasSumOutput = outputCount > 3;
const components = getMaxComponents(hiddenSize);
const uniformsArray: UniformsArrayType = [
{name: 'output_size', type: 'u32'},
{name: 'components', type: 'u32'},
{name: 'hidden_size', type: 'u32'},
{name: 'epsilon', type: 'f32'},
];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize},
{type: 'uint32', data: components},
{type: 'uint32', data: hiddenSize},
{type: 'float32', data: attributes.epsilon},
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
const epsilon: f32 = ${attributes.epsilon};
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')}
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
let offset = global_idx * hidden_size_vectorized;
var sum = ${fillVector('f32', components)};
var squareSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
let skip_value = skip[offset + i];
let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'};
let input_value = x[offset + i];
let value = input_value + skip_value + bias_value;
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
output[offset + i] = value;
let f32_value = ${castToF32(dataType, components, 'value')};
sum += f32_value;
squareSum += f32_value * f32_value;
}
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
let inv_std_dev = inverseSqrt(${
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 2) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 3) {
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
}
return {
name: 'SkipLayerNormalization',
shaderCache: {
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
inputDependencies: inputs.map((_input, _index) => 'type')
},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
};
};
export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => {
// TODO: initialize isTraining from ComputeContext
const isTraining = false;
validateInputs(context.inputs);
// Mean and InvStdDev are only used in training mode and are not required for inference.
// They are added here for completeness only.
const outputs = [0];
if (context.outputCount > 1) {
outputs.push(isTraining ? 1 : -3);
}
if (context.outputCount > 2) {
outputs.push(isTraining ? 2 : -3);
}
if (context.outputCount > 3) {
outputs.push(3);
}
context.compute(
createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs});
};