Skip to content

Commit 01c7aaf

Browse files
authored
[js/webgpu] allow setting env.webgpu.adapter (#19940)
### Description Allow user to set `env.webgpu.adapter` before creating the first inference session. Feature request: #19857 (comment) @xenova
1 parent 8293aa1 commit 01c7aaf

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

js/common/lib/env.ts

+7-3
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,20 @@ export declare namespace Env {
166166
*/
167167
forceFallbackAdapter?: boolean;
168168
/**
169-
* Get the adapter for WebGPU.
169+
* Set or get the adapter for WebGPU.
170170
*
171-
* This property is only available after the first WebGPU inference session is created.
171+
* Setting this property only has effect before the first WebGPU inference session is created. The value will be
172+
* used as the GPU adapter for the underlying WebGPU backend to create GPU device.
173+
*
174+
* If this property is not set, it will be available to get after the first WebGPU inference session is created. The
175+
* value will be the GPU adapter that created by the underlying WebGPU backend.
172176
*
173177
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
174178
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
175179
*
176180
* see comments on {@link Tensor.GpuBufferType}
177181
*/
178-
readonly adapter: unknown;
182+
adapter: unknown;
179183
/**
180184
* Get the device for WebGPU.
181185
*

js/web/lib/wasm/jsep/backend-webgpu.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,10 @@ export class WebGpuBackend {
252252
}
253253
};
254254

255-
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
256-
Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter});
255+
Object.defineProperty(
256+
this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
257+
Object.defineProperty(
258+
this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});
257259

258260
// init queryType, which is necessary for InferenceSession.create
259261
this.setQueryType();

js/web/lib/wasm/wasm-core-impl.ts

+24-11
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
9393
if (typeof navigator === 'undefined' || !navigator.gpu) {
9494
throw new Error('WebGPU is not supported in current environment');
9595
}
96-
const powerPreference = env.webgpu?.powerPreference;
97-
if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') {
98-
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
99-
}
100-
const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter;
101-
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
102-
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
103-
}
104-
const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
96+
97+
let adapter = env.webgpu.adapter as GPUAdapter | null;
10598
if (!adapter) {
106-
throw new Error(
107-
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
99+
// if adapter is not set, request a new adapter.
100+
const powerPreference = env.webgpu.powerPreference;
101+
if (powerPreference !== undefined && powerPreference !== 'low-power' &&
102+
powerPreference !== 'high-performance') {
103+
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
104+
}
105+
const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
106+
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
107+
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
108+
}
109+
adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
110+
if (!adapter) {
111+
throw new Error(
112+
'Failed to get GPU adapter. ' +
113+
'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
114+
}
115+
} else {
116+
// if adapter is set, validate it.
117+
if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' ||
118+
typeof adapter.requestDevice !== 'function') {
119+
throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
120+
}
108121
}
109122

110123
if (!env.wasm.simd) {

0 commit comments

Comments
 (0)