Skip to content

Commit 7a3e61a

Browse files
Fix missing isTypedArray when mixing versions of @tensorflow packages (#7489)
A new function, `isTypedArray` was added to the `platform` interface by #7181 and first published in tfjs-core 4.2.0. This made 4.2.0 incompatible with earlier versions of backends that implemented `platform`, such as node and react-native. This change adds a fallback to the use of `isTypedArray` so earlier versions of platforms that don't implement `isTypedArray` will not throw an error. Note that the fallback behavior may not be perfect, such as when running Jest tests in node. See #7175 for more details and upgrade all @tensorflow scoped packages to ^4.2.0 to avoid this.
1 parent 8fc6fe9 commit 7a3e61a

14 files changed

+348
-309
lines changed

Diff for: tfjs-core/src/platforms/is_typed_array_browser.ts

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
export function isTypedArrayBrowser(a: unknown): a is Uint8Array
19+
| Float32Array | Int32Array | Uint8ClampedArray {
20+
return a instanceof Float32Array || a instanceof Int32Array ||
21+
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
22+
}

Diff for: tfjs-core/src/platforms/platform_browser.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {BrowserLocalStorage, BrowserLocalStorageManager} from '../io/local_stora
2323
import {ModelStoreManagerRegistry} from '../io/model_management';
2424

2525
import {Platform} from './platform';
26+
import {isTypedArrayBrowser} from './is_typed_array_browser';
2627

2728
export class PlatformBrowser implements Platform {
2829
// According to the spec, the built-in encoder can do only UTF-8 encoding.
@@ -93,8 +94,7 @@ export class PlatformBrowser implements Platform {
9394

9495
isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array
9596
| Uint8ClampedArray {
96-
return a instanceof Float32Array || a instanceof Int32Array ||
97-
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
97+
return isTypedArrayBrowser(a);
9898
}
9999
}
100100

Diff for: tfjs-core/src/util.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
import {env} from './environment';
19+
import {isTypedArrayBrowser} from './platforms/is_typed_array_browser';
1920
import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types';
2021
import * as base from './util_base';
2122
export * from './util_base';
@@ -134,7 +135,12 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {
134135

135136
export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
136137
Uint8ClampedArray {
137-
return env().platform.isTypedArray(a);
138+
// TODO(mattsoulanille): Remove this fallback in 5.0.0
139+
if (env().platform.isTypedArray != null) {
140+
return env().platform.isTypedArray(a);
141+
} else {
142+
return isTypedArrayBrowser(a);
143+
}
138144
}
139145

140146
// NOTE: We explicitly type out what T extends instead of any so that

Diff for: tfjs-core/src/util_test.ts

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {ALL_ENVS, describeWithFlags} from './jasmine_util';
2020
import {complex, scalar, tensor2d} from './ops/ops';
2121
import {inferShape} from './tensor_util_env';
2222
import * as util from './util';
23+
import {env} from './environment';
2324

2425
describe('Util', () => {
2526
it('Correctly gets size from shape', () => {
@@ -133,6 +134,22 @@ describe('Util', () => {
133134
];
134135
expect(inferShape(a, 'string')).toEqual([2, 2, 1]);
135136
});
137+
describe('isTypedArray', () => {
138+
it('checks if a value is a typed array', () => {
139+
expect(util.isTypedArray(new Uint8Array([1,2,3]))).toBeTrue();
140+
expect(util.isTypedArray([1,2,3])).toBeFalse();
141+
});
142+
it('uses fallback if platform is missing isTypedArray', () => {
143+
const tmpIsTypedArray = env().platform.isTypedArray;
144+
try {
145+
env().platform.isTypedArray = null;
146+
expect(util.isTypedArray(new Uint8Array([1,2,3]))).toBeTrue();
147+
expect(util.isTypedArray([1,2,3])).toBeFalse();
148+
} finally {
149+
env().platform.isTypedArray = tmpIsTypedArray;
150+
}
151+
});
152+
});
136153
});
137154

138155
describe('util.flatten', () => {

Diff for: tfjs-node-gpu/package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@
4545
},
4646
"devDependencies": {
4747
"@tensorflow/tfjs-core": "link:../link-package/node_modules/@tensorflow/tfjs-core",
48-
"@types/jasmine": "~2.8.6",
48+
"@types/jasmine": "~4.0.3",
4949
"@types/node": "^10.5.1",
5050
"@types/progress": "^2.0.1",
5151
"@types/rimraf": "~2.0.2",
5252
"@types/yargs": "^13.0.3",
5353
"clang-format": "~1.8.0",
54-
"jasmine": "~3.1.0",
54+
"jasmine": "~4.2.1",
5555
"node-fetch": "~2.6.1",
5656
"nyc": "^15.1.0",
5757
"tmp": "^0.0.33",
58-
"ts-node": "^5.0.1",
58+
"ts-node": "~8.8.2",
5959
"tslint": "~6.1.3",
6060
"tslint-no-circular-imports": "^0.7.0",
6161
"typescript": "4.9.4",

Diff for: tfjs-node/package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,17 @@
4343
},
4444
"devDependencies": {
4545
"@tensorflow/tfjs-core": "link:../link-package/node_modules/@tensorflow/tfjs-core",
46-
"@types/jasmine": "~2.8.6",
46+
"@types/jasmine": "~4.0.3",
4747
"@types/node": "^10.5.1",
4848
"@types/progress": "^2.0.1",
4949
"@types/rimraf": "~2.0.2",
5050
"@types/yargs": "^13.0.3",
5151
"clang-format": "~1.8.0",
52-
"jasmine": "~3.1.0",
52+
"jasmine": "~4.2.1",
5353
"node-fetch": "~2.6.1",
5454
"nyc": "^15.1.0",
5555
"tmp": "^0.0.33",
56-
"ts-node": "^5.0.1",
56+
"ts-node": "~8.8.2",
5757
"tslint": "~6.1.3",
5858
"tslint-no-circular-imports": "^0.7.0",
5959
"typescript": "4.9.4",

Diff for: tfjs-node/src/callbacks.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import * as ProgressBar from 'progress';
2121

2222
import {summaryFileWriter, SummaryFileWriter} from './tensorboard';
2323

24+
type LogFunction = (message: string) => void;
2425
// A helper class created for testing with the jasmine `spyOn` method, which
2526
// operates only on member methods of objects.
2627
// tslint:disable-next-line:no-any
27-
export const progressBarHelper: {ProgressBar: any, log: Function} = {
28+
export const progressBarHelper: {ProgressBar: any, log: LogFunction} = {
2829
ProgressBar,
2930
log: console.log
3031
};

Diff for: tfjs-node/src/image_test.ts

+18-37
Original file line numberDiff line numberDiff line change
@@ -273,50 +273,31 @@ describe('decode images', () => {
273273
.toBe(beforeNumTFTensors + 1);
274274
});
275275

276-
it('throw error if request non int32 dtype', async done => {
277-
try {
278-
const uint8array = await getUint8ArrayFromImage(
279-
'test_objects/images/image_png_test.png');
280-
tf.node.decodeImage(uint8array, 0, 'uint8');
281-
done.fail();
282-
} catch (error) {
283-
expect(error.message)
284-
.toBe(
285-
'decodeImage could only return Tensor of type `int32` for now.');
286-
done();
287-
}
276+
it('throw error if request non int32 dtype', async () => {
277+
const uint8array = await getUint8ArrayFromImage(
278+
'test_objects/images/image_png_test.png');
279+
expect(() => tf.node.decodeImage(uint8array, 0, 'uint8')).toThrowError(
280+
'decodeImage could only return Tensor of type `int32` for now.');
288281
});
289282

290-
it('throw error if decode invalid image type', async done => {
291-
try {
292-
const uint8array = await getUint8ArrayFromImage('package.json');
293-
tf.node.decodeImage(uint8array);
294-
done.fail();
295-
} catch (error) {
296-
expect(error.message)
297-
.toBe(
298-
'Expected image (BMP, JPEG, PNG, or GIF), ' +
299-
'but got unsupported image type');
300-
done();
301-
}
283+
it('throw error if decode invalid image type', async () => {
284+
const uint8array = await getUint8ArrayFromImage('package.json');
285+
expect(() => tf.node.decodeImage(uint8array)).toThrowError(
286+
'Expected image (BMP, JPEG, PNG, or GIF), ' +
287+
'but got unsupported image type');
302288
});
303289

304-
it('throw error if backend is not tensorflow', async done => {
290+
it('throw error if backend is not tensorflow', async () => {
291+
const testBackend = new TestKernelBackend();
292+
registerBackend('fake', () => testBackend);
293+
setBackend('fake');
305294
try {
306-
const testBackend = new TestKernelBackend();
307-
registerBackend('fake', () => testBackend);
308-
setBackend('fake');
309-
310295
const uint8array = await getUint8ArrayFromImage(
311-
'test_objects/images/image_png_test.png');
312-
tf.node.decodeImage(uint8array);
313-
done.fail();
314-
} catch (err) {
315-
expect(err.message)
316-
.toBe(
317-
'Expect the current backend to be "tensorflow", but got "fake"');
296+
'test_objects/images/image_png_test.png');
297+
expect(() => tf.node.decodeImage(uint8array)).toThrowError(
298+
'Expect the current backend to be "tensorflow", but got "fake"');
299+
} finally {
318300
setBackend('tensorflow');
319-
done();
320301
}
321302
});
322303
});

0 commit comments

Comments
 (0)