Skip to content

Commit 624b4e2

Browse files
authoredJan 30, 2024
[js/webgpu] Remove enableShapesUniforms (#19279)
1 parent 00d0481 commit 624b4e2

File tree

9 files changed

+68
-134
lines changed

9 files changed

+68
-134
lines changed
 

‎js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts

+6-6
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,9 @@ export const createMatmulProgramInfo =
443443

444444
const components = isVec4 ? 4 : 1;
445445
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
446-
const aShapeOrRank = aShapeTemp.length;
446+
const aRank = aShapeTemp.length;
447447
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
448-
const bShapeOrRank = bShapeTemp.length;
448+
const bRank = bShapeTemp.length;
449449
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
450450
const programUniforms: ProgramUniform[] =
451451
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
@@ -467,12 +467,12 @@ export const createMatmulProgramInfo =
467467
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
468468

469469
const getShaderSource = (shaderHelper: ShaderHelper) => {
470-
const batchShapeOrRank = outerDims.length;
471-
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
470+
const batchRank = outerDims.length;
471+
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1);
472472
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
473473

474-
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
475-
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
474+
const A = inputVariable('a', inputs[0].dataType, aRank, components);
475+
const B = inputVariable('b', inputs[1].dataType, bRank, components);
476476
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
477477
const inputVariables = [A, B];
478478
if (hasBias) {

‎js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util';
88
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
99
import {ComputeContext, ProgramInfo} from '../types';
1010

11-
import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
11+
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
1212

1313
export interface BatchNormAttributes extends AttributeWithCacheKey {
1414
readonly epsilon: number;
@@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo =
6161
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
6262
const outputSize = ShapeUtil.size(yShape) / components;
6363
// Only support uniforms for opset version >= 9 (spatial = true).
64-
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
64+
const useShapesUniforms = spatial;
6565
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
6666
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
6767
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);

‎js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts

+14-23
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66
import {BroadcastUtil, ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
type BuiltinFunctionName = string;
1212
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
@@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
1818
const createBinaryOpProgramShader =
1919
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
2020
vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
21-
typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
22-
additionalImplementation?: string) => {
21+
typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => {
2322
let expressionScalar: BinaryCustomExpression;
2423
let expressionVector: BinaryCustomExpression;
2524
if (typeof funcCall === 'string') {
@@ -31,12 +30,9 @@ const createBinaryOpProgramShader =
3130
expressionVector = funcCall.vector;
3231
}
3332

34-
const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA;
35-
const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB;
36-
const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput;
37-
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4);
38-
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4);
39-
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4);
33+
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
34+
const a = inputVariable('aData', typeA, dimsA.length, 4);
35+
const b = inputVariable('bData', typeB, dimsB.length, 4);
4036

4137
let assignment: string;
4238
if (vectorize) {
@@ -169,30 +165,25 @@ const createBinaryOpProgramInfo =
169165
vectorize = true;
170166
}
171167
cacheKeyAux.push(vectorize);
172-
const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) &&
173-
enableShapesUniforms(outputShape.length);
168+
174169
return {
175170
name,
176171
shaderCache: {
177172
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
178-
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
173+
inputDependencies: ['rank', 'rank'],
179174
},
180175
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
181176
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
182-
a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
177+
a.dataType, b.dataType, outputDataType, additionalImplementation),
183178
getRunData: () => ({
184179
outputs: [{dims: outputShape, dataType: outputDataType}],
185180
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
186-
programUniforms: useShapesUniforms ?
187-
[
188-
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
189-
...createTensorShapeVariables(a.dims),
190-
...createTensorShapeVariables(b.dims),
191-
...createTensorShapeVariables(outputShape),
192-
] :
193-
[
194-
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
195-
],
181+
programUniforms: [
182+
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
183+
...createTensorShapeVariables(a.dims),
184+
...createTensorShapeVariables(b.dims),
185+
...createTensorShapeVariables(outputShape),
186+
],
196187
}),
197188
};
198189
};

‎js/web/lib/wasm/jsep/webgpu/ops/common.ts

-3
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
922922
}
923923
return dims;
924924
};
925-
926-
// TODO: remove this when all related uses have been removed.
927-
export const enableShapesUniforms = (_rank: number): boolean => true;

‎js/web/lib/wasm/jsep/webgpu/ops/concat.ts

+8-18
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
66
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
77
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
88

9-
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
export interface ConcatAttributes extends AttributeWithCacheKey {
1212
readonly axis: number;
@@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
9494

9595
let previousSum = 0;
9696
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
97-
const inputShapeOrRanks = [];
98-
const enableInputShapesUniforms = [];
97+
const inputRanks = [];
9998
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
10099
for (let i = 0; i < inputs.length; ++i) {
101100
previousSum += inputs[i].dims[adjustedAxis];
102101
sizeInConcatAxis[i] = previousSum;
103-
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
104-
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
105-
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
106-
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
102+
inputRanks.push(inputs[i].dims.length);
103+
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
104+
inputDependencies.push('rank');
107105
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
108106
}
109107
for (let i = 0; i < inputs.length; ++i) {
110-
if (enableInputShapesUniforms[i]) {
111-
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
112-
}
113-
}
114-
115-
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
116-
if (enableOutputShapesUniforms) {
117-
programUniforms.push(...createTensorShapeVariables(outputShape));
108+
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
118109
}
110+
programUniforms.push(...createTensorShapeVariables(outputShape));
119111

120-
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
121-
const output = outputVariable('output', dataType, outputShapeOrRank);
122-
112+
const output = outputVariable('output', dataType, outputShape.length);
123113
const indicesAxis = output.indicesGet('indices', adjustedAxis);
124114
const sizeInConcatAxisStr =
125115
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');

‎js/web/lib/wasm/jsep/webgpu/ops/einsum.ts

+10-21
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util';
66
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
77
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
88

9-
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
10-
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1110

1211
export interface EinsumAttributes extends AttributeWithCacheKey {
1312
readonly equation: string;
@@ -181,14 +180,12 @@ class EinsumEquation {
181180
const appendMax = (name: string): string => name + '_max';
182181

183182
const createEinsumProgramInfo =
184-
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
185-
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
186-
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
187-
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
183+
(inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
184+
outputShape: readonly number[]): ProgramInfo => {
185+
const ranks = inputShapes.map((dims) => dims.length);
186+
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
188187
const outputSize = ShapeUtil.size(outputShape);
189-
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
190-
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
191-
const output = outputVariable('output', dataType, outputShapeOrRank);
188+
const output = outputVariable('output', dataType, outputShape.length);
192189
const uniformsSymbols =
193190
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
194191
const getShaderSource = (shaderHelper: ShaderHelper) => {
@@ -269,10 +266,7 @@ const createEinsumProgramInfo =
269266
};
270267
return {
271268
name: 'Einsum',
272-
shaderCache: {
273-
hint: einsumEquation.equation,
274-
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
275-
},
269+
shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
276270
getRunData: () => {
277271
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
278272
// filter is added to make sure that dimValue is never 0.
@@ -281,12 +275,9 @@ const createEinsumProgramInfo =
281275
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
282276
programUniformsInit.push({type: 'uint32', data: outputSize});
283277
const programUniforms: ProgramUniform[] =
284-
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
285-
.map((dims, _) => [...createTensorShapeVariables(dims)])
278+
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
286279
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
287-
if (enableOutputShapesUniforms) {
288-
programUniforms.push(...createTensorShapeVariables(outputShape));
289-
}
280+
programUniforms.push(...createTensorShapeVariables(outputShape));
290281
return ({
291282
outputs: [{dims: outputShape, dataType}],
292283
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
@@ -299,11 +290,9 @@ const createEinsumProgramInfo =
299290

300291
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
301292
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
302-
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
303293
const outputShape = einsumEquation.outputDims;
304294
const inputShapes = context.inputs.map((input, _) => input.dims);
305-
context.compute(createEinsumProgramInfo(
306-
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
295+
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
307296
};
308297

309298
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {

‎js/web/lib/wasm/jsep/webgpu/ops/expand.ts

+8-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
88

9-
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
const validateInputs = (inputs: readonly TensorView[]): void => {
1212
if (!inputs || inputs.length !== 2) {
@@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
4949
const components = dataType === DataType.bool ? 4 : 1;
5050
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
5151

52-
const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
53-
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);
54-
55-
5652
const getShaderSource = (shaderHelper: ShaderHelper) => {
57-
const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
58-
const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
59-
const input = inputVariable('input', dataType, inputShapeOrRank, components);
60-
const output = outputVariable('output', dataType, outputShapeOrRank, components);
53+
const input = inputVariable('input', dataType, inputShape.length, components);
54+
const output = outputVariable('output', dataType, outputShape.length, components);
6155
let assignment: string;
6256
if (dataType === DataType.bool) {
6357
const singleAssignment = (resStr: string, x: number, typeCast = '') => `
@@ -90,16 +84,13 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
9084
${assignment}`;
9185
};
9286

93-
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
94-
if (enableInputShapeUniform) {
95-
programUniforms.push(...createTensorShapeVariables(inputShape));
96-
}
97-
if (enableOutputShapeUniform) {
98-
programUniforms.push(...createTensorShapeVariables(outputShape));
99-
}
87+
const programUniforms: ProgramUniform[] = [
88+
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
89+
...createTensorShapeVariables(outputShape)
90+
];
10091
return {
10192
name: 'Expand',
102-
shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
93+
shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']},
10394
getShaderSource,
10495
getRunData: () => ({
10596
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],

‎js/web/lib/wasm/jsep/webgpu/ops/gather.ts

+11-28
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common';
55
import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
8-
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
8+
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
10+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1111

1212
export interface GatherAttributes extends AttributeWithCacheKey {
1313
axis: number;
@@ -33,33 +33,16 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
3333
const components = inputs[0].dataType === DataType.bool ? 4 : 1;
3434
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
3535

36-
const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
37-
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
38-
const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length);
39-
const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims;
40-
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
41-
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
42-
43-
const programUniforms: ProgramUniform[] =
44-
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
45-
if (enableInputShapesUniforms) {
46-
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
47-
}
48-
if (enableIndicesShapesUniforms) {
49-
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
50-
}
51-
if (enableOutputShapesUniforms) {
52-
programUniforms.push(...createTensorShapeVariables(outputShape));
53-
}
54-
55-
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
56-
inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims');
57-
inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims');
36+
const programUniforms: ProgramUniform[] = [
37+
{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis},
38+
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims),
39+
...createTensorShapeVariables(outputShape)
40+
];
5841

5942
const getShaderSource = (shaderHelper: ShaderHelper) => {
60-
const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components);
61-
const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
62-
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components);
43+
const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components);
44+
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length);
45+
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
6346

6447
const calcDataIndices = (x: number|string): string => {
6548
const indicesRank = indicesShape.length;
@@ -127,7 +110,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
127110
};
128111
return {
129112
name: 'Gather',
130-
shaderCache: {hint: attributes.cacheKey, inputDependencies},
113+
shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']},
131114
getRunData: () => ({
132115
outputs: [
133116
{dims: outputShape, dataType: inputs[0].dataType},

‎js/web/lib/wasm/jsep/webgpu/ops/transpose.ts

+9-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
66
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
export interface TransposeAttributes extends AttributeWithCacheKey {
1212
readonly perm: number[];
@@ -39,12 +39,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
3939
const inputDataType = inputTensor.dataType;
4040
const inputRank = inputTensor.dims.length;
4141
const perm = getAdjustedPerm(inputRank, permAttr);
42-
const useShapesUniforms = enableShapesUniforms(inputRank);
4342
const outputShape = getOutputShape(inputTensor.dims, perm);
44-
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
45-
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
46-
const output = outputVariable('output', inputDataType, outShapeOrRank);
47-
const input = inputVariable('a', inputDataType, inShapeOrRank);
43+
const output = outputVariable('output', inputDataType, outputShape.length);
44+
const input = inputVariable('a', inputDataType, inputRank);
4845

4946
const getShaderSource = (shaderHelper: ShaderHelper) => `
5047
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
@@ -61,21 +58,17 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
6158
}`;
6259
return {
6360
name: 'Transpose',
64-
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
61+
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
6562
getRunData: (inputs) => {
6663
const outputSize = ShapeUtil.size(outputShape);
6764
return {
6865
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
6966
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
70-
programUniforms: useShapesUniforms ?
71-
[
72-
{type: 'uint32', data: outputSize},
73-
...createTensorShapeVariables(inputs[0].dims),
74-
...createTensorShapeVariables(outputShape),
75-
] :
76-
[
77-
{type: 'uint32', data: outputSize},
78-
],
67+
programUniforms: [
68+
{type: 'uint32', data: outputSize},
69+
...createTensorShapeVariables(inputs[0].dims),
70+
...createTensorShapeVariables(outputShape),
71+
],
7972
};
8073
},
8174
getShaderSource,

0 commit comments

Comments
 (0)
Please sign in to comment.