Skip to content

Commit 97e8998

Browse files
authored
[js/webgpu] allows a ProgramInfo's RunData to use zero sized output (microsoft#19614)
### Description This PR allows zero-sized output. To make the implementation simple, it does not support partial zero-sized tensor. Which means, either all outputs are zero-sized, or an error will be reported. added 2 tests: - op test of `Add` with input T[2,0] T[2,1], and - test_split_zero_size_splits
1 parent 3acaea2 commit 97e8998

File tree

6 files changed

+71
-9
lines changed

6 files changed

+71
-9
lines changed

web/lib/wasm/jsep/backend-webgpu.ts

+28-4
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,16 @@ export class WebGpuBackend {
385385
// create info for inputs
386386
const inputDatas: GpuData[] = [];
387387
for (let i = 0; i < inputTensorViews.length; ++i) {
388-
const gpuData = this.gpuDataManager.get(inputTensorViews[i].data);
388+
const data = inputTensorViews[i].data;
389+
// if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
390+
if (data === 0) {
391+
continue;
392+
}
393+
const gpuData = this.gpuDataManager.get(data);
389394
if (!gpuData) {
390-
throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`);
395+
throw new Error(`no GPU data for input: ${data}`);
391396
}
392-
inputDatas[i] = gpuData;
397+
inputDatas.push(gpuData);
393398
}
394399

395400
const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
@@ -419,6 +424,11 @@ export class WebGpuBackend {
419424
const tensorView = (isTemporary || isPersistent) ?
420425
createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
421426
createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
427+
outputTensorViews.push(tensorView);
428+
// if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
429+
if (tensorView.data === 0) {
430+
continue;
431+
}
422432
const gpuData = this.gpuDataManager.get(tensorView.data);
423433
if (!gpuData) {
424434
throw new Error(`no GPU data for output: ${tensorView.data}`);
@@ -434,10 +444,24 @@ export class WebGpuBackend {
434444
}
435445
persistentData.push(gpuData);
436446
}
437-
outputTensorViews.push(tensorView);
438447
outputDatas.push(gpuData);
439448
}
440449

450+
// when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
451+
// zero-sized tensors.
452+
if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
453+
// if all outputs are zero-sized tensors, there is no need to run the program.
454+
if (outputDatas.length === 0) {
455+
TRACE_FUNC_END(program.name);
456+
return outputTensorViews;
457+
}
458+
// if some outputs are zero-sized tensors, report an error.
459+
//
460+
// TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
461+
// If we see such use case, we need to make a change here to support it.
462+
throw new Error(
463+
`Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
464+
}
441465

442466
// load uniforms
443467
// TODO: add cache for uniform (is it necessary?)

web/lib/wasm/jsep/init.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext {
104104
throw new Error(`Unsupported data type: ${dataType}`);
105105
}
106106
const bufferSize = elementSize * ShapeUtil.size(dims);
107-
return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
107+
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
108+
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
108109
};
109110
return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
110111
}

web/lib/wasm/jsep/util.ts

+10-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@ export class BroadcastUtil {
5656
if (aLen !== bLen && aLen > 1 && bLen > 1) {
5757
return undefined;
5858
}
59-
cdims[crank - i] = Math.max(aLen, bLen);
59+
const max = Math.max(aLen, bLen);
60+
if (aLen && bLen) {
61+
cdims[crank - i] = Math.max(aLen, bLen);
62+
} else {
63+
// when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
64+
if (max > 1) {
65+
return undefined;
66+
}
67+
cdims[crank - i] = 0;
68+
}
6069
}
6170

6271
return cdims;

web/test/data/ops/add.jsonc

+22
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,28 @@
157157
"type": "float32"
158158
}
159159
]
160+
},
161+
{
162+
"name": "T[2,0] T[2,1]",
163+
"inputs": [
164+
{
165+
"data": [],
166+
"dims": [2, 0],
167+
"type": "float32"
168+
},
169+
{
170+
"data": [1, 2],
171+
"dims": [2, 1],
172+
"type": "float32"
173+
}
174+
],
175+
"outputs": [
176+
{
177+
"data": [],
178+
"dims": [2, 0],
179+
"type": "float32"
180+
}
181+
]
160182
}
161183
]
162184
}

web/test/suite-test-list.jsonc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@
12311231
"test_split_variable_parts_1d",
12321232
"test_split_variable_parts_2d",
12331233
"test_split_variable_parts_default_axis",
1234-
// // "test_split_zero_size_splits",
1234+
"test_split_zero_size_splits",
12351235
"test_sqrt_example",
12361236
"test_sqrt",
12371237
"test_squeeze_negative_axes",

web/test/test-runner.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,9 @@ export async function sessionRun(options: {
573573
// replace the CPU tensors in feeds into GPU tensors
574574
for (const name in feeds) {
575575
if (Object.hasOwnProperty.call(feeds, name)) {
576-
feeds[name] = createGpuTensorForInput(feeds[name]);
576+
if (feeds[name].size > 0) {
577+
feeds[name] = createGpuTensorForInput(feeds[name]);
578+
}
577579
}
578580
}
579581
}
@@ -582,7 +584,11 @@ export async function sessionRun(options: {
582584
for (const name in options.outputsMetaInfo) {
583585
if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) {
584586
const {type, dims} = options.outputsMetaInfo[name];
585-
fetches[name] = createGpuTensorForOutput(type, dims);
587+
if (dims.some(d => d === 0)) {
588+
fetches[name] = new ort.Tensor(type, [], dims);
589+
} else {
590+
fetches[name] = createGpuTensorForOutput(type, dims);
591+
}
586592
}
587593
}
588594
}

0 commit comments

Comments
 (0)