@@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
6
6
import { BroadcastUtil , ShapeUtil } from '../../util' ;
7
7
import { ComputeContext , ProgramInfo } from '../types' ;
8
8
9
- import { inputVariable , outputVariable , ShaderHelper } from './common' ;
9
+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper } from './common' ;
10
10
11
11
const createWhereOpProgramShader =
12
12
( shaderHelper : ShaderHelper , inputs : readonly TensorView [ ] , dimsOutput : readonly number [ ] , isBroadcast : boolean ,
13
13
typeOutput : number ) => {
14
- const outputSize = ShapeUtil . size ( dimsOutput ) ;
15
- const vecSize = Math . ceil ( outputSize / 4 ) ;
16
-
17
- const output = outputVariable ( 'outputData' , typeOutput , dimsOutput , 4 ) ;
18
- const a = inputVariable ( 'aData' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims , 4 ) ;
19
- const b = inputVariable ( 'bData' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims , 4 ) ;
20
- const c = inputVariable ( 'cData' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims , 4 ) ;
14
+ const output = outputVariable ( 'output_data' , typeOutput , dimsOutput . length , 4 ) ;
15
+ const a = inputVariable ( 'a_data' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , 4 ) ;
16
+ const b = inputVariable ( 'b_data' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , 4 ) ;
17
+ const c = inputVariable ( 'c_data' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , 4 ) ;
21
18
22
19
let assignment : string ;
23
20
const expression = ( a : string , b : string , c : string ) => `select(${ b } , ${ a } , ${ c } )` ;
@@ -27,20 +24,20 @@ const createWhereOpProgramShader =
27
24
expression ( a . getByOffset ( 'global_idx' ) , b . getByOffset ( 'global_idx' ) , c . getByOffset ( 'global_idx' ) ) ) ;
28
25
} else {
29
26
const singleAssignment = ( resStr : string , x : number , typeCast = '' ) => {
30
- const expressionA = `aData[indexA ${ x } ][componentA ${ x } ]` ;
31
- const expressionB = `bData[indexB ${ x } ][componentB ${ x } ]` ;
27
+ const expressionA = `a_data[index_a ${ x } ][component_a ${ x } ]` ;
28
+ const expressionB = `b_data[index_b ${ x } ][component_b ${ x } ]` ;
32
29
// eslint-disable-next-line no-bitwise
33
- const expressionC = `bool(cData[indexC ${ x } ] & ${ 0xff000000 >>> ( ( 3 - x ) * 8 ) } u)` ;
30
+ const expressionC = `bool(c_data[index_c ${ x } ] & ${ 0xff000000 >>> ( ( 3 - x ) * 8 ) } u)` ;
34
31
return `
35
- let outputIndices ${ x } = ${ output . offsetToIndices ( `global_idx * 4u + ${ x } u` ) } ;
36
- let offsetA ${ x } = ${ a . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
37
- let offsetB ${ x } = ${ b . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
38
- let offsetC ${ x } = ${ c . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
39
- let indexA ${ x } = offsetA ${ x } / 4u;
40
- let indexB ${ x } = offsetB ${ x } / 4u;
41
- let indexC ${ x } = offsetC ${ x } / 4u;
42
- let componentA ${ x } = offsetA ${ x } % 4u;
43
- let componentB ${ x } = offsetB ${ x } % 4u;
32
+ let output_indices ${ x } = ${ output . offsetToIndices ( `global_idx * 4u + ${ x } u` ) } ;
33
+ let offset_a ${ x } = ${ a . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
34
+ let offset_b ${ x } = ${ b . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
35
+ let offset_c ${ x } = ${ c . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
36
+ let index_a ${ x } = offset_a ${ x } / 4u;
37
+ let index_b ${ x } = offset_b ${ x } / 4u;
38
+ let index_c ${ x } = offset_c ${ x } / 4u;
39
+ let component_a ${ x } = offset_a ${ x } % 4u;
40
+ let component_b ${ x } = offset_b ${ x } % 4u;
44
41
${ resStr } [${ x } ] = ${ typeCast } (${ expression ( expressionA , expressionB , expressionC ) } );
45
42
` ;
46
43
} ;
@@ -51,21 +48,21 @@ const createWhereOpProgramShader =
51
48
${ singleAssignment ( 'data' , 1 , 'u32' ) }
52
49
${ singleAssignment ( 'data' , 2 , 'u32' ) }
53
50
${ singleAssignment ( 'data' , 3 , 'u32' ) }
54
- outputData [global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));` ;
51
+ output_data [global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));` ;
55
52
} else {
56
53
assignment = `
57
- ${ singleAssignment ( 'outputData [global_idx]' , 0 ) }
58
- ${ singleAssignment ( 'outputData [global_idx]' , 1 ) }
59
- ${ singleAssignment ( 'outputData [global_idx]' , 2 ) }
60
- ${ singleAssignment ( 'outputData [global_idx]' , 3 ) }
54
+ ${ singleAssignment ( 'output_data [global_idx]' , 0 ) }
55
+ ${ singleAssignment ( 'output_data [global_idx]' , 1 ) }
56
+ ${ singleAssignment ( 'output_data [global_idx]' , 2 ) }
57
+ ${ singleAssignment ( 'output_data [global_idx]' , 3 ) }
61
58
` ;
62
59
}
63
60
}
64
61
65
62
return `
66
- ${ shaderHelper . declareVariables ( c , a , b , output ) }
63
+ ${ shaderHelper . registerUniform ( 'vec_size' , 'u32' ) . declareVariables ( c , a , b , output ) }
67
64
${ shaderHelper . mainStart ( ) }
68
- ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( vecSize ) }
65
+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.vec_size' ) }
69
66
${ assignment }
70
67
}` ;
71
68
} ;
@@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
79
76
const isBroadcast = ! ( ShapeUtil . areEqual ( dimsA , dimsB ) && ShapeUtil . areEqual ( dimsB , dimsC ) ) ;
80
77
let outputShape = dimsA ;
81
78
let outputSize = ShapeUtil . size ( dimsA ) ;
79
+ const vecSize = Math . ceil ( outputSize / 4 ) ;
82
80
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
83
81
84
82
if ( isBroadcast ) {
@@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
92
90
93
91
return {
94
92
name : 'Where' ,
93
+ shaderCache : { inputDependencies : [ 'rank' , 'rank' , 'rank' ] } ,
95
94
getShaderSource : ( shaderHelper ) =>
96
95
createWhereOpProgramShader ( shaderHelper , inputs , outputShape , isBroadcast , outputDataType ) ,
97
96
getRunData : ( ) => ( {
98
97
outputs : [ { dims : outputShape , dataType : outputDataType } ] ,
99
- dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* vec size */ ) }
98
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* vec size */ ) } ,
99
+ programUniforms : [
100
+ { type : 'uint32' , data : vecSize } , ...createTensorShapeVariables ( dimsC ) , ...createTensorShapeVariables ( dimsA ) ,
101
+ ...createTensorShapeVariables ( dimsB ) , ...createTensorShapeVariables ( outputShape )
102
+ ] ,
100
103
} ) ,
101
104
} ;
102
105
} ;
0 commit comments