Skip to content

Commit a825fbb

Browse files
Use uppercase letters for M, N and K.
1 parent f186004 commit a825fbb

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

+16-16
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
88
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
10+
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
// TODO support quantization bits not equal to 4
1313
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
14-
k: number;
15-
n: number;
14+
K: number;
15+
N: number;
1616
accuracyLevel: number;
1717
bits: number;
1818
blockSize: number;
@@ -24,25 +24,25 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
2424
}
2525
const a = inputs[0];
2626
const aRank = a.dims.length;
27-
if (a.dims[aRank - 1] !== attributes.k) {
27+
if (a.dims[aRank - 1] !== attributes.K) {
2828
throw new Error('The last dim of input shape does not match the k value');
2929
}
30-
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
30+
const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize);
3131
const blobSize = attributes.blockSize / 8 * attributes.bits;
3232
const b = inputs[1];
33-
if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
33+
if (!ShapeUtil.areEqual(b.dims, [attributes.N, nBlocksPerCol, blobSize])) {
3434
throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
3535
}
3636
const scales = inputs[2];
3737
const scalesShape = scales.dims;
38-
if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
38+
if (ShapeUtil.size(scalesShape) !== attributes.N * nBlocksPerCol) {
3939
throw new Error('scales input size error.');
4040
}
4141
if (inputs.length === 4) {
4242
const zeroPoints = inputs[3];
4343
const zeroPointsShape = zeroPoints.dims;
4444
const expectedZeroPointsSize =
45-
attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
45+
attributes.bits > 4 ? (attributes.N * nBlocksPerCol) : attributes.N * Math.floor((nBlocksPerCol + 1) / 2);
4646
if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
4747
throw new Error('zeroPoints input size error.');
4848
}
@@ -53,19 +53,19 @@ export const createMatMulNBitsProgramInfo =
5353
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
5454
const inputShape = inputs[0].dims;
5555
const aRank = inputShape.length;
56-
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
57-
const m = inputShape[aRank - 2];
56+
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.N);
57+
const M = inputShape[aRank - 2];
5858
const blobSize = attributes.blockSize / 8 * attributes.bits;
5959
const blobSizeInWords = blobSize / 4;
60-
const outputNumber = getMaxComponents(m);
60+
const outputNumber = getMaxComponents(M);
6161
const components = 1; // getMaxComponents(attributes.n);
62-
const aComponents = getMaxComponents(attributes.k);
62+
const aComponents = getMaxComponents(attributes.K);
6363
const bComponents = getMaxComponents(blobSizeInWords);
64-
const zComponents = 1; // getMaxComponents(attributes.n / 8);
64+
const zComponents = 1; // getMaxComponents(attributes.N / 8);
6565
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
6666
const programUniforms: ProgramUniform[] = [
67-
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
68-
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
67+
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.K},
68+
{type: DataType.uint32, data: attributes.N}, {type: DataType.uint32, data: attributes.accuracyLevel},
6969
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
7070
];
7171
const getShaderSource = (shaderHelper: ShaderHelper) => {
@@ -88,7 +88,7 @@ export const createMatMulNBitsProgramInfo =
8888
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
8989
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
9090
];
91-
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
91+
const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize);
9292
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
9393
const dequantizeArrayReturnType = (() => {
9494
switch (aComponents) {

onnxruntime/contrib_ops/js/quantization/matmul_nbits.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class MatMulNBits final : public JsKernel {
2222
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
2323
"Block size must be a power of 2 and greater than or equal to 16.");
2424
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
25-
"k" : $1,
26-
"n" : $2,
25+
"K" : $1,
26+
"N" : $2,
2727
"accuracyLevel" : $3,
2828
"bits" : $4,
2929
"blockSize" : $5

0 commit comments

Comments
 (0)