Skip to content

Commit d114d9a

Browse files
Temporarily remove uniforms to debug.
1 parent 069d2d6 commit d114d9a

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

+7-16
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
77
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
88
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
99

10-
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
10+
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111

1212
// TODO support quantization bits not equal to 4
1313
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -52,8 +52,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
5252
export const createMatMulNBitsProgramInfo =
5353
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
5454
const a = inputs[0];
55-
const b = inputs[1];
56-
const scales = inputs[2];
5755
const aRank = a.dims.length;
5856
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
5957
const outputSize = ShapeUtil.size(outputShape);
@@ -64,24 +62,17 @@ export const createMatMulNBitsProgramInfo =
6462
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
6563
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
6664
];
67-
programUniforms.push(...createTensorShapeVariables(a.dims));
68-
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims)));
69-
programUniforms.push(...createTensorShapeVariables(scales.dims));
70-
if (inputs.length === 4) {
71-
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
72-
}
73-
programUniforms.push(...createTensorShapeVariables(outputShape));
7465
const getShaderSource = (shaderHelper: ShaderHelper) => {
75-
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length);
76-
const b = inputVariable('b', DataType.uint32, inputs[1].dims.length);
77-
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
66+
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
67+
const b = inputVariable('b', DataType.uint32, ShapeUtil.convertShape(inputs[1].dims));
68+
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims);
7869
const inputVariables = [a, b, scales];
7970
const zeroPoints =
80-
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
71+
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims) : undefined;
8172
if (zeroPoints) {
8273
inputVariables.push(zeroPoints);
8374
}
84-
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
75+
const output = outputVariable('output', inputs[0].dataType, outputShape);
8576
const uniforms: UniformsArrayType = [
8677
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
8778
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
@@ -165,7 +156,7 @@ export const createMatMulNBitsProgramInfo =
165156
return {
166157
name: 'MatMulNBits',
167158
shaderCache:
168-
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
159+
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('dims')},
169160
getRunData: () => ({
170161
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
171162
dispatchGroup: {x: Math.ceil(outputSize / 64)},

0 commit comments

Comments
 (0)