Skip to content

Commit 1b48054

Browse files
authoredFeb 20, 2024
[js/webgpu] Create Split indices helpers by rank, not by shape (#19554)
### Description This is required to make shape uniforms really work. ### Motivation and Context The bug was unveiled in a model with multiple Split nodes. The later nodes would try to reuse a previous pipeline cache, while the old shapes were hardcoded as constants in cache.
1 parent 7efb0db commit 1b48054

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed
 

‎js/web/lib/wasm/jsep/webgpu/ops/split.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
6868
const dataType = inputs[0].dataType;
6969
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
7070
const outputs = new Array<IndicesHelper>(attributes.numOutputs);
71-
const input = inputVariable('input', dataType, inputShape);
71+
const input = inputVariable('input', dataType, inputShape.length);
7272
const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
7373
const outputsTensorInfo: TensorInfo[] = [];
7474
const outputShapes: number[][] = [];
@@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
8080
const outputShape = inputShape.slice();
8181
outputShape[attributes.axis] = attributes.splitSizes[i];
8282
outputShapes.push(outputShape);
83-
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
83+
outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
8484
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
8585
}
8686
programUniforms.push(

0 commit comments

Comments
 (0)
Please sign in to comment.