@@ -385,11 +385,16 @@ export class WebGpuBackend {
385
385
// create info for inputs
386
386
const inputDatas : GpuData [ ] = [ ] ;
387
387
for ( let i = 0 ; i < inputTensorViews . length ; ++ i ) {
388
- const gpuData = this . gpuDataManager . get ( inputTensorViews [ i ] . data ) ;
388
+ const data = inputTensorViews [ i ] . data ;
389
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
390
+ if ( data === 0 ) {
391
+ continue ;
392
+ }
393
+ const gpuData = this . gpuDataManager . get ( data ) ;
389
394
if ( ! gpuData ) {
390
- throw new Error ( `no GPU data for input: ${ inputTensorViews [ i ] . data } ` ) ;
395
+ throw new Error ( `no GPU data for input: ${ data } ` ) ;
391
396
}
392
- inputDatas [ i ] = gpuData ;
397
+ inputDatas . push ( gpuData ) ;
393
398
}
394
399
395
400
const { outputs, dispatchGroup, programUniforms} = program . getRunData ( inputTensorViews ) ;
@@ -419,6 +424,11 @@ export class WebGpuBackend {
419
424
const tensorView = ( isTemporary || isPersistent ) ?
420
425
createIntermediateOutput ( outputs [ i ] . dataType , outputs [ i ] . dims ) :
421
426
createKernelOutput ( validatedOutputIndices [ i ] , outputs [ i ] . dataType , outputs [ i ] . dims ) ;
427
+ outputTensorViews . push ( tensorView ) ;
428
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
429
+ if ( tensorView . data === 0 ) {
430
+ continue ;
431
+ }
422
432
const gpuData = this . gpuDataManager . get ( tensorView . data ) ;
423
433
if ( ! gpuData ) {
424
434
throw new Error ( `no GPU data for output: ${ tensorView . data } ` ) ;
@@ -434,10 +444,24 @@ export class WebGpuBackend {
434
444
}
435
445
persistentData . push ( gpuData ) ;
436
446
}
437
- outputTensorViews . push ( tensorView ) ;
438
447
outputDatas . push ( gpuData ) ;
439
448
}
440
449
450
+ // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
451
+ // zero-sized tensors.
452
+ if ( inputDatas . length !== inputTensorViews . length || outputDatas . length !== outputTensorViews . length ) {
453
+ // if all outputs are zero-sized tensors, there is no need to run the program.
454
+ if ( outputDatas . length === 0 ) {
455
+ TRACE_FUNC_END ( program . name ) ;
456
+ return outputTensorViews ;
457
+ }
458
+ // if some outputs are zero-sized tensors, report an error.
459
+ //
460
+ // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
461
+ // If we see such use case, we need to make a change here to support it.
462
+ throw new Error (
463
+ `Program ${ program . name } has zero-sized tensor(s) in inputs or outputs. This is not supported now.` ) ;
464
+ }
441
465
442
466
// load uniforms
443
467
// TODO: add cache for uniform (is it necessary?)
0 commit comments