Skip to content

Commit ae3d73c

Browse files
[JS/WebGPU] Fix Split and Where to handle corner cases. (#19613)
### Description <!-- Describe your changes. --> 1. Fix Where operator to handle Boolean input less than 4 bytes. 2. Fix JSEP test harness to use tensor names consistently. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 5e432a3 commit ae3d73c

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

js/web/lib/wasm/jsep/webgpu/ops/where.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const createWhereOpProgramShader =
2727
const expressionA = `a_data[index_a${x}][component_a${x}]`;
2828
const expressionB = `b_data[index_b${x}][component_b${x}]`;
2929
// eslint-disable-next-line no-bitwise
30-
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
30+
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
3131
return `
3232
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
3333
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
@@ -38,6 +38,7 @@ const createWhereOpProgramShader =
3838
let index_c${x} = offset_c${x} / 4u;
3939
let component_a${x} = offset_a${x} % 4u;
4040
let component_b${x} = offset_b${x} % 4u;
41+
let component_c${x} = offset_c${x} % 4u;
4142
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
4243
`;
4344
};

js/web/test/data/ops/where.jsonc

+34
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,39 @@
168168
]
169169
}
170170
]
171+
},
172+
{
173+
"name": "Where with no attributes",
174+
"operator": "Where",
175+
"attributes": [],
176+
"cases": [
177+
{
178+
"name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
179+
"inputs": [
180+
{
181+
"data": [true, false],
182+
"dims": [1, 1, 2, 1],
183+
"type": "bool"
184+
},
185+
{
186+
"data": [1, 2, 3, 4],
187+
"dims": [1, 4],
188+
"type": "float32"
189+
},
190+
{
191+
"data": [5, 6, 7, 8, 9, 10, 11, 12],
192+
"dims": [1, 1, 2, 4],
193+
"type": "float32"
194+
}
195+
],
196+
"outputs": [
197+
{
198+
"data": [1, 2, 3, 4, 9, 10, 11, 12],
199+
"dims": [1, 1, 2, 4],
200+
"type": "float32"
201+
}
202+
]
203+
}
204+
]
171205
}
172206
]

js/web/test/test-runner.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,8 @@ export async function runModelTestSet(
627627
try {
628628
const feeds: Record<string, ort.Tensor> = {};
629629
const outputsMetaInfo: Record<string, ort.Tensor> = {};
630-
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
631-
testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
630+
testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
631+
testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
632632
const [start, end, outputs] =
633633
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
634634
if (context.perfData.count === 0) {

0 commit comments

Comments
 (0)