Skip to content

Commit 85cef0a

Browse files
authored
[js/webgpu] Support capture and replay for jsep (#18989)
### Description This PR expands the graph capture capability to JS EP, which is similar to #16081. But for JS EP, we don't use the CUDA Graph, instead, we records all gpu commands and replay them, which removes most of the cpu overhead to avoid the the situation that gpu waiting for cpu. mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from 4.58ms on Intel A770. All limitations are similar with CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the JS EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IObinding is required. The usage is like below: Method 1: specify outputs buffers explicitly. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer/outputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; const fetches = { 'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] }) }; let results = await session.run(feeds, fetches); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds, fetches); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); ``` Method 2: Don't specify outputs buffers explicitly. Internally, when graph capture is enabled, it will set all outputs location to 'gpu-buffer'. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; let results = await session.run(feeds); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release();
1 parent 6dd0079 commit 85cef0a

16 files changed

+436
-136
lines changed

js/common/lib/inference-session.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ export declare namespace InferenceSession {
111111
optimizedModelFilePath?: string;
112112

113113
/**
114-
* Wether enable profiling.
114+
* Whether enable profiling.
115115
*
116116
* This setting is a placeholder for a future use.
117117
*/
@@ -154,6 +154,12 @@ export declare namespace InferenceSession {
154154
*/
155155
preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};
156156

157+
/**
158+
* Whether enable graph capture.
159+
* This setting is available only in ONNXRuntime Web for WebGPU EP.
160+
*/
161+
enableGraphCapture?: boolean;
162+
157163
/**
158164
* Store configurations for a session. See
159165
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/

js/web/lib/wasm/binding/ort-wasm.d.ts

+16-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ export declare namespace JSEP {
1313
type ReleaseKernelFunction = (kernel: number) => void;
1414
type RunFunction =
1515
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
16+
type CaptureBeginFunction = () => void;
17+
type CaptureEndFunction = () => void;
18+
type ReplayFunction = () => void;
1619
}
1720

1821
export interface OrtWasmModule extends EmscriptenModule {
@@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule {
128131
jsepInit?
129132
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
130133
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
131-
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
134+
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
135+
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;
132136

133137
/**
134138
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
@@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule {
158162
* @returns the GPU data ID for the registered GPU buffer.
159163
*/
160164
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
161-
/**
162-
* [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
163-
*
164-
* @param sessionId - specify the session ID.
165-
*/
166-
jsepUnregisterBuffers?: (sessionId: number) => void;
167165
/**
168166
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
169167
*
@@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule {
183181
(gpuBuffer: GPUBuffer, size: number,
184182
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
185183
/**
186-
* [exported from js_internal_api.js] Called when InferenceSession.run started.
184+
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
185+
* _OrtRun[WithBinding]() is called.
186+
* @param sessionId - specify the session ID.
187+
*/
188+
jsepOnRunStart: (sessionId: number) => void;
189+
/**
190+
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
191+
* called.
192+
* @param sessionId - specify the session ID.
193+
* @returns
187194
*/
188-
jsepOnRunStart: () => void;
195+
jsepOnReleaseSession: (sessionId: number) => void;
189196
// #endregion
190197
}
191198

js/web/lib/wasm/jsep/backend-webgpu.ts

+96-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view';
1010
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
1111
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
1212
import {ProgramManager} from './webgpu/program-manager';
13-
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types';
13+
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
14+
15+
interface CommandInfo {
16+
readonly kernelId: number;
17+
readonly computePipeline: GPUComputePipeline;
18+
readonly bindGroup: GPUBindGroup;
19+
readonly dispatchGroup: [number, number, number];
20+
}
1421

1522
interface KernelInfo {
1623
readonly kernelType: string;
@@ -103,6 +110,13 @@ export class WebGpuBackend {
103110
*/
104111
programManager: ProgramManager;
105112

113+
/**
114+
* representing the session ID of which is currently being run.
115+
* `null` means no session is being run.
116+
* only valid when session.run is executed.
117+
*/
118+
currentSessionId: number|null = null;
119+
106120
/**
107121
* representing the kernel ID of which is currently being computed (CPU code perspective).
108122
* `null` means no kernel is being computed.
@@ -155,6 +169,16 @@ export class WebGpuBackend {
155169
queryType: TimestampQuery;
156170

157171
env: Env;
172+
sessionStatus: SessionState = 'default';
173+
/**
174+
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
175+
*/
176+
capturedCommandList: Map<number, CommandInfo[]> = new Map();
177+
178+
/**
179+
* a SessionID -> PendingKernelInfo[] mapping for profiling.
180+
*/
181+
private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();
158182

159183
/**
160184
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
@@ -228,6 +252,7 @@ export class WebGpuBackend {
228252

229253
getComputePassEncoder(): GPUComputePassEncoder {
230254
if (!this.computePassEncoder) {
255+
const commandEncoder = this.getCommandEncoder();
231256
const computePassDescriptor: GPUComputePassDescriptor = {};
232257

233258
if (this.queryType === 'at-passes') {
@@ -238,7 +263,7 @@ export class WebGpuBackend {
238263
};
239264
}
240265

241-
this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor);
266+
this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
242267
}
243268
return this.computePassEncoder;
244269
}
@@ -494,14 +519,17 @@ export class WebGpuBackend {
494519
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
495520
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);
496521

497-
if (this.queryType !== 'none') {
522+
if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
498523
const pendingKernelInfo: PendingKernelInfo = {
499524
kernelId: this.currentKernelId!,
500525
programName: artifact.programInfo.name,
501526
inputTensorViews,
502527
outputTensorViews,
503528
};
504529
this.pendingKernels.push(pendingKernelInfo);
530+
531+
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
532+
sessionPendingKernels!.push(pendingKernelInfo);
505533
}
506534

507535
this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
@@ -672,7 +700,71 @@ export class WebGpuBackend {
672700
}
673701
}
674702
}
675-
onRunStart(): void {
703+
704+
captureBegin(): void {
705+
LOG_DEBUG('info', 'captureBegin');
706+
let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
707+
let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
708+
if (!sessionCommandList) {
709+
sessionCommandList = [];
710+
this.capturedCommandList.set(this.currentSessionId!, sessionCommandList);
711+
sessionPendingKernels = [];
712+
this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels);
713+
}
714+
// flush the left commands before we change the status.
715+
this.flush();
716+
this.sessionStatus = 'capturing';
717+
}
718+
captureEnd(): void {
719+
LOG_DEBUG('info', 'captureEnd');
720+
// flush the left commands before we change the status.
721+
this.flush();
722+
this.sessionStatus = 'default';
723+
}
724+
replay(): void {
725+
LOG_DEBUG('info', 'replay');
726+
this.sessionStatus = 'replaying';
727+
const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
728+
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
729+
const length = sessionCommandList!.length;
730+
this.pendingKernels = [];
731+
for (let i = 0; i < length; i++) {
732+
const computePassEncoder = this.getComputePassEncoder();
733+
const command = sessionCommandList![i];
734+
this.writeTimestamp(this.pendingDispatchNumber * 2);
735+
computePassEncoder.setPipeline(command.computePipeline);
736+
computePassEncoder.setBindGroup(0, command.bindGroup);
737+
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
738+
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
739+
this.pendingDispatchNumber++;
740+
if (this.queryType !== 'none') {
741+
this.pendingKernels.push(sessionPendingKernels![i]);
742+
}
743+
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
744+
this.endComputePass();
745+
}
746+
if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
747+
this.flush();
748+
}
749+
}
750+
// flush the left commands before we change the status.
751+
this.flush();
752+
this.sessionStatus = 'default';
753+
}
754+
755+
onReleaseSession(sessionId: number): void {
756+
this.unregisterBuffers(sessionId);
757+
if (this.capturedCommandList.has(sessionId)) {
758+
this.capturedCommandList.delete(sessionId);
759+
}
760+
if (this.capturedPendingKernels.has(sessionId)) {
761+
this.capturedPendingKernels.delete(sessionId);
762+
}
763+
this.gpuDataManager.onReleaseSession(sessionId);
764+
}
765+
766+
onRunStart(sessionId: number): void {
767+
this.currentSessionId = sessionId;
676768
this.setQueryType();
677769
}
678770
}

js/web/lib/wasm/jsep/init.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
201201
contextDataOffset}`);
202202
const context = new ComputeContextImpl(module, backend, contextDataOffset);
203203
return backend.computeKernel(kernel, context, errors);
204-
});
204+
},
205+
// jsepCaptureBegin
206+
() => backend.captureBegin(),
207+
// jsepCaptureEnd
208+
() => backend.captureEnd(),
209+
// jsepReplay
210+
() => backend.replay());
205211
};

js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts

+62-12
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,15 @@ export interface GpuDataManager {
6060
unregisterExternalBuffer(buffer: GPUBuffer): void;
6161

6262
/**
63-
* destroy all gpu buffers. Call this when the session.release is called.
63+
* destroy all gpu buffers.
6464
*/
6565
dispose(): void;
66+
67+
/**
68+
* release session related data.
69+
* @param sessionId - specify the session ID.
70+
*/
71+
onReleaseSession(sessionId: number): void;
6672
}
6773

6874
interface StorageCacheValue {
@@ -139,13 +145,18 @@ class GpuDataManagerImpl implements GpuDataManager {
139145
// The external buffers registered users for IO Binding.
140146
private externalBuffers: Map<GPUBuffer, GpuDataId>;
141147

148+
// The pendingBuffers for capture graph.
149+
// a SessionID -> GPUBuffer[] mapping.
150+
private capturedPendingBuffers: Map<number, GPUBuffer[]>;
151+
142152
constructor(private backend: WebGpuBackend) {
143153
this.storageCache = new Map();
144154
this.freeBuffers = new Map();
145155
this.freeUniformBuffers = new Map();
146156
this.buffersForUploadingPending = [];
147157
this.buffersPending = [];
148158
this.externalBuffers = new Map();
159+
this.capturedPendingBuffers = new Map();
149160
}
150161

151162
upload(id: GpuDataId, data: Uint8Array): void {
@@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager {
220231
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
221232
id}, buffer is the same, skip.`);
222233
return id;
234+
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
235+
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
236+
Please use the previous external buffer!`);
223237
}
224238
this.externalBuffers.delete(previousBuffer);
225239
} else {
@@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager {
312326
buffer.destroy();
313327
}
314328
this.buffersForUploadingPending = [];
315-
for (const buffer of this.buffersPending) {
316-
// eslint-disable-next-line no-bitwise
317-
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
318-
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
319-
this.freeBuffers.get(buffer.size)!.push(buffer);
329+
330+
if (this.buffersPending.length === 0) {
331+
return;
332+
}
333+
334+
if (this.backend.sessionStatus === 'default') {
335+
for (const buffer of this.buffersPending) {
320336
// eslint-disable-next-line no-bitwise
321-
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
322-
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
323-
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
324-
} else {
325-
buffer.destroy();
337+
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
338+
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
339+
this.freeBuffers.get(buffer.size)!.push(buffer);
340+
// eslint-disable-next-line no-bitwise
341+
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
342+
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
343+
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
344+
} else {
345+
buffer.destroy();
346+
}
347+
}
348+
this.buffersPending = [];
349+
} else {
350+
// Don't release intermediate tensors in non-default mode.
351+
// TODO: reuse the storage buffers in non-default mode.
352+
let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
353+
if (!capturedBuffers) {
354+
capturedBuffers = [];
355+
this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
326356
}
357+
for (const buffer of this.buffersPending) {
358+
capturedBuffers.push(buffer);
359+
}
360+
this.buffersPending = [];
327361
}
328-
this.buffersPending = [];
329362
}
330363

331364
dispose() {
@@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager {
344377
storage.gpuData.buffer.destroy();
345378
});
346379

380+
this.capturedPendingBuffers.forEach((buffers) => {
381+
buffers.forEach(buffer => {
382+
buffer.destroy();
383+
});
384+
});
347385
this.storageCache = new Map();
348386
this.freeBuffers = new Map();
349387
this.freeUniformBuffers = new Map();
388+
this.capturedPendingBuffers = new Map();
389+
}
390+
391+
onReleaseSession(sessionId: number) {
392+
// release the captured pending buffers.
393+
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
394+
if (pendingBuffers) {
395+
pendingBuffers.forEach(buffer => {
396+
buffer.destroy();
397+
});
398+
this.capturedPendingBuffers.delete(sessionId);
399+
}
350400
}
351401
}
352402

js/web/lib/wasm/jsep/webgpu/program-manager.ts

+13-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ export class ProgramManager {
3838
const device = this.backend.device;
3939
const computePassEncoder = this.backend.getComputePassEncoder();
4040
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
41-
computePassEncoder.setPipeline(buildArtifact.computePipeline);
4241
const entries = [];
4342
for (const input of inputs) {
4443
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
@@ -51,8 +50,20 @@ export class ProgramManager {
5150
}
5251
const bindGroup = device.createBindGroup(
5352
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});
54-
computePassEncoder.setBindGroup(0, bindGroup);
5553

54+
if (this.backend.sessionStatus === 'capturing') {
55+
const commandInfo = {
56+
kernelId: this.backend.currentKernelId!,
57+
computePipeline: buildArtifact.computePipeline,
58+
bindGroup,
59+
dispatchGroup
60+
};
61+
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
62+
sessionCommandList!.push(commandInfo);
63+
}
64+
65+
computePassEncoder.setPipeline(buildArtifact.computePipeline);
66+
computePassEncoder.setBindGroup(0, bindGroup);
5667
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
5768
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
5869
this.backend.pendingDispatchNumber++;

0 commit comments

Comments
 (0)