@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
7
7
import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
8
import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10
+ import { inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
// TODO support quantization bits not equal to 4
13
13
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -52,8 +52,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
52
52
export const createMatMulNBitsProgramInfo =
53
53
( inputs : readonly TensorView [ ] , attributes : MatMulNBitsAttributes ) : ProgramInfo => {
54
54
const a = inputs [ 0 ] ;
55
- const b = inputs [ 1 ] ;
56
- const scales = inputs [ 2 ] ;
57
55
const aRank = a . dims . length ;
58
56
const outputShape = a . dims . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
59
57
const outputSize = ShapeUtil . size ( outputShape ) ;
@@ -64,24 +62,17 @@ export const createMatMulNBitsProgramInfo =
64
62
{ type : DataType . uint32 , data : attributes . n } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
65
63
{ type : DataType . uint32 , data : attributes . bits } , { type : DataType . uint32 , data : attributes . blockSize }
66
64
] ;
67
- programUniforms . push ( ...createTensorShapeVariables ( a . dims ) ) ;
68
- programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( b . dims ) ) ) ;
69
- programUniforms . push ( ...createTensorShapeVariables ( scales . dims ) ) ;
70
- if ( inputs . length === 4 ) {
71
- programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( inputs [ 3 ] . dims ) ) ) ;
72
- }
73
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
74
65
const getShaderSource = ( shaderHelper : ShaderHelper ) => {
75
- const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length ) ;
76
- const b = inputVariable ( 'b' , DataType . uint32 , inputs [ 1 ] . dims . length ) ;
77
- const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
66
+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims ) ;
67
+ const b = inputVariable ( 'b' , DataType . uint32 , ShapeUtil . convertShape ( inputs [ 1 ] . dims ) ) ;
68
+ const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims ) ;
78
69
const inputVariables = [ a , b , scales ] ;
79
70
const zeroPoints =
80
- inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims . length ) : undefined ;
71
+ inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims ) : undefined ;
81
72
if ( zeroPoints ) {
82
73
inputVariables . push ( zeroPoints ) ;
83
74
}
84
- const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length ) ;
75
+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape ) ;
85
76
const uniforms : UniformsArrayType = [
86
77
{ name : 'output_size' , type : 'u32' } , { name : 'k' , type : 'u32' } , { name : 'n' , type : 'u32' } ,
87
78
{ name : 'accuracy_level' , type : 'u32' } , { name : 'bits' , type : 'u32' } , { name : 'block_size' , type : 'u32' }
@@ -165,7 +156,7 @@ export const createMatMulNBitsProgramInfo =
165
156
return {
166
157
name : 'MatMulNBits' ,
167
158
shaderCache :
168
- { hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'rank ' ) } ,
159
+ { hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'dims ' ) } ,
169
160
getRunData : ( ) => ( {
170
161
outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
171
162
dispatchGroup : { x : Math . ceil ( outputSize / 64 ) } ,
0 commit comments