Skip to content

Commit c44d497

Browse files
qjia7fs-eire
authored andcommitted
[js/webgpu] set query type in onRunStart (#19202)
### Description <!-- Describe your changes. --> `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](e630dbf#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268) and [2)](e630dbf#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run.
1 parent a24273e commit c44d497

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

js/web/lib/wasm/binding/ort-wasm.d.ts

+4
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
182182
jsepCreateDownloader:
183183
(gpuBuffer: GPUBuffer, size: number,
184184
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
185+
/**
186+
* [exported from js_internal_api.js] Called when InferenceSession.run started.
187+
*/
188+
jsepOnRunStart: () => void;
185189
// #endregion
186190
}
187191

js/web/lib/wasm/jsep/backend-webgpu.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ export class WebGpuBackend {
208208

209209
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
210210

211-
// init queryType, which is necessary for createKernel
211+
// init queryType, which is necessary for InferenceSession.create
212212
this.setQueryType();
213213
}
214214

@@ -223,8 +223,6 @@ export class WebGpuBackend {
223223
if (!this.commandEncoder) {
224224
this.commandEncoder = this.device.createCommandEncoder();
225225

226-
// refresh queryType, as sometimes we only need to enable query for a specific run
227-
this.setQueryType();
228226
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
229227
this.querySet = this.device.createQuerySet({
230228
type: 'timestamp',
@@ -639,6 +637,7 @@ export class WebGpuBackend {
639637
return createView(data.buffer, type);
640638
};
641639
}
640+
// #endregion
642641
writeTimestamp(index: number): void {
643642
if (this.queryType !== 'inside-passes') {
644643
return;
@@ -657,5 +656,7 @@ export class WebGpuBackend {
657656
}
658657
}
659658
}
660-
// #endregion
659+
onRunStart(): void {
660+
this.setQueryType();
661+
}
661662
}

js/web/lib/wasm/wasm-core-impl.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,8 @@ export const run = async(
488488
}
489489
}
490490

491+
wasm.jsepOnRunStart?.();
491492
let errorCode: number;
492-
493493
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
494494
errorCode = await wasm._OrtRunWithBinding(
495495
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);

onnxruntime/wasm/js_internal_api.js

+3
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
186186
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
187187
return backend['createDownloader'](gpuBuffer, size, type);
188188
};
189+
Module['jsepOnRunStart'] = () => {
190+
return backend['onRunStart']();
191+
};
189192
};

0 commit comments

Comments
 (0)