|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved.
|
2 | 2 | // Licensed under the MIT License.
|
3 | 3 |
|
4 |
| -import {tensorDataTypeEnumToString} from '../../../wasm-common'; |
| 4 | +import {DataType} from '../../../wasm-common'; |
5 | 5 | import {TensorView} from '../../tensor-view';
|
6 | 6 | import {ComputeContext, GpuDataType, ProgramUniform} from '../types';
|
7 | 7 |
|
@@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
241 | 241 | WG = Math.ceil(dComp / 8);
|
242 | 242 | }
|
243 | 243 | const elementsPerWG = Math.ceil(d / components / WG);
|
244 |
| - const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; |
245 |
| - const programUniforms: ProgramUniform[] = |
246 |
| - [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; |
| 244 | + const programUniforms: ProgramUniform[] = [ |
| 245 | + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, |
| 246 | + {type: DataType.uint32, data: elementsPerWG} |
| 247 | + ]; |
247 | 248 | const dataType = tensorTypeToWsglStorageType(input.dataType, components);
|
248 | 249 |
|
249 | 250 | const getShaderSource = (shaderHelper: ShaderHelper) => {
|
@@ -336,11 +337,10 @@ const computeAttentionProbs =
|
336 | 337 | y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
337 | 338 | z: parameters.batchSize * parameters.numHeads
|
338 | 339 | };
|
339 |
| - const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; |
340 | 340 | const programUniforms: ProgramUniform[] = [
|
341 |
| - {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, |
342 |
| - {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, |
343 |
| - {type: tensorDataType, data: alpha} |
| 341 | + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, |
| 342 | + {type: DataType.uint32, data: parameters.totalSequenceLength}, |
| 343 | + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} |
344 | 344 | ];
|
345 | 345 |
|
346 | 346 | const inputs = [q, key];
|
@@ -430,9 +430,9 @@ const computeVxAttentionScore =
|
430 | 430 | z: params.batchSize * params.numHeads
|
431 | 431 | };
|
432 | 432 | const programUniforms: ProgramUniform[] = [
|
433 |
| - {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, |
434 |
| - {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, |
435 |
| - {type: 'uint32', data: params.vHiddenSize} |
| 433 | + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, |
| 434 | + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, |
| 435 | + {type: DataType.uint32, data: params.vHiddenSize} |
436 | 436 | ];
|
437 | 437 |
|
438 | 438 | const getShaderSource = (shaderHelper: ShaderHelper) => {
|
@@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
526 | 526 | };
|
527 | 527 | const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
|
528 | 528 | const programUniforms: ProgramUniform[] = [
|
529 |
| - {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, |
530 |
| - {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, |
531 |
| - {type: 'uint32', data: parameters.hiddenSize}, |
532 |
| - {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} |
| 529 | + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, |
| 530 | + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, |
| 531 | + {type: DataType.uint32, data: parameters.hiddenSize}, |
| 532 | + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} |
533 | 533 | ];
|
534 | 534 |
|
535 | 535 | const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
0 commit comments