@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
6
6
import { BroadcastUtil , ShapeUtil } from '../../util' ;
7
7
import { ComputeContext , ProgramInfo } from '../types' ;
8
8
9
- import { createTensorShapeVariables , enableShapesUniforms , inputVariable , outputVariable , ShaderHelper } from './common' ;
9
+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper } from './common' ;
10
10
11
11
type BuiltinFunctionName = string ;
12
12
type BinaryCustomExpression = ( expressionA : string , expressionB : string ) => string ;
@@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
18
18
const createBinaryOpProgramShader =
19
19
( shaderHelper : ShaderHelper , dimsA : readonly number [ ] , dimsB : readonly number [ ] , dimsOutput : readonly number [ ] ,
20
20
vectorize : boolean , doBroadcast : boolean , sharedDimensionDivisibleBy4 : boolean , funcCall : BinaryFunctionCall ,
21
- typeA : number , typeB : number , typeOutput : number , useShapesUniforms : boolean ,
22
- additionalImplementation ?: string ) => {
21
+ typeA : number , typeB : number , typeOutput : number , additionalImplementation ?: string ) => {
23
22
let expressionScalar : BinaryCustomExpression ;
24
23
let expressionVector : BinaryCustomExpression ;
25
24
if ( typeof funcCall === 'string' ) {
@@ -31,12 +30,9 @@ const createBinaryOpProgramShader =
31
30
expressionVector = funcCall . vector ;
32
31
}
33
32
34
- const inputAShapeOrRank = useShapesUniforms ? dimsA . length : dimsA ;
35
- const inputBShapeOrRank = useShapesUniforms ? dimsB . length : dimsB ;
36
- const outputShapeOrRank = useShapesUniforms ? dimsOutput . length : dimsOutput ;
37
- const output = outputVariable ( 'outputData' , typeOutput , outputShapeOrRank , 4 ) ;
38
- const a = inputVariable ( 'aData' , typeA , inputAShapeOrRank , 4 ) ;
39
- const b = inputVariable ( 'bData' , typeB , inputBShapeOrRank , 4 ) ;
33
+ const output = outputVariable ( 'outputData' , typeOutput , dimsOutput . length , 4 ) ;
34
+ const a = inputVariable ( 'aData' , typeA , dimsA . length , 4 ) ;
35
+ const b = inputVariable ( 'bData' , typeB , dimsB . length , 4 ) ;
40
36
41
37
let assignment : string ;
42
38
if ( vectorize ) {
@@ -169,30 +165,25 @@ const createBinaryOpProgramInfo =
169
165
vectorize = true ;
170
166
}
171
167
cacheKeyAux . push ( vectorize ) ;
172
- const useShapesUniforms = enableShapesUniforms ( a . dims . length ) && enableShapesUniforms ( b . dims . length ) &&
173
- enableShapesUniforms ( outputShape . length ) ;
168
+
174
169
return {
175
170
name,
176
171
shaderCache : {
177
172
hint : cacheKey + cacheKeyAux . map ( ( x ) => x . toString ( ) ) . join ( '_' ) ,
178
- inputDependencies : useShapesUniforms ? [ 'rank' , 'rank' ] : [ 'dims' , 'dims '] ,
173
+ inputDependencies : [ 'rank' , 'rank' ] ,
179
174
} ,
180
175
getShaderSource : ( shaderHelper ) => createBinaryOpProgramShader (
181
176
shaderHelper , a . dims , b . dims , outputShape , vectorize , isBroadcast , sharedDimensionDivisibleBy4 , funcCall ,
182
- a . dataType , b . dataType , outputDataType , useShapesUniforms , additionalImplementation ) ,
177
+ a . dataType , b . dataType , outputDataType , additionalImplementation ) ,
183
178
getRunData : ( ) => ( {
184
179
outputs : [ { dims : outputShape , dataType : outputDataType } ] ,
185
180
dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* component size */ ) } ,
186
- programUniforms : useShapesUniforms ?
187
- [
188
- { type : 'uint32' , data : Math . ceil ( ShapeUtil . size ( outputShape ) / 4 ) } ,
189
- ...createTensorShapeVariables ( a . dims ) ,
190
- ...createTensorShapeVariables ( b . dims ) ,
191
- ...createTensorShapeVariables ( outputShape ) ,
192
- ] :
193
- [
194
- { type : 'uint32' , data : Math . ceil ( ShapeUtil . size ( outputShape ) / 4 ) } ,
195
- ] ,
181
+ programUniforms : [
182
+ { type : 'uint32' , data : Math . ceil ( ShapeUtil . size ( outputShape ) / 4 ) } ,
183
+ ...createTensorShapeVariables ( a . dims ) ,
184
+ ...createTensorShapeVariables ( b . dims ) ,
185
+ ...createTensorShapeVariables ( outputShape ) ,
186
+ ] ,
196
187
} ) ,
197
188
} ;
198
189
} ;
0 commit comments