Skip to content

Commit 341f709

Browse files
committed
split d.ts and fix break
1 parent 1a20c40 commit 341f709

File tree

4 files changed

+146
-127
lines changed

4 files changed

+146
-127
lines changed

js/web/lib/wasm/binding/ort-wasm.d.ts

+126-117
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,97 @@ export declare namespace JSEP {
1616
type CaptureBeginFunction = () => void;
1717
type CaptureEndFunction = () => void;
1818
type ReplayFunction = () => void;
19-
}
2019

21-
export interface OrtWasmModule extends EmscriptenModule {
22-
// #region emscripten functions
23-
stackSave(): number;
24-
stackRestore(stack: number): void;
25-
stackAlloc(size: number): number;
26-
27-
UTF8ToString(offset: number, maxBytesToRead?: number): string;
28-
lengthBytesUTF8(str: string): number;
29-
stringToUTF8(str: string, offset: number, maxBytes: number): void;
30-
// #endregion
20+
export interface Module extends WebGpuModule {
21+
/**
22+
* Mount the external data file to an internal map, which will be used during session initialization.
23+
*
24+
* @param externalDataFilePath - specify the relative path of the external data file.
25+
* @param externalDataFileData - specify the content data.
26+
*/
27+
mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
28+
/**
29+
* Unmount all external data files from the internal map.
30+
*/
31+
unmountExternalData(): void;
32+
33+
/**
34+
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per
35+
* backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and
36+
* registers a few callbacks that will be called in C++ code.
37+
*/
38+
jsepInit(name: 'webgpu', initParams: [
39+
backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction,
40+
download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction,
41+
run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction
42+
]): void;
43+
jsepInit(name: 'webnn', initParams?: never): void;
44+
}
45+
46+
export interface WebGpuModule {
47+
/**
48+
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
49+
*
50+
* @param context - specify the kernel context pointer.
51+
* @param index - specify the index of the output.
52+
* @param data - specify the pointer to encoded data of type and dims.
53+
*/
54+
_JsepOutput(context: number, index: number, data: number): number;
55+
/**
56+
* [exported from wasm] Get name of an operator node.
57+
*
58+
* @param kernel - specify the kernel pointer.
59+
* @returns the pointer to a C-style UTF8 encoded string representing the node name.
60+
*/
61+
_JsepGetNodeName(kernel: number): number;
62+
63+
/**
64+
* [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
65+
*
66+
* @param sessionId - specify the session ID.
67+
* @param index - specify an integer to represent which input/output it is registering for. For input, it is the
68+
* input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
69+
* corresponding to the session's ouputNames.
70+
* @param buffer - specify the GPU buffer to register.
71+
* @param size - specify the original data size in byte.
72+
* @returns the GPU data ID for the registered GPU buffer.
73+
*/
74+
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
75+
/**
76+
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
77+
*
78+
* @param dataId - specify the GPU data ID
79+
* @returns the GPU buffer.
80+
*/
81+
jsepGetBuffer: (dataId: number) => GPUBuffer;
82+
/**
83+
* [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
84+
*
85+
* @param gpuBuffer - specify the GPU buffer
86+
* @param size - specify the original data size in byte.
87+
* @param type - specify the tensor type.
88+
* @returns the generated downloader function.
89+
*/
90+
jsepCreateDownloader:
91+
(gpuBuffer: GPUBuffer, size: number,
92+
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
93+
/**
94+
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
95+
* _OrtRun[WithBinding]() is called.
96+
* @param sessionId - specify the session ID.
97+
*/
98+
jsepOnRunStart: (sessionId: number) => void;
99+
/**
100+
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
101+
* called.
102+
* @param sessionId - specify the session ID.
103+
* @returns
104+
*/
105+
jsepOnReleaseSession: (sessionId: number) => void;
106+
}
107+
}
31108

32-
// #region ORT APIs
109+
export interface OrtInferenceAPIs {
33110
_OrtInit(numThreads: number, loggingLevel: number): number;
34111

35112
_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
@@ -74,129 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule {
74151
_OrtReleaseRunOptions(runOptionsHandle: number): void;
75152

76153
_OrtEndProfiling(sessionHandle: number): number;
77-
// #endregion
154+
}
155+
156+
export interface OrtTrainingAPIs {
157+
_OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number;
78158

79-
// #region ORT Training APIs
80-
_OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number;
159+
_OrtTrainingReleaseCheckpoint(checkpointHandle: number): void;
81160

82-
_OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void;
161+
_OrtTrainingCreateSession(
162+
sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
163+
evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
83164

84-
_OrtTrainingCreateSession?
85-
(sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
86-
evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
165+
_OrtTrainingLazyResetGrad(trainingHandle: number): number;
87166

88-
_OrtTrainingLazyResetGrad?(trainingHandle: number): number;
167+
_OrtTrainingRunTrainStep(
168+
trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
169+
runOptionsHandle: number): number;
89170

90-
_OrtTrainingRunTrainStep?
91-
(trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
92-
runOptionsHandle: number): number;
171+
_OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number;
93172

94-
_OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number;
173+
_OrtTrainingEvalStep(
174+
trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
175+
runOptionsHandle: number): number;
95176

96-
_OrtTrainingEvalStep?
97-
(trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
98-
runOptionsHandle: number): number;
177+
_OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
99178

100-
_OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
179+
_OrtTrainingCopyParametersToBuffer(
180+
trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
101181

102-
_OrtTrainingCopyParametersToBuffer?
103-
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
182+
_OrtTrainingCopyParametersFromBuffer(
183+
trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
104184

105-
_OrtTrainingCopyParametersFromBuffer?
106-
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
185+
_OrtTrainingGetModelInputOutputCount(
186+
trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
187+
_OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean):
188+
number;
189+
190+
_OrtTrainingReleaseSession(trainingHandle: number): void;
191+
}
107192

108-
_OrtTrainingGetModelInputOutputCount?
109-
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
110-
_OrtTrainingGetModelInputOutputName?
111-
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
193+
export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial<OrtTrainingAPIs>,
194+
Partial<JSEP.Module> {
195+
// #region emscripten functions
196+
stackSave(): number;
197+
stackRestore(stack: number): void;
198+
stackAlloc(size: number): number;
112199

113-
_OrtTrainingReleaseSession?(trainingHandle: number): void;
200+
UTF8ToString(offset: number, maxBytesToRead?: number): string;
201+
lengthBytesUTF8(str: string): number;
202+
stringToUTF8(str: string, offset: number, maxBytes: number): void;
114203
// #endregion
115204

116205
// #region config
117206
numThreads?: number;
118207
mainScriptUrlOrBlob?: string|Blob;
119208
// #endregion
120-
121-
// #region external data API
122-
mountExternalData?(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
123-
unmountExternalData?(): void;
124-
// #endregion
125-
126-
// #region JSEP
127-
/**
128-
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per backend.
129-
* This function initializes Asyncify support.
130-
* If name is 'webgpu', also initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
131-
*/
132-
jsepInit?(name: 'webgpu', initParams: [
133-
backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
134-
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction,
135-
run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, captureEnd: JSEP.CaptureEndFunction,
136-
replay: JSEP.ReplayFunction
137-
]): void;
138-
jsepInit?(name: 'webnn', initParams?: never): void;
139-
140-
/**
141-
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
142-
*
143-
* @param context - specify the kernel context pointer.
144-
* @param index - specify the index of the output.
145-
* @param data - specify the pointer to encoded data of type and dims.
146-
*/
147-
_JsepOutput(context: number, index: number, data: number): number;
148-
/**
149-
* [exported from wasm] Get name of an operator node.
150-
*
151-
* @param kernel - specify the kernel pointer.
152-
* @returns the pointer to a C-style UTF8 encoded string representing the node name.
153-
*/
154-
_JsepGetNodeName(kernel: number): number;
155-
156-
/**
157-
* [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
158-
*
159-
* @param sessionId - specify the session ID.
160-
* @param index - specify an integer to represent which input/output it is registering for. For input, it is the
161-
* input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
162-
* corresponding to the session's ouputNames.
163-
* @param buffer - specify the GPU buffer to register.
164-
* @param size - specify the original data size in byte.
165-
* @returns the GPU data ID for the registered GPU buffer.
166-
*/
167-
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
168-
/**
169-
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
170-
*
171-
* @param dataId - specify the GPU data ID
172-
* @returns the GPU buffer.
173-
*/
174-
jsepGetBuffer: (dataId: number) => GPUBuffer;
175-
/**
176-
* [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
177-
*
178-
* @param gpuBuffer - specify the GPU buffer
179-
* @param size - specify the original data size in byte.
180-
* @param type - specify the tensor type.
181-
* @returns the generated downloader function.
182-
*/
183-
jsepCreateDownloader:
184-
(gpuBuffer: GPUBuffer, size: number,
185-
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
186-
/**
187-
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
188-
* _OrtRun[WithBinding]() is called.
189-
* @param sessionId - specify the session ID.
190-
*/
191-
jsepOnRunStart: (sessionId: number) => void;
192-
/**
193-
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
194-
* called.
195-
* @param sessionId - specify the session ID.
196-
* @returns
197-
*/
198-
jsepOnReleaseSession: (sessionId: number) => void;
199-
// #endregion
200209
}
201210

202211
declare const moduleFactory: EmscriptenModuleFactory<OrtWasmModule>;

js/web/lib/wasm/jsep/init.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ComputeContextImpl implements ComputeContext {
119119
for (let i = 0; i < dims.length; i++) {
120120
this.module.HEAPU32[offset++] = dims[i];
121121
}
122-
return this.module._JsepOutput(this.opKernelContext, index, data);
122+
return this.module._JsepOutput!(this.opKernelContext, index, data);
123123
} catch (e) {
124124
throw new Error(
125125
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
@@ -200,8 +200,8 @@ export const init =
200200
},
201201

202202
// jsepCreateKernel
203-
(kernelType: string, kernelId: number, attribute: unknown) =>
204-
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))),
203+
(kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel(
204+
kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
205205

206206
// jsepReleaseKernel
207207
(kernel: number) => backend.releaseKernel(kernel),

js/web/lib/wasm/wasm-core-impl.ts

+12-3
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,12 @@ export const prepareInputOutputTensor =
381381
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
382382
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
383383
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
384-
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
384+
385+
const registerBuffer = wasm.jsepRegisterBuffer;
386+
if (!registerBuffer) {
387+
throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
388+
}
389+
rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
385390
} else {
386391
const data = tensor[2];
387392

@@ -596,7 +601,11 @@ export const run = async(
596601
// If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
597602
// tensor for it. There is no mapping GPU buffer for an empty tensor.
598603
if (preferredLocation === 'gpu-buffer' && size > 0) {
599-
const gpuBuffer = wasm.jsepGetBuffer(dataOffset);
604+
const getBuffer = wasm.jsepGetBuffer;
605+
if (!getBuffer) {
606+
throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
607+
}
608+
const gpuBuffer = getBuffer(dataOffset);
600609
const elementSize = getTensorElementSize(dataType);
601610
if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
602611
throw new Error(`Unsupported data type: ${type}`);
@@ -608,7 +617,7 @@ export const run = async(
608617
output.push([
609618
type, dims, {
610619
gpuBuffer,
611-
download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type),
620+
download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type),
612621
dispose: () => {
613622
wasm._OrtReleaseTensor(tensor);
614623
}

onnxruntime/wasm/js_internal_api.js

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
'use strict';
55

66
/**
7-
* Mount external data files of a model to the virtual file system (MEMFS).
7+
* Mount external data files of a model to an internal map, which will be used during session initialization.
88
*
99
* @param {string} externalDataFilesPath
1010
* @param {Uint8Array} externalDataFilesData
@@ -15,7 +15,7 @@ Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => {
1515
};
1616

1717
/**
18-
* Unmount external data files of a model from the virtual file system (MEMFS).
18+
* Unmount external data files of a model.
1919
*/
2020
Module['unmountExternalData'] = () => {
2121
delete Module.MountedFiles;
@@ -131,7 +131,7 @@ let jsepInitAsync = () => {
131131
}
132132

133133
// Flush the backend. This will submit all pending commands to the GPU.
134-
backend['flush']();
134+
Module.jsepBackend?.['flush']();
135135

136136
// Await all pending promises. This includes GPU validation promises for diagnostic purposes.
137137
const errorPromises = state.errors;
@@ -180,7 +180,8 @@ Module['jsepInit'] = (name, params) => {
180180
jsepInitAsync?.();
181181

182182
if (name === 'webgpu') {
183-
[Module.jsepBackend, Module.jsepAlloc,
183+
[Module.jsepBackend,
184+
Module.jsepAlloc,
184185
Module.jsepFree,
185186
Module.jsepCopy,
186187
Module.jsepCopyAsync,

0 commit comments

Comments
 (0)