Skip to content

Commit dca54cf

Browse files
Removed unnecessary shapes/strides in the shader.
1 parent 687b266 commit dca54cf

File tree

1 file changed

+29
-34
lines changed

1 file changed

+29
-34
lines changed

js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts

+29-34
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
88
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {castToF32, createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
10+
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
1313
epsilon: number;
@@ -98,36 +98,30 @@ const createSkipLayerNormProgramInfo =
9898
{type: 'uint32', data: hiddenSize},
9999
{type: 'float32', data: attributes.epsilon},
100100
];
101-
inputs.forEach((input, _) => {
102-
programUniforms.push(...createTensorShapeVariables(input.dims));
103-
});
104-
const variables = [
105-
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components),
106-
inputVariable('skip', inputs[1].dataType, inputs[1].dims.length, components),
107-
inputVariable('gamma', inputs[2].dataType, inputs[2].dims.length, components),
108-
];
109-
if (hasBetaInput) {
110-
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims.length, components));
111-
}
112-
if (hasBiasInput) {
113-
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims.length, components));
114-
}
115-
variables.push(outputVariable('output', inputs[0].dataType, outputShape.length, components));
116-
programUniforms.push(...createTensorShapeVariables(outputShape));
117-
if (hasMeanOutput) {
118-
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim.length));
119-
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
120-
}
121-
if (hasInvStdDevOutput) {
122-
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim.length));
123-
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
124-
}
125-
if (hasInputSkipBiasSumOutput) {
126-
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape.length, components));
127-
programUniforms.push(...createTensorShapeVariables(outputShape));
128-
}
129-
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
130-
const getShaderSource = (shaderHelper: ShaderHelper) => `
101+
const getShaderSource = (shaderHelper: ShaderHelper) => {
102+
const variables = [
103+
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
104+
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
105+
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
106+
];
107+
if (hasBetaInput) {
108+
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
109+
}
110+
if (hasBiasInput) {
111+
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
112+
}
113+
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
114+
if (hasMeanOutput) {
115+
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
116+
}
117+
if (hasInvStdDevOutput) {
118+
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
119+
}
120+
if (hasInputSkipBiasSumOutput) {
121+
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
122+
}
123+
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
124+
return `
131125
const epsilon: f32 = ${attributes.epsilon};
132126
133127
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
@@ -151,14 +145,15 @@ const createSkipLayerNormProgramInfo =
151145
}
152146
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
153147
let inv_std_dev = inverseSqrt(${
154-
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
148+
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
155149
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
156150
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
157151
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
158152
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
159-
hasBetaInput ? 'beta[i]' : '0.0'};
153+
hasBetaInput ? 'beta[i]' : '0.0'};
160154
}
161155
}`;
156+
};
162157
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
163158
if (outputCount > 1) {
164159
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
@@ -173,7 +168,7 @@ const createSkipLayerNormProgramInfo =
173168
name: 'SkipLayerNormalization',
174169
shaderCache: {
175170
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
176-
inputDependencies: inputs.map((_input, _index) => 'rank')
171+
inputDependencies: inputs.map((_input, _index) => 'type')
177172
},
178173
getShaderSource,
179174
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),

0 commit comments

Comments
 (0)