Skip to content

Commit e597eae

Browse files
authored
[js/webgpu] Optimize transpose as reshape when suitable (#22870)
BUG #22031
1 parent c4f3742 commit e597eae

File tree

2 files changed

+102
-17
lines changed

2 files changed

+102
-17
lines changed

js/web/lib/wasm/jsep/webgpu/ops/transpose.ts

+78-17
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,73 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh
4848
return { newShape, newPerm };
4949
};
5050

51+
const isTransposeReshape = (perm: number[], shape: readonly number[]) => {
52+
// As long as the dims with values > 1 stay in the same order, it's a reshape.
53+
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
54+
let lastPermutedAxis = 0;
55+
for (let i = 0; i < perm.length; ++i) {
56+
if (shape[perm[i]] === 1) {
57+
continue;
58+
}
59+
if (perm[i] < lastPermutedAxis) {
60+
return false;
61+
}
62+
lastPermutedAxis = perm[i];
63+
}
64+
return true;
65+
};
66+
5167
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
5268
const inputDataType = inputTensor.dataType;
5369
const inputRank = inputTensor.dims.length;
5470
const perm = getAdjustedPerm(inputRank, permAttr);
5571
const outputShape = getOutputShape(inputTensor.dims, perm);
72+
let newInputShape = inputTensor.dims;
73+
let newOutputShape = outputShape;
74+
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
75+
let getShaderSource;
76+
if (transposeAsReshape) {
77+
getShaderSource = (shaderHelper: ShaderHelper) => {
78+
const input = inputVariable('input', inputDataType, newInputShape, 4);
79+
const output = outputVariable('output', inputDataType, newOutputShape, 4);
80+
return `
81+
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
82+
${shaderHelper.mainStart()}
83+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
84+
output[global_idx] = input[global_idx];
85+
}`;
86+
};
87+
88+
return {
89+
name: 'TransposeCopy',
90+
shaderCache: { inputDependencies: ['type'] },
91+
getRunData: () => {
92+
const outputSize = ShapeUtil.size(outputShape);
93+
return {
94+
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
95+
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) },
96+
programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }],
97+
};
98+
},
99+
getShaderSource,
100+
};
101+
}
56102
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
57103
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
58104
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
59-
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
60-
let newInputShape = useShared ? newShape : inputTensor.dims;
61-
let newOutputShape = outputShape;
105+
const useShared = newShape.length === 2 || channelsLast || channelsFirst;
62106
if (useShared) {
63107
newInputShape = channelsLast
64108
? [newShape[0], newShape[1] * newShape[2]]
65109
: channelsFirst
66110
? [newShape[0] * newShape[1], newShape[2]]
67111
: newShape;
68112
newOutputShape = [newInputShape[1], newInputShape[0]];
69-
}
70-
const input = inputVariable('a', inputDataType, newInputShape.length);
71-
const output = outputVariable('output', inputDataType, newOutputShape.length);
72-
const tileSize = 16;
73-
let getShaderSource;
74-
if (useShared) {
75-
getShaderSource = (shaderHelper: ShaderHelper) => `
113+
const tileSize = 16;
114+
getShaderSource = (shaderHelper: ShaderHelper) => {
115+
const input = inputVariable('a', inputDataType, newInputShape.length);
116+
const output = outputVariable('output', inputDataType, newOutputShape.length);
117+
return `
76118
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
77119
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
78120
${shaderHelper.mainStart([tileSize, tileSize, 1])}
@@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
92134
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
93135
}
94136
}`;
95-
} else {
96-
getShaderSource = (shaderHelper: ShaderHelper) => `
137+
};
138+
return {
139+
name: 'TransposeShared',
140+
shaderCache: { inputDependencies: ['type'] },
141+
getRunData: () => {
142+
const outputSize = ShapeUtil.size(outputShape);
143+
return {
144+
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
145+
dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) },
146+
programUniforms: [
147+
{ type: DataType.uint32, data: outputSize },
148+
...createTensorShapeVariables(newInputShape, newOutputShape),
149+
],
150+
};
151+
},
152+
getShaderSource,
153+
};
154+
}
155+
156+
getShaderSource = (shaderHelper: ShaderHelper) => {
157+
const input = inputVariable('a', inputDataType, newInputShape.length);
158+
const output = outputVariable('output', inputDataType, newOutputShape.length);
159+
return `
97160
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
98161
99162
${permFunctionBody(perm, inputRank, input, output)}
@@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
106169
107170
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
108171
}`;
109-
}
172+
};
110173
return {
111-
name: useShared ? 'TransposeShared' : 'Transpose',
174+
name: 'Transpose',
112175
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
113176
getRunData: () => {
114177
const outputSize = ShapeUtil.size(outputShape);
115178
return {
116179
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
117-
dispatchGroup: useShared
118-
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
119-
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
180+
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
120181
programUniforms: [
121182
{ type: DataType.uint32, data: outputSize },
122183
...createTensorShapeVariables(newInputShape, newOutputShape),

js/web/test/data/ops/transpose.jsonc

+24
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,30 @@
263263
}
264264
]
265265
},
266+
{
267+
"name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]",
268+
"operator": "Transpose",
269+
"attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }],
270+
"cases": [
271+
{
272+
"name": "T[3, 1, 2, 1, 4]",
273+
"inputs": [
274+
{
275+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
276+
"dims": [3, 1, 2, 1, 4],
277+
"type": "float32"
278+
}
279+
],
280+
"outputs": [
281+
{
282+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
283+
"dims": [1, 3, 2, 4, 1],
284+
"type": "float32"
285+
}
286+
]
287+
}
288+
]
289+
},
266290
{
267291
"name": "Transpose - perms:[1, 0]",
268292
"operator": "Transpose",

0 commit comments

Comments
 (0)