Skip to content

Commit d673e39

Browse files
[JS/WebGPU] Added uniforms to Tile and Where Ops (#18768)
### Description <!-- Describe your changes. --> Added uniforms to Tile and Where Ops ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve performance.
1 parent b4be9e1 commit d673e39

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

Diff for: js/web/lib/wasm/jsep/webgpu/ops/tile.ts

+16-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
const getRepeats = (repeatsTensorView: TensorView): readonly number[] =>
1212
Array.from(repeatsTensorView.getBigInt64Array(), Number);
@@ -54,30 +54,35 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
5454
const outputSize = ShapeUtil.size(outputShape);
5555

5656
const dataType = inputs[0].dataType;
57-
const input = inputVariable('input', dataType, inputShape);
58-
const output = outputVariable('output', dataType, outputShape);
57+
const input = inputVariable('input', dataType, inputShape.length);
58+
const output = outputVariable('output', dataType, outputShape.length);
5959

6060
const getShaderSource = (shaderHelper: ShaderHelper) => `
6161
const inputShape = ${input.indices(...inputShape)};
62-
${shaderHelper.declareVariables(input, output)}
62+
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
6363
${shaderHelper.mainStart()}
64-
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
65-
let outputIndices = ${output.offsetToIndices('global_idx')};
66-
var inputIndices: ${input.type.indices};
64+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
65+
let output_indices = ${output.offsetToIndices('global_idx')};
66+
var input_indices: ${input.type.indices};
6767
for (var i = 0; i < ${inputShape.length}; i++) {
68-
let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')};
68+
let input_dim_i = ${input.indicesGet('uniforms.input_shape', 'i')};
69+
let input_dim_value = ${output.indicesGet('output_indices', 'i')} % input_dim_i;
6970
70-
${input.indicesSet('inputIndices', 'i', 'inputDimValue')}
71+
${input.indicesSet('input_indices', 'i', 'input_dim_value')}
7172
}
72-
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
73+
${output.setByOffset('global_idx', input.getByIndices('input_indices'))}
7374
}`;
7475

7576
return {
7677
name: 'Tile',
77-
shaderCache: {hint: `${repeats}`},
78+
shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']},
7879
getRunData: () => ({
7980
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
8081
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
82+
programUniforms: [
83+
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
84+
...createTensorShapeVariables(outputShape)
85+
],
8186
}),
8287
getShaderSource,
8388
};

Diff for: js/web/lib/wasm/jsep/webgpu/ops/where.ts

+31-28
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
66
import {BroadcastUtil, ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
const createWhereOpProgramShader =
1212
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
1313
typeOutput: number) => {
14-
const outputSize = ShapeUtil.size(dimsOutput);
15-
const vecSize = Math.ceil(outputSize / 4);
16-
17-
const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
18-
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
19-
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
20-
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
14+
const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
15+
const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
16+
const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
17+
const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);
2118

2219
let assignment: string;
2320
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
@@ -27,20 +24,20 @@ const createWhereOpProgramShader =
2724
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
2825
} else {
2926
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
30-
const expressionA = `aData[indexA${x}][componentA${x}]`;
31-
const expressionB = `bData[indexB${x}][componentB${x}]`;
27+
const expressionA = `a_data[index_a${x}][component_a${x}]`;
28+
const expressionB = `b_data[index_b${x}][component_b${x}]`;
3229
// eslint-disable-next-line no-bitwise
33-
const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
30+
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
3431
return `
35-
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
36-
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
37-
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
38-
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
39-
let indexA${x} = offsetA${x} / 4u;
40-
let indexB${x} = offsetB${x} / 4u;
41-
let indexC${x} = offsetC${x} / 4u;
42-
let componentA${x} = offsetA${x} % 4u;
43-
let componentB${x} = offsetB${x} % 4u;
32+
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
33+
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
34+
let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
35+
let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
36+
let index_a${x} = offset_a${x} / 4u;
37+
let index_b${x} = offset_b${x} / 4u;
38+
let index_c${x} = offset_c${x} / 4u;
39+
let component_a${x} = offset_a${x} % 4u;
40+
let component_b${x} = offset_b${x} % 4u;
4441
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
4542
`;
4643
};
@@ -51,21 +48,21 @@ const createWhereOpProgramShader =
5148
${singleAssignment('data', 1, 'u32')}
5249
${singleAssignment('data', 2, 'u32')}
5350
${singleAssignment('data', 3, 'u32')}
54-
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
51+
output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
5552
} else {
5653
assignment = `
57-
${singleAssignment('outputData[global_idx]', 0)}
58-
${singleAssignment('outputData[global_idx]', 1)}
59-
${singleAssignment('outputData[global_idx]', 2)}
60-
${singleAssignment('outputData[global_idx]', 3)}
54+
${singleAssignment('output_data[global_idx]', 0)}
55+
${singleAssignment('output_data[global_idx]', 1)}
56+
${singleAssignment('output_data[global_idx]', 2)}
57+
${singleAssignment('output_data[global_idx]', 3)}
6158
`;
6259
}
6360
}
6461

6562
return `
66-
${shaderHelper.declareVariables(c, a, b, output)}
63+
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
6764
${shaderHelper.mainStart()}
68-
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
65+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
6966
${assignment}
7067
}`;
7168
};
@@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
7976
const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
8077
let outputShape = dimsA;
8178
let outputSize = ShapeUtil.size(dimsA);
79+
const vecSize = Math.ceil(outputSize / 4);
8280
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
8381

8482
if (isBroadcast) {
@@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
9290

9391
return {
9492
name: 'Where',
93+
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
9594
getShaderSource: (shaderHelper) =>
9695
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
9796
getRunData: () => ({
9897
outputs: [{dims: outputShape, dataType: outputDataType}],
99-
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}
98+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
99+
programUniforms: [
100+
{type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA),
101+
...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape)
102+
],
100103
}),
101104
};
102105
};

0 commit comments

Comments
 (0)