@@ -7,12 +7,12 @@ import {ShapeUtil} from '../../util';
7
7
import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
8
8
import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
9
9
10
- import { getMaxComponents , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10
+ import { createTensorShapeVariables , getMaxComponents , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
11
11
12
12
// TODO support quantization bits not equal to 4
13
13
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
14
- k : number ;
15
- n : number ;
14
+ K : number ;
15
+ N : number ;
16
16
accuracyLevel : number ;
17
17
bits : number ;
18
18
blockSize : number ;
@@ -24,25 +24,25 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
24
24
}
25
25
const a = inputs [ 0 ] ;
26
26
const aRank = a . dims . length ;
27
- if ( a . dims [ aRank - 1 ] !== attributes . k ) {
27
+ if ( a . dims [ aRank - 1 ] !== attributes . K ) {
28
28
throw new Error ( 'The last dim of input shape does not match the k value' ) ;
29
29
}
30
- const nBlocksPerCol = Math . floor ( ( attributes . k + attributes . blockSize - 1 ) / attributes . blockSize ) ;
30
+ const nBlocksPerCol = Math . floor ( ( attributes . K + attributes . blockSize - 1 ) / attributes . blockSize ) ;
31
31
const blobSize = attributes . blockSize / 8 * attributes . bits ;
32
32
const b = inputs [ 1 ] ;
33
- if ( ! ShapeUtil . areEqual ( b . dims , [ attributes . n , nBlocksPerCol , blobSize ] ) ) {
33
+ if ( ! ShapeUtil . areEqual ( b . dims , [ attributes . N , nBlocksPerCol , blobSize ] ) ) {
34
34
throw new Error ( 'The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize' ) ;
35
35
}
36
36
const scales = inputs [ 2 ] ;
37
37
const scalesShape = scales . dims ;
38
- if ( ShapeUtil . size ( scalesShape ) !== attributes . n * nBlocksPerCol ) {
38
+ if ( ShapeUtil . size ( scalesShape ) !== attributes . N * nBlocksPerCol ) {
39
39
throw new Error ( 'scales input size error.' ) ;
40
40
}
41
41
if ( inputs . length === 4 ) {
42
42
const zeroPoints = inputs [ 3 ] ;
43
43
const zeroPointsShape = zeroPoints . dims ;
44
44
const expectedZeroPointsSize =
45
- attributes . bits > 4 ? ( attributes . n * nBlocksPerCol ) : attributes . n * Math . floor ( ( nBlocksPerCol + 1 ) / 2 ) ;
45
+ attributes . bits > 4 ? ( attributes . N * nBlocksPerCol ) : attributes . N * Math . floor ( ( nBlocksPerCol + 1 ) / 2 ) ;
46
46
if ( ShapeUtil . size ( zeroPointsShape ) !== expectedZeroPointsSize ) {
47
47
throw new Error ( 'zeroPoints input size error.' ) ;
48
48
}
@@ -53,19 +53,19 @@ export const createMatMulNBitsProgramInfo =
53
53
( inputs : readonly TensorView [ ] , attributes : MatMulNBitsAttributes ) : ProgramInfo => {
54
54
const inputShape = inputs [ 0 ] . dims ;
55
55
const aRank = inputShape . length ;
56
- const outputShape = inputShape . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
57
- const m = inputShape [ aRank - 2 ] ;
56
+ const outputShape = inputShape . slice ( 0 , aRank - 1 ) . concat ( attributes . N ) ;
57
+ const M = inputShape [ aRank - 2 ] ;
58
58
const blobSize = attributes . blockSize / 8 * attributes . bits ;
59
59
const blobSizeInWords = blobSize / 4 ;
60
- const outputNumber = getMaxComponents ( m ) ;
60
+ const outputNumber = getMaxComponents ( M ) ;
61
61
const components = 1 ; // getMaxComponents(attributes.n);
62
- const aComponents = getMaxComponents ( attributes . k ) ;
62
+ const aComponents = getMaxComponents ( attributes . K ) ;
63
63
const bComponents = getMaxComponents ( blobSizeInWords ) ;
64
- const zComponents = 1 ; // getMaxComponents(attributes.n / 8);
64
+ const zComponents = 1 ; // getMaxComponents(attributes.N / 8);
65
65
const outputSize = ShapeUtil . size ( outputShape ) / components / outputNumber ;
66
66
const programUniforms : ProgramUniform [ ] = [
67
- { type : DataType . uint32 , data : outputSize } , { type : DataType . uint32 , data : attributes . k } ,
68
- { type : DataType . uint32 , data : attributes . n } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
67
+ { type : DataType . uint32 , data : outputSize } , { type : DataType . uint32 , data : attributes . K } ,
68
+ { type : DataType . uint32 , data : attributes . N } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
69
69
{ type : DataType . uint32 , data : attributes . bits } , { type : DataType . uint32 , data : attributes . blockSize }
70
70
] ;
71
71
const getShaderSource = ( shaderHelper : ShaderHelper ) => {
@@ -88,7 +88,7 @@ export const createMatMulNBitsProgramInfo =
88
88
{ name : 'output_size' , type : 'u32' } , { name : 'K' , type : 'u32' } , { name : 'N' , type : 'u32' } ,
89
89
{ name : 'accuracy_level' , type : 'u32' } , { name : 'bits' , type : 'u32' } , { name : 'block_size' , type : 'u32' }
90
90
] ;
91
- const nBlocksPerCol = Math . floor ( ( attributes . k + attributes . blockSize - 1 ) / attributes . blockSize ) ;
91
+ const nBlocksPerCol = Math . floor ( ( attributes . K + attributes . blockSize - 1 ) / attributes . blockSize ) ;
92
92
const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
93
93
const dequantizeArrayReturnType = ( ( ) => {
94
94
switch ( aComponents ) {
0 commit comments