Skip to content

Commit 7282e23

Browse files
[JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788)
### Description Added Uniforms to SkipLayerNorm ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve performance --------- Co-authored-by: Yulong Wang <[email protected]>
1 parent 254b543 commit 7282e23

File tree

2 files changed

+69
-58
lines changed

2 files changed

+69
-58
lines changed

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import * as pool from './ops/pool';
2525
import {range} from './ops/range';
2626
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
2727
import {parseResizeAttributes, resize} from './ops/resize';
28-
import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
28+
import {skipLayerNorm} from './ops/skip-layer-norm';
2929
import {parseSliceAttributes, slice} from './ops/slice';
3030
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
3131
import {parseSplitAttributes, split} from './ops/split';
@@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
116116
['Sin', [unaryOps.sin]],
117117
['Sinh', [unaryOps.sinh]],
118118
['Slice', [slice, parseSliceAttributes]],
119-
['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]],
119+
['SkipLayerNormalization', [skipLayerNorm]],
120120
['Split', [split, parseSplitAttributes]],
121121
['Sqrt', [unaryOps.sqrt]],
122122
['Softmax', [softmax, parseSoftmaxAttributes]],

js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts

+67-56
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import {DataType} from '../../../wasm-common';
55
import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
7-
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
8-
import {ComputeContext, ProgramInfo} from '../types';
7+
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
8+
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common';
10+
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
1313
epsilon: number;
@@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo =
8686
const hasInputSkipBiasSumOutput = outputCount > 3;
8787

8888
const components = getMaxComponents(hiddenSize);
89-
const variables = [
90-
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
91-
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
92-
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
93-
];
94-
if (hasBetaInput) {
95-
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
96-
}
97-
if (hasBiasInput) {
98-
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
99-
}
100-
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
101-
if (hasMeanOutput) {
102-
variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim));
103-
}
104-
if (hasInvStdDevOutput) {
105-
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
106-
}
107-
if (hasInputSkipBiasSumOutput) {
108-
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
109-
}
110-
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
111-
const getShaderSource = (shaderHelper: ShaderHelper) => `
112-
const hiddenSize: f32 = ${hiddenSize};
113-
const hiddenSizeVectorized: u32 = ${hiddenSize / components};
114-
const epsilon: f32 = ${attributes.epsilon};
11589

116-
${shaderHelper.declareVariables(...variables)}
90+
const programUniforms: ProgramUniform[] = [
91+
{type: 'uint32', data: outputSize},
92+
{type: 'uint32', data: components},
93+
{type: 'uint32', data: hiddenSize},
94+
{type: 'float32', data: attributes.epsilon},
95+
];
96+
const getShaderSource = (shaderHelper: ShaderHelper) => {
97+
const uniformsArray: UniformsArrayType = [
98+
{name: 'output_size', type: 'u32'},
99+
{name: 'components', type: 'u32'},
100+
{name: 'hidden_size', type: 'u32'},
101+
{name: 'epsilon', type: 'f32'},
102+
];
103+
const variables = [
104+
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
105+
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
106+
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
107+
];
108+
if (hasBetaInput) {
109+
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
110+
}
111+
if (hasBiasInput) {
112+
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
113+
}
114+
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
115+
if (hasMeanOutput) {
116+
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
117+
}
118+
if (hasInvStdDevOutput) {
119+
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
120+
}
121+
if (hasInputSkipBiasSumOutput) {
122+
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
123+
}
124+
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
125+
return `
126+
127+
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
117128
118129
${shaderHelper.mainStart()}
119-
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)}
120-
let offset = global_idx * hiddenSizeVectorized;
130+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')}
131+
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
132+
let offset = global_idx * hidden_size_vectorized;
121133
var sum = ${fillVector('f32', components)};
122134
var squareSum = ${fillVector('f32', components)};
123-
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
124-
let skipValue = skip[offset + i];
125-
let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'};
126-
let inputValue = x[offset + i];
127-
let value = inputValue + skipValue + biasValue;
128-
${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''}
135+
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
136+
let skip_value = skip[offset + i];
137+
let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'};
138+
let input_value = x[offset + i];
139+
let value = input_value + skip_value + bias_value;
140+
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
129141
output[offset + i] = value;
130-
let f32Value = ${castToF32(dataType, components, 'value')};
131-
sum += f32Value;
132-
squareSum += f32Value * f32Value;
142+
let f32_value = ${castToF32(dataType, components, 'value')};
143+
sum += f32_value;
144+
squareSum += f32_value * f32_value;
133145
}
134-
let mean = ${sumVector('sum', components)} / hiddenSize;
135-
let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon);
136-
${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
137-
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''}
138-
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
139-
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i]
140-
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
146+
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
147+
let inv_std_dev = inverseSqrt(${
148+
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
149+
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
150+
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
151+
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
152+
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
153+
hasBetaInput ? 'beta[i]' : '0.0'};
141154
}
142155
}`;
156+
};
143157
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
144158
if (outputCount > 1) {
145159
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
@@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo =
150164
if (outputCount > 3) {
151165
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
152166
}
153-
154167
return {
155168
name: 'SkipLayerNormalization',
156-
shaderCache: {hint: attributes.cacheKey},
169+
shaderCache: {
170+
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
171+
inputDependencies: inputs.map((_input, _index) => 'type')
172+
},
157173
getShaderSource,
158-
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}),
174+
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
159175
};
160176
};
161177

@@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm
178194
context.compute(
179195
createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs});
180196
};
181-
182-
export const parseSkipLayerNormAttributes = (attributes: Record<string, unknown>): SkipLayerNormAttributes => {
183-
const epsilon = attributes.epsilon as number;
184-
return createAttributeWithCacheKey({epsilon});
185-
};

0 commit comments

Comments
 (0)