@@ -16,20 +16,97 @@ export declare namespace JSEP {
16
16
type CaptureBeginFunction = ( ) => void ;
17
17
type CaptureEndFunction = ( ) => void ;
18
18
type ReplayFunction = ( ) => void ;
19
- }
20
19
21
- export interface OrtWasmModule extends EmscriptenModule {
22
- // #region emscripten functions
23
- stackSave ( ) : number ;
24
- stackRestore ( stack : number ) : void ;
25
- stackAlloc ( size : number ) : number ;
26
-
27
- UTF8ToString ( offset : number , maxBytesToRead ?: number ) : string ;
28
- lengthBytesUTF8 ( str : string ) : number ;
29
- stringToUTF8 ( str : string , offset : number , maxBytes : number ) : void ;
30
- // #endregion
20
+ export interface Module extends WebGpuModule {
21
+ /**
22
+ * Mount the external data file to an internal map, which will be used during session initialization.
23
+ *
24
+ * @param externalDataFilePath - specify the relative path of the external data file.
25
+ * @param externalDataFileData - specify the content data.
26
+ */
27
+ mountExternalData ( externalDataFilePath : string , externalDataFileData : Uint8Array ) : void ;
28
+ /**
29
+ * Unmount all external data files from the internal map.
30
+ */
31
+ unmountExternalData ( ) : void ;
32
+
33
+ /**
34
+ * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per
35
+ * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and
36
+ * registers a few callbacks that will be called in C++ code.
37
+ */
38
+ jsepInit ( name : 'webgpu' , initParams : [
39
+ backend : BackendType , alloc : AllocFunction , free : FreeFunction , upload : UploadFunction ,
40
+ download : DownloadFunction , createKernel : CreateKernelFunction , releaseKernel : ReleaseKernelFunction ,
41
+ run : RunFunction , captureBegin : CaptureBeginFunction , captureEnd : CaptureEndFunction , replay : ReplayFunction
42
+ ] ) : void ;
43
+ jsepInit ( name : 'webnn' , initParams ?: never ) : void ;
44
+ }
45
+
46
+ export interface WebGpuModule {
47
+ /**
48
+ * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
49
+ *
50
+ * @param context - specify the kernel context pointer.
51
+ * @param index - specify the index of the output.
52
+ * @param data - specify the pointer to encoded data of type and dims.
53
+ */
54
+ _JsepOutput ( context : number , index : number , data : number ) : number ;
55
+ /**
56
+ * [exported from wasm] Get name of an operator node.
57
+ *
58
+ * @param kernel - specify the kernel pointer.
59
+ * @returns the pointer to a C-style UTF8 encoded string representing the node name.
60
+ */
61
+ _JsepGetNodeName ( kernel : number ) : number ;
62
+
63
+ /**
64
+ * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
65
+ *
66
+ * @param sessionId - specify the session ID.
67
+ * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
68
+ * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
69
+ * corresponding to the session's ouputNames.
70
+ * @param buffer - specify the GPU buffer to register.
71
+ * @param size - specify the original data size in byte.
72
+ * @returns the GPU data ID for the registered GPU buffer.
73
+ */
74
+ jsepRegisterBuffer : ( sessionId : number , index : number , buffer : GPUBuffer , size : number ) => number ;
75
+ /**
76
+ * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
77
+ *
78
+ * @param dataId - specify the GPU data ID
79
+ * @returns the GPU buffer.
80
+ */
81
+ jsepGetBuffer : ( dataId : number ) => GPUBuffer ;
82
+ /**
83
+ * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
84
+ *
85
+ * @param gpuBuffer - specify the GPU buffer
86
+ * @param size - specify the original data size in byte.
87
+ * @param type - specify the tensor type.
88
+ * @returns the generated downloader function.
89
+ */
90
+ jsepCreateDownloader :
91
+ ( gpuBuffer : GPUBuffer , size : number ,
92
+ type : Tensor . GpuBufferDataTypes ) => ( ) => Promise < Tensor . DataTypeMap [ Tensor . GpuBufferDataTypes ] > ;
93
+ /**
94
+ * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
95
+ * _OrtRun[WithBinding]() is called.
96
+ * @param sessionId - specify the session ID.
97
+ */
98
+ jsepOnRunStart : ( sessionId : number ) => void ;
99
+ /**
100
+ * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
101
+ * called.
102
+ * @param sessionId - specify the session ID.
103
+ * @returns
104
+ */
105
+ jsepOnReleaseSession : ( sessionId : number ) => void ;
106
+ }
107
+ }
31
108
32
- // #region ORT APIs
109
+ export interface OrtInferenceAPIs {
33
110
_OrtInit ( numThreads : number , loggingLevel : number ) : number ;
34
111
35
112
_OrtGetLastError ( errorCodeOffset : number , errorMessageOffset : number ) : void ;
@@ -74,129 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule {
74
151
_OrtReleaseRunOptions ( runOptionsHandle : number ) : void ;
75
152
76
153
_OrtEndProfiling ( sessionHandle : number ) : number ;
77
- // #endregion
154
+ }
155
+
156
+ export interface OrtTrainingAPIs {
157
+ _OrtTrainingLoadCheckpoint ( dataOffset : number , dataLength : number ) : number ;
78
158
79
- // #region ORT Training APIs
80
- _OrtTrainingLoadCheckpoint ?( dataOffset : number , dataLength : number ) : number ;
159
+ _OrtTrainingReleaseCheckpoint ( checkpointHandle : number ) : void ;
81
160
82
- _OrtTrainingReleaseCheckpoint ?( checkpointHandle : number ) : void ;
161
+ _OrtTrainingCreateSession (
162
+ sessionOptionsHandle : number , checkpointHandle : number , trainOffset : number , trainLength : number ,
163
+ evalOffset : number , evalLength : number , optimizerOffset : number , optimizerLength : number ) : number ;
83
164
84
- _OrtTrainingCreateSession ?
85
- ( sessionOptionsHandle : number , checkpointHandle : number , trainOffset : number , trainLength : number ,
86
- evalOffset : number , evalLength : number , optimizerOffset : number , optimizerLength : number ) : number ;
165
+ _OrtTrainingLazyResetGrad ( trainingHandle : number ) : number ;
87
166
88
- _OrtTrainingLazyResetGrad ?( trainingHandle : number ) : number ;
167
+ _OrtTrainingRunTrainStep (
168
+ trainingHandle : number , inputsOffset : number , inputCount : number , outputsOffset : number , outputCount : number ,
169
+ runOptionsHandle : number ) : number ;
89
170
90
- _OrtTrainingRunTrainStep ?
91
- ( trainingHandle : number , inputsOffset : number , inputCount : number , outputsOffset : number , outputCount : number ,
92
- runOptionsHandle : number ) : number ;
171
+ _OrtTrainingOptimizerStep ( trainingHandle : number , runOptionsHandle : number ) : number ;
93
172
94
- _OrtTrainingOptimizerStep ?( trainingHandle : number , runOptionsHandle : number ) : number ;
173
+ _OrtTrainingEvalStep (
174
+ trainingHandle : number , inputsOffset : number , inputCount : number , outputsOffset : number , outputCount : number ,
175
+ runOptionsHandle : number ) : number ;
95
176
96
- _OrtTrainingEvalStep ?
97
- ( trainingHandle : number , inputsOffset : number , inputCount : number , outputsOffset : number , outputCount : number ,
98
- runOptionsHandle : number ) : number ;
177
+ _OrtTrainingGetParametersSize ( trainingHandle : number , paramSizeT : number , trainableOnly : boolean ) : number ;
99
178
100
- _OrtTrainingGetParametersSize ?( trainingHandle : number , paramSizeT : number , trainableOnly : boolean ) : number ;
179
+ _OrtTrainingCopyParametersToBuffer (
180
+ trainingHandle : number , parametersBuffer : number , parameterCount : number , trainableOnly : boolean ) : number ;
101
181
102
- _OrtTrainingCopyParametersToBuffer ?
103
- ( trainingHandle : number , parametersBuffer : number , parameterCount : number , trainableOnly : boolean ) : number ;
182
+ _OrtTrainingCopyParametersFromBuffer (
183
+ trainingHandle : number , parametersBuffer : number , parameterCount : number , trainableOnly : boolean ) : number ;
104
184
105
- _OrtTrainingCopyParametersFromBuffer ?
106
- ( trainingHandle : number , parametersBuffer : number , parameterCount : number , trainableOnly : boolean ) : number ;
185
+ _OrtTrainingGetModelInputOutputCount (
186
+ trainingHandle : number , inputCount : number , outputCount : number , isEvalModel : boolean ) : number ;
187
+ _OrtTrainingGetModelInputOutputName ( trainingHandle : number , index : number , isInput : boolean , isEvalModel : boolean ) :
188
+ number ;
189
+
190
+ _OrtTrainingReleaseSession ( trainingHandle : number ) : void ;
191
+ }
107
192
108
- _OrtTrainingGetModelInputOutputCount ?
109
- ( trainingHandle : number , inputCount : number , outputCount : number , isEvalModel : boolean ) : number ;
110
- _OrtTrainingGetModelInputOutputName ?
111
- ( trainingHandle : number , index : number , isInput : boolean , isEvalModel : boolean ) : number ;
193
+ export interface OrtWasmModule extends EmscriptenModule , OrtInferenceAPIs , Partial < OrtTrainingAPIs > ,
194
+ Partial < JSEP . Module > {
195
+ // #region emscripten functions
196
+ stackSave ( ) : number ;
197
+ stackRestore ( stack : number ) : void ;
198
+ stackAlloc ( size : number ) : number ;
112
199
113
- _OrtTrainingReleaseSession ?( trainingHandle : number ) : void ;
200
+ UTF8ToString ( offset : number , maxBytesToRead ?: number ) : string ;
201
+ lengthBytesUTF8 ( str : string ) : number ;
202
+ stringToUTF8 ( str : string , offset : number , maxBytes : number ) : void ;
114
203
// #endregion
115
204
116
205
// #region config
117
206
numThreads ?: number ;
118
207
mainScriptUrlOrBlob ?: string | Blob ;
119
208
// #endregion
120
-
121
- // #region external data API
122
- mountExternalData ?( externalDataFilePath : string , externalDataFileData : Uint8Array ) : void ;
123
- unmountExternalData ?( ) : void ;
124
- // #endregion
125
-
126
- // #region JSEP
127
- /**
128
- * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per backend.
129
- * This function initializes Asyncify support.
130
- * If name is 'webgpu', also initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
131
- */
132
- jsepInit ?( name : 'webgpu' , initParams : [
133
- backend : JSEP . BackendType , alloc : JSEP . AllocFunction , free : JSEP . FreeFunction , upload : JSEP . UploadFunction ,
134
- download : JSEP . DownloadFunction , createKernel : JSEP . CreateKernelFunction , releaseKernel : JSEP . ReleaseKernelFunction ,
135
- run : JSEP . RunFunction , captureBegin : JSEP . CaptureBeginFunction , captureEnd : JSEP . CaptureEndFunction ,
136
- replay : JSEP . ReplayFunction
137
- ] ) : void ;
138
- jsepInit ?( name : 'webnn' , initParams ?: never ) : void ;
139
-
140
- /**
141
- * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
142
- *
143
- * @param context - specify the kernel context pointer.
144
- * @param index - specify the index of the output.
145
- * @param data - specify the pointer to encoded data of type and dims.
146
- */
147
- _JsepOutput ( context : number , index : number , data : number ) : number ;
148
- /**
149
- * [exported from wasm] Get name of an operator node.
150
- *
151
- * @param kernel - specify the kernel pointer.
152
- * @returns the pointer to a C-style UTF8 encoded string representing the node name.
153
- */
154
- _JsepGetNodeName ( kernel : number ) : number ;
155
-
156
- /**
157
- * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
158
- *
159
- * @param sessionId - specify the session ID.
160
- * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
161
- * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
162
- * corresponding to the session's ouputNames.
163
- * @param buffer - specify the GPU buffer to register.
164
- * @param size - specify the original data size in byte.
165
- * @returns the GPU data ID for the registered GPU buffer.
166
- */
167
- jsepRegisterBuffer : ( sessionId : number , index : number , buffer : GPUBuffer , size : number ) => number ;
168
- /**
169
- * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
170
- *
171
- * @param dataId - specify the GPU data ID
172
- * @returns the GPU buffer.
173
- */
174
- jsepGetBuffer : ( dataId : number ) => GPUBuffer ;
175
- /**
176
- * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
177
- *
178
- * @param gpuBuffer - specify the GPU buffer
179
- * @param size - specify the original data size in byte.
180
- * @param type - specify the tensor type.
181
- * @returns the generated downloader function.
182
- */
183
- jsepCreateDownloader :
184
- ( gpuBuffer : GPUBuffer , size : number ,
185
- type : Tensor . GpuBufferDataTypes ) => ( ) => Promise < Tensor . DataTypeMap [ Tensor . GpuBufferDataTypes ] > ;
186
- /**
187
- * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
188
- * _OrtRun[WithBinding]() is called.
189
- * @param sessionId - specify the session ID.
190
- */
191
- jsepOnRunStart : ( sessionId : number ) => void ;
192
- /**
193
- * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
194
- * called.
195
- * @param sessionId - specify the session ID.
196
- * @returns
197
- */
198
- jsepOnReleaseSession : ( sessionId : number ) => void ;
199
- // #endregion
200
209
}
201
210
202
211
declare const moduleFactory : EmscriptenModuleFactory < OrtWasmModule > ;
0 commit comments