Skip to content

Commit a3f0e24

Browse files
authored
[js/webgpu] Support f16 uniform (#19098)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 8b45172 commit a3f0e24

File tree

5 files changed

+56
-26
lines changed

5 files changed

+56
-26
lines changed

js/web/lib/wasm/jsep/backend-webgpu.ts

+21-5
Original file line numberDiff line numberDiff line change
@@ -428,13 +428,26 @@ export class WebGpuBackend {
428428
return;
429429
}
430430
// https://www.w3.org/TR/WGSL/#alignof
431-
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
431+
const sizeOfElement = v.type === 'float16' ? 2 : 4;
432+
let sizeOfVecOrMat;
433+
let baseAlignment;
434+
if (v.type === 'float16') {
435+
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
436+
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
437+
} else {
438+
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
439+
sizeOfVecOrMat = 16;
440+
}
432441
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
433442
offsets.push(currentOffset);
434-
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
435-
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
436-
// SizeOf(vec4<i32|u32|f32>).
437-
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
443+
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
444+
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
445+
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
446+
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
447+
// length is N * SizeOf(mat2x4<f16>).
448+
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
449+
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
450+
data.length * sizeOfElement;
438451
});
439452

440453
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
@@ -449,6 +462,9 @@ export class WebGpuBackend {
449462
new Int32Array(arrayBuffer, offset, data.length).set(data);
450463
} else if (v.type === 'uint32') {
451464
new Uint32Array(arrayBuffer, offset, data.length).set(data);
465+
} else if (v.type === 'float16') {
466+
// TODO: use Float16Array.
467+
new Uint16Array(arrayBuffer, offset, data.length).set(data);
452468
} else {
453469
new Float32Array(arrayBuffer, offset, data.length).set(data);
454470
}

js/web/lib/wasm/jsep/webgpu/ops/common.ts

+27-13
Original file line numberDiff line numberDiff line change
@@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => {
330330
* @param name - the name of variable.
331331
* @param index - the index of variable element.
332332
* @param length - the length of variable.
333+
* @param type - the type of variable, optional.
333334
*/
334-
export const getElementAt = (name: string, index: number|string, length: number): string => {
335-
if (name.startsWith('uniforms.') && length > 4) {
336-
if (typeof (index) === 'string') {
337-
return `${name}[(${index}) / 4][(${index}) % 4]`;
338-
} else {
339-
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
340-
}
341-
} else {
342-
return length > 1 ? `${name}[${index}]` : name;
343-
}
344-
};
335+
export const getElementAt =
336+
(name: string, index: number|string, length: number, type?: UniformDataElementType): string => {
337+
if (name.startsWith('uniforms.') && length > 4) {
338+
if (typeof (index) === 'string') {
339+
if (type === 'f16') {
340+
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
341+
} else {
342+
return `${name}[(${index}) / 4][(${index}) % 4]`;
343+
}
344+
} else {
345+
if (type === 'f16') {
346+
return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;
347+
} else {
348+
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
349+
}
350+
}
351+
} else {
352+
return length > 1 ? `${name}[${index}]` : name;
353+
}
354+
};
345355

346356
/**
347357
* A helper function to get a IndicesHelper for a given input or output.
@@ -688,7 +698,7 @@ export const internalVariable =
688698
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
689699
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
690700

691-
export type UniformDataElementType = 'u32'|'f32'|'i32';
701+
export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32';
692702
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
693703

694704
/**
@@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper {
861871
const uniformSnippets: string[] = [];
862872
for (const {name, type, length} of this.uniforms) {
863873
if (length && length > 4) {
864-
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
874+
if (type === 'f16') {
875+
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
876+
} else {
877+
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
878+
}
865879
} else {
866880
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
867881
uniformSnippets.push(`${name}:${typeTemp}`);

js/web/lib/wasm/jsep/webgpu/ops/pad.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
1919
if (!inputs || inputs.length < 1) {
2020
throw new Error('Too few inputs');
2121
}
22-
if (inputs[0].dataType !== DataType.float) {
23-
throw new Error('Input type must be float.');
22+
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
23+
throw new Error('Input type must be float or float16.');
2424
}
2525

2626
if (inputs.length >= 2) {

js/web/lib/wasm/jsep/webgpu/types.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export interface TensorInfo {
2424
}
2525

2626
export interface ProgramUniform {
27-
type: 'int32'|'float32'|'uint32';
27+
type: 'int32'|'float16'|'float32'|'uint32';
2828
data: number|readonly number[];
2929
}
3030

onnxruntime/core/providers/js/operators/pad.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1414
2,
1515
10,
1616
kJsExecutionProvider,
17-
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
17+
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
1818
Pad);
1919

2020
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2424
12,
2525
kJsExecutionProvider,
2626
(*KernelDefBuilder::Create())
27-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
27+
.TypeConstraint("T", JsepSupportedFloatTypes())
2828
.InputMemoryType(OrtMemTypeCPU, 1)
2929
.InputMemoryType(OrtMemTypeCPU, 2)
3030
.InputMemoryType(OrtMemTypeCPU, 3),
@@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
3737
17,
3838
kJsExecutionProvider,
3939
(*KernelDefBuilder::Create())
40-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
40+
.TypeConstraint("T", JsepSupportedFloatTypes())
4141
.InputMemoryType(OrtMemTypeCPU, 1)
4242
.InputMemoryType(OrtMemTypeCPU, 2)
4343
.InputMemoryType(OrtMemTypeCPU, 3),
@@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
5050
18,
5151
kJsExecutionProvider,
5252
(*KernelDefBuilder::Create())
53-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
53+
.TypeConstraint("T", JsepSupportedFloatTypes())
5454
.InputMemoryType(OrtMemTypeCPU, 1)
5555
.InputMemoryType(OrtMemTypeCPU, 2)
5656
.InputMemoryType(OrtMemTypeCPU, 3),
@@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX(
6262
19,
6363
kJsExecutionProvider,
6464
(*KernelDefBuilder::Create())
65-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
65+
.TypeConstraint("T", JsepSupportedFloatTypes())
6666
.InputMemoryType(OrtMemTypeCPU, 1)
6767
.InputMemoryType(OrtMemTypeCPU, 2)
6868
.InputMemoryType(OrtMemTypeCPU, 3),

0 commit comments

Comments
 (0)