Skip to content

Commit 58f4921

Browse files
authored
[js] changes to allow Float16Array if any polyfill is available (#19305)
### Description This change adds only necessary code to enable ort-web works with any Float16Array polyfill. Unlike #19302, in this PR, ort-web does not include any specific polyfill; instead, it's user's choice for how to use a polyfill. ORT-web uses Float16Array if it's available; otherwise, fallback to use Uint16Array. ```js // case 1: user does not use polyfill: import * as ort from 'onnxruntime-web'; const myF16Data = new Uint16Array(...); // need to use Uint16Array const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` ```js // case 2: user use polyfill: import * as ort from 'onnxruntime-web'; import { Float16Array, isFloat16Array, isTypedArray, getFloat16, setFloat16, f16round, } from "@petamoriken/float16"; globalThis.Float16Array = Float16Array; // ort-web will pick the global Float16Array const myF16Data = new Float16Array(...); // Use the polyfilled Float16Array type const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ```
1 parent 8092a89 commit 58f4921

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

js/common/lib/tensor-impl-type-mapping.ts

+23-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
1414
['uint8', Uint8Array],
1515
['int8', Int8Array],
1616
['uint16', Uint16Array],
17-
['float16', Uint16Array],
1817
['int16', Int16Array],
1918
['int32', Int32Array],
2019
['bool', Uint8Array],
@@ -34,16 +33,22 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
3433
[Uint32Array, 'uint32'],
3534
]);
3635

37-
// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
38-
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
39-
// if available.
40-
let isBigIntChecked = false;
41-
export const checkBigInt = () => {
42-
if (!isBigIntChecked) {
43-
isBigIntChecked = true;
44-
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
45-
const isBigUint64ArrayAvailable =
46-
typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
36+
// a dummy type declaration for Float16Array in case any polyfill is available.
37+
declare global {
38+
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
39+
const Float16Array: any;
40+
}
41+
42+
// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
43+
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
44+
// polyfill if available.
45+
let isTypedArrayChecked = false;
46+
export const checkTypedArray = () => {
47+
if (!isTypedArrayChecked) {
48+
isTypedArrayChecked = true;
49+
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
50+
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
51+
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;
4752

4853
if (isBigInt64ArrayAvailable) {
4954
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
@@ -53,5 +58,12 @@ export const checkBigInt = () => {
5358
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
5459
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
5560
}
61+
if (isFloat16ArrayAvailable) {
62+
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
63+
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
64+
} else {
65+
// if Float16Array is not available, use 'Uint16Array' to store the data.
66+
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
67+
}
5668
}
5769
};

js/common/lib/tensor-impl.ts

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
55
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
66
import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
77
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
8-
import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
8+
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
99
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
1010
import {Tensor as TensorInterface} from './tensor.js';
1111

@@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
6767
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
6868
TextureConstructorParameters|GpuBufferConstructorParameters,
6969
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
70-
// perform one-time check for BigInt support
71-
checkBigInt();
70+
// perform one-time check for BigInt/Float16Array support
71+
checkTypedArray();
7272

7373
let type: TensorType;
7474
let dims: readonly number[];
@@ -142,7 +142,9 @@ export class Tensor implements TensorInterface {
142142
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
143143
}
144144
if (Array.isArray(arg1)) {
145-
if (arg0 === 'float16') {
145+
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
146+
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
147+
//
146148
// Throw error here because when user try to use number array as data,
147149
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
148150
// Uint16Array.from(arg1) which generates wrong data.

js/web/lib/wasm/wasm-common.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
import {Tensor} from 'onnxruntime-common';
55

6+
// a dummy type declaration for Float16Array in case any polyfill is available.
7+
declare global {
8+
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
9+
const Float16Array: any;
10+
}
11+
612
// This file includes common definitions. They do NOT have dependency on the WebAssembly instance.
713

814
/**
@@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr
117123
Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
118124
switch (type) {
119125
case 'float16':
120-
return Uint16Array;
126+
// allow Float16Array polyfill.
127+
return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array;
121128
case 'float32':
122129
return Float32Array;
123130
case 'uint8':

0 commit comments

Comments
 (0)