Skip to content

Commit 086e9d8

Browse files
Support loading models with weights above 2GB on Chrome (#7609)
Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer. This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.
1 parent dcd3b43 commit 086e9d8

22 files changed

+205
-136
lines changed

Diff for: tfjs-converter/src/executor/graph_model_test.ts

+6-3
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ describe('Model', () => {
586586
expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL);
587587
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
588588
tfc.test_util.expectArraysClose(
589-
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
589+
new Int32Array(io.CompositeArrayBuffer.join(
590+
handler.savedArtifacts.weightData)), bias.dataSync());
590591
});
591592
});
592593
});
@@ -616,7 +617,8 @@ describe('Model', () => {
616617
});
617618
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
618619
tfc.test_util.expectArraysClose(
619-
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
620+
new Int32Array(io.CompositeArrayBuffer.join(
621+
handler.savedArtifacts.weightData)), bias.dataSync());
620622
});
621623
});
622624

@@ -904,7 +906,8 @@ describe('Model', () => {
904906
});
905907
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
906908
tfc.test_util.expectArraysClose(
907-
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
909+
new Int32Array(io.CompositeArrayBuffer.join(handler.savedArtifacts
910+
.weightData)), bias.dataSync());
908911
});
909912
});
910913

Diff for: tfjs-converter/src/operations/executors/spy_ops.ts

+13-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,23 @@
1515
* =============================================================================
1616
*/
1717

18+
// The opposite of Extract<T, U>
19+
type Without<T, U> = T extends U ? never : T;
20+
21+
// Do not spy on CompositeArrayBuffer because it is a class constructor.
22+
type NotSpiedOn = 'CompositeArrayBuffer';
23+
1824
export type RecursiveSpy<T> =
19-
T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy<T[K]>};
25+
T extends Function ? jasmine.Spy :
26+
{[K in Without<keyof T, NotSpiedOn>]: RecursiveSpy<T[K]>} &
27+
{[K in Extract<keyof T, NotSpiedOn>]: T[K]};
2028

2129
export function spyOnAllFunctions<T>(obj: T): RecursiveSpy<T> {
2230
return Object.fromEntries(Object.entries(obj).map(([key, val]) => {
31+
// TODO(mattSoulanille): Do not hard code this
32+
if (key === 'CompositeArrayBuffer') {
33+
return val;
34+
}
2335
if (val instanceof Function) {
2436
return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()];
2537
} else if (val instanceof Array) {

Diff for: tfjs-core/src/io/browser_files.ts

+11-5
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
import '../flags';
2424
import {env} from '../environment';
2525

26-
import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
26+
import {basename, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
2727
import {IORouter, IORouterRegistry} from './router_registry';
28-
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
29+
import {CompositeArrayBuffer} from './composite_array_buffer';
2930

3031
const DEFAULT_FILE_NAME_PREFIX = 'model';
3132
const DEFAULT_JSON_EXTENSION_NAME = '.json';
@@ -70,8 +71,13 @@ export class BrowserDownloads implements IOHandler {
7071
'Browser downloads are not supported in ' +
7172
'this environment since `document` is not present');
7273
}
74+
75+
// TODO(mattsoulanille): Support saving models over 2GB that exceed
76+
// Chrome's ArrayBuffer size limit.
77+
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
78+
7379
const weightsURL = window.URL.createObjectURL(new Blob(
74-
[modelArtifacts.weightData], {type: 'application/octet-stream'}));
80+
[weightBuffer], {type: 'application/octet-stream'}));
7581

7682
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
7783
throw new Error(
@@ -169,7 +175,7 @@ class BrowserFiles implements IOHandler {
169175
}
170176

171177
private loadWeights(weightsManifest: WeightsManifestConfig): Promise<[
172-
/* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer
178+
/* weightSpecs */ WeightsManifestEntry[], WeightData,
173179
]> {
174180
const weightSpecs: WeightsManifestEntry[] = [];
175181
const paths: string[] = [];
@@ -185,7 +191,7 @@ class BrowserFiles implements IOHandler {
185191
paths.map(path => this.loadWeightsFile(path, pathToFile[path]));
186192

187193
return Promise.all(promises).then(
188-
buffers => [weightSpecs, concatenateArrayBuffers(buffers)]);
194+
buffers => [weightSpecs, buffers]);
189195
}
190196

191197
private loadWeightsFile(path: string, file: File): Promise<ArrayBuffer> {

Diff for: tfjs-core/src/io/browser_files_test.ts

+14-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import * as tf from '../index';
2323
import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util';
2424
import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files';
2525
import {WeightsManifestConfig, WeightsManifestEntry} from './types';
26+
import {CompositeArrayBuffer} from './composite_array_buffer';
2627

2728
const modelTopology1: {} = {
2829
'class_name': 'Sequential',
@@ -310,7 +311,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
310311
expect(modelArtifacts.modelInitializer).toEqual({});
311312
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);
312313

313-
expect(new Uint8Array(modelArtifacts.weightData))
314+
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
314315
.toEqual(new Uint8Array(weightData1));
315316
});
316317

@@ -351,9 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
351352
const modelArtifacts = await filesHandler.load();
352353
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
353354
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs);
354-
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
355-
1, 2, 3, 4, 10, 20, 30, 40
356-
]));
355+
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
356+
.toEqual(new Uint8Array([
357+
1, 2, 3, 4, 10, 20, 30, 40
358+
]));
357359
});
358360

359361
it(`Two groups, four paths, reverseOrder=false`, async () => {
@@ -418,9 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
418420
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
419421
expect(modelArtifacts.weightSpecs)
420422
.toEqual(weightSpecs1.concat(weightSpecs2));
421-
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
422-
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
423-
]));
423+
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
424+
.toEqual(new Uint8Array([
425+
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
426+
]));
424427
});
425428

426429
it(`Two groups, four paths, reverseOrder=true`, async () => {
@@ -485,9 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
485488
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
486489
expect(modelArtifacts.weightSpecs)
487490
.toEqual(weightSpecs1.concat(weightSpecs2));
488-
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
489-
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
490-
]));
491+
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
492+
.toEqual(new Uint8Array([
493+
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
494+
]));
491495
});
492496

493497
it('Upload model topology only', async () => {

Diff for: tfjs-core/src/io/composite_array_buffer.ts

+23-3
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,22 @@ export class CompositeArrayBuffer {
4040
private bufferUniformSize?: number;
4141
public readonly byteLength: number;
4242

43-
constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray |
43+
/**
44+
* Concatenate a number of ArrayBuffers into one.
45+
*
46+
* @param buffers An array of ArrayBuffers to concatenate, or a single
47+
* ArrayBuffer.
48+
* @returns Result of concatenating `buffers` in order.
49+
*/
50+
static join(buffers?: ArrayBuffer[] | ArrayBuffer) {
51+
return new CompositeArrayBuffer(buffers).slice();
52+
}
53+
54+
constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray |
4455
TypedArray[]) {
56+
if (buffers == null) {
57+
return;
58+
}
4559
// Normalize the `buffers` input to be `ArrayBuffer[]`.
4660
if (!(buffers instanceof Array)) {
4761
buffers = [buffers];
@@ -85,6 +99,12 @@ export class CompositeArrayBuffer {
8599
}
86100

87101
slice(start = 0, end = this.byteLength): ArrayBuffer {
102+
// If there are no shards, then the CompositeArrayBuffer was initialized
103+
// with no data.
104+
if (this.shards.length === 0) {
105+
return new ArrayBuffer(0);
106+
}
107+
88108
// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
89109
start = isNaN(Number(start)) ? 0 : start;
90110
end = isNaN(Number(end)) ? 0 : end;
@@ -117,8 +137,8 @@ export class CompositeArrayBuffer {
117137
const globalEnd = Math.min(end, shard.end);
118138
const localEnd = globalEnd - shard.start;
119139

120-
const outputSlice = new Uint8Array(shard.buffer.slice(localStart,
121-
localEnd));
140+
const outputSlice = new Uint8Array(shard.buffer, localStart,
141+
localEnd - localStart);
122142
outputArray.set(outputSlice, outputStart);
123143
sliced += outputSlice.length;
124144

Diff for: tfjs-core/src/io/composite_array_buffer_test.ts

+13-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ describe('CompositeArrayBuffer', () => {
8080
});
8181
}
8282

83-
it('can be passed an empty arraybuffer', () => {
83+
it('can be created from an empty arraybuffer', () => {
8484
const array = new Uint8Array([]);
8585
const singleComposite = new CompositeArrayBuffer(array.buffer);
8686
expectArraysEqual(new Uint8Array(singleComposite.slice()), []);
@@ -92,6 +92,18 @@ describe('CompositeArrayBuffer', () => {
9292
expectArraysEqual(new Uint8Array(singleComposite.slice()), array);
9393
});
9494

95+
it('can be created from zero arrays', () => {
96+
const singleComposite = new CompositeArrayBuffer([]);
97+
expectArraysEqual(new Uint8Array(singleComposite.slice()),
98+
new Uint8Array());
99+
});
100+
101+
it('can be created from undefined input', () => {
102+
const singleComposite = new CompositeArrayBuffer();
103+
expectArraysEqual(new Uint8Array(singleComposite.slice()),
104+
new Uint8Array());
105+
});
106+
95107
it('treats NaN as zero when passed as the start of slice', () => {
96108
const array = new Uint8Array([1,2,3]);
97109
const composite = new CompositeArrayBuffer(array.buffer);

Diff for: tfjs-core/src/io/http.ts

+10-5
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
import {env} from '../environment';
2525

2626
import {assert} from '../util';
27-
import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
27+
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
28+
import {CompositeArrayBuffer} from './composite_array_buffer';
2829
import {IORouter, IORouterRegistry} from './router_registry';
29-
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
30+
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
3031
import {loadWeightsAsArrayBuffer} from './weights_loader';
3132

3233
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
@@ -110,9 +111,13 @@ export class HTTPRequest implements IOHandler {
110111
'model.json');
111112

112113
if (modelArtifacts.weightData != null) {
114+
// TODO(mattsoulanille): Support saving models over 2GB that exceed
115+
// Chrome's ArrayBuffer size limit.
116+
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
117+
113118
init.body.append(
114119
'model.weights.bin',
115-
new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}),
120+
new Blob([weightBuffer], {type: OCTET_STREAM_MIME_TYPE}),
116121
'model.weights.bin');
117122
}
118123

@@ -182,7 +187,7 @@ export class HTTPRequest implements IOHandler {
182187
}
183188

184189
private async loadWeights(weightsManifest: WeightsManifestConfig):
185-
Promise<[WeightsManifestEntry[], ArrayBuffer]> {
190+
Promise<[WeightsManifestEntry[], WeightData]> {
186191
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
187192
const [prefix, suffix] = parseUrl(weightPath);
188193
const pathPrefix = this.weightPathPrefix || prefix;
@@ -210,7 +215,7 @@ export class HTTPRequest implements IOHandler {
210215
fetchFunc: this.fetch,
211216
onProgress: this.onProgress
212217
});
213-
return [weightSpecs, concatenateArrayBuffers(buffers)];
218+
return [weightSpecs, buffers];
214219
}
215220
}
216221

Diff for: tfjs-core/src/io/http_test.ts

+25-17
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import * as tf from '../index';
1919
import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
2020
import {HTTPRequest, httpRouter, parseUrl} from './http';
21+
import {CompositeArrayBuffer} from './composite_array_buffer';
2122

2223
// Test data.
2324
const modelTopology1: {} = {
@@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => {
161162
expect(modelArtifacts.generatedBy).toEqual('1.15');
162163
expect(modelArtifacts.convertedBy).toEqual('1.3.1');
163164
expect(modelArtifacts.userDefinedMetadata).toEqual({});
164-
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
165+
expect(new Float32Array(CompositeArrayBuffer.join(
166+
modelArtifacts.weightData))).toEqual(floatData);
165167
});
166168

167169
it('throw exception if no fetch polyfill', () => {
@@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
507509
expect(modelArtifacts.userDefinedMetadata).toEqual({});
508510
expect(modelArtifacts.modelInitializer).toEqual({});
509511

510-
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
512+
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
513+
.weightData))).toEqual(floatData);
511514
expect(Object.keys(requestInits).length).toEqual(2);
512515
// Assert that fetch is invoked with `window` as the context.
513516
expect(fetchSpy.calls.mostRecent().object).toEqual(window);
@@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
550553
const modelArtifacts = await handler.load();
551554
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
552555
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
553-
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
556+
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
557+
.weightData))).toEqual(floatData);
554558
expect(Object.keys(requestInits).length).toEqual(2);
555559
expect(Object.keys(requestInits).length).toEqual(2);
556560
expect(requestInits['./model.json'].headers['header_key_1'])
@@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
599603
const modelArtifacts = await handler.load();
600604
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
601605
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
602-
expect(new Float32Array(modelArtifacts.weightData))
603-
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
606+
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
607+
.weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4]));
604608
});
605609

606610
it('2 groups, 2 weight, 2 paths', async () => {
@@ -644,8 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
644648
expect(modelArtifacts.weightSpecs)
645649
.toEqual(
646650
weightsManifest[0].weights.concat(weightsManifest[1].weights));
647-
expect(new Float32Array(modelArtifacts.weightData))
648-
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
651+
expect(new Float32Array(CompositeArrayBuffer.join(
652+
modelArtifacts.weightData)))
653+
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
649654
});
650655

651656
it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => {
@@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
689694
expect(modelArtifacts.weightSpecs)
690695
.toEqual(
691696
weightsManifest[0].weights.concat(weightsManifest[1].weights));
692-
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
693-
.toEqual(new Int32Array([1, 3, 3]));
694-
expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14)))
695-
.toEqual(new Uint8Array([7, 4]));
697+
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
698+
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
699+
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
700+
.slice(12, 14))).toEqual(new Uint8Array([7, 4]));
696701
});
697702

698703
it('topology only', async () => {
@@ -752,10 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
752757
expect(modelArtifacts.weightSpecs)
753758
.toEqual(
754759
weightsManifest[0].weights.concat(weightsManifest[1].weights));
755-
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
756-
.toEqual(new Int32Array([1, 3, 3]));
757-
expect(new Float32Array(modelArtifacts.weightData.slice(12, 20)))
758-
.toEqual(new Float32Array([-7, -4]));
760+
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
761+
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
762+
expect(new Float32Array(CompositeArrayBuffer
763+
.join(modelArtifacts.weightData)
764+
.slice(12, 20))).toEqual(new Float32Array([-7, -4]));
759765
});
760766

761767
it('Missing modelTopology and weightsManifest leads to error', async () => {
@@ -840,7 +846,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
840846
const modelArtifacts = await handler.load();
841847
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
842848
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
843-
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
849+
expect(new Float32Array(CompositeArrayBuffer.join(
850+
modelArtifacts.weightData))).toEqual(floatData);
844851
expect(Object.keys(requestInits).length).toEqual(2);
845852
expect(Object.keys(requestInits).length).toEqual(2);
846853
expect(requestInits['./model.json'].headers['header_key_1'])
@@ -902,7 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
902909
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
903910
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);
904911
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
905-
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
912+
expect(new Float32Array(CompositeArrayBuffer
913+
.join(modelArtifacts.weightData))).toEqual(floatData);
906914

907915
expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
908916
expect(fetchInits.length).toEqual(2);

0 commit comments

Comments
 (0)