@@ -175,44 +175,137 @@ static inline const char *getUrResultString(ur_result_t Result) {
175
175
fprintf (stderr, " UR <--- %s(%s)\n " , #Call, getUrResultString (Result)); \
176
176
}
177
177
178
+ // Handle to a kernel command.
179
+ //
180
+ // Struct that stores all the information related to a kernel command in a
181
+ // command-buffer, such that the command can be recreated. When handles can
182
+ // be returned from other command types this struct will need refactored.
183
+ struct ur_exp_command_buffer_command_handle_t_ {
184
+ ur_exp_command_buffer_command_handle_t_ (
185
+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
186
+ std::shared_ptr<hipGraphNode_t> &&Node, hipKernelNodeParams Params,
187
+ uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr,
188
+ const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr);
189
+
190
+ void setGlobalOffset (const size_t *GlobalWorkOffsetPtr) {
191
+ const size_t CopySize = sizeof (size_t ) * WorkDim;
192
+ std::memcpy (GlobalWorkOffset, GlobalWorkOffsetPtr, CopySize);
193
+ if (WorkDim < 3 ) {
194
+ const size_t ZeroSize = sizeof (size_t ) * (3 - WorkDim);
195
+ std::memset (GlobalWorkOffset + WorkDim, 0 , ZeroSize);
196
+ }
197
+ }
198
+
199
+ void setGlobalSize (const size_t *GlobalWorkSizePtr) {
200
+ const size_t CopySize = sizeof (size_t ) * WorkDim;
201
+ std::memcpy (GlobalWorkSize, GlobalWorkSizePtr, CopySize);
202
+ if (WorkDim < 3 ) {
203
+ const size_t ZeroSize = sizeof (size_t ) * (3 - WorkDim);
204
+ std::memset (GlobalWorkSize + WorkDim, 0 , ZeroSize);
205
+ }
206
+ }
207
+
208
+ void setLocalSize (const size_t *LocalWorkSizePtr) {
209
+ const size_t CopySize = sizeof (size_t ) * WorkDim;
210
+ std::memcpy (LocalWorkSize, LocalWorkSizePtr, CopySize);
211
+ if (WorkDim < 3 ) {
212
+ const size_t ZeroSize = sizeof (size_t ) * (3 - WorkDim);
213
+ std::memset (LocalWorkSize + WorkDim, 0 , ZeroSize);
214
+ }
215
+ }
216
+
217
+ uint32_t incrementInternalReferenceCount () noexcept {
218
+ return ++RefCountInternal;
219
+ }
220
+ uint32_t decrementInternalReferenceCount () noexcept {
221
+ return --RefCountInternal;
222
+ }
223
+
224
+ uint32_t incrementExternalReferenceCount () noexcept {
225
+ return ++RefCountExternal;
226
+ }
227
+ uint32_t decrementExternalReferenceCount () noexcept {
228
+ return --RefCountExternal;
229
+ }
230
+ uint32_t getExternalReferenceCount () const noexcept {
231
+ return RefCountExternal;
232
+ }
233
+
234
+ ur_exp_command_buffer_handle_t CommandBuffer;
235
+ ur_kernel_handle_t Kernel;
236
+ std::shared_ptr<hipGraphNode_t> Node;
237
+ hipKernelNodeParams Params;
238
+
239
+ uint32_t WorkDim;
240
+ size_t GlobalWorkOffset[3 ];
241
+ size_t GlobalWorkSize[3 ];
242
+ size_t LocalWorkSize[3 ];
243
+
244
+ private:
245
+ std::atomic_uint32_t RefCountInternal;
246
+ std::atomic_uint32_t RefCountExternal;
247
+ };
248
+
178
249
struct ur_exp_command_buffer_handle_t_ {
179
250
180
251
ur_exp_command_buffer_handle_t_ (ur_context_handle_t hContext,
181
- ur_device_handle_t hDevice);
252
+ ur_device_handle_t hDevice, bool IsUpdatable );
182
253
183
254
~ur_exp_command_buffer_handle_t_ ();
184
255
185
- void RegisterSyncPoint (ur_exp_command_buffer_sync_point_t SyncPoint,
256
+ void registerSyncPoint (ur_exp_command_buffer_sync_point_t SyncPoint,
186
257
std::shared_ptr<hipGraphNode_t> HIPNode) {
187
258
SyncPoints[SyncPoint] = HIPNode;
188
259
NextSyncPoint++;
189
260
}
190
261
191
- ur_exp_command_buffer_sync_point_t GetNextSyncPoint () const {
262
+ ur_exp_command_buffer_sync_point_t getNextSyncPoint () const {
192
263
return NextSyncPoint;
193
264
}
194
265
195
266
// Helper to register next sync point
196
267
// @param HIPNode Node to register as next sync point
197
268
// @return Pointer to the sync that registers the Node
198
269
ur_exp_command_buffer_sync_point_t
199
- AddSyncPoint (std::shared_ptr<hipGraphNode_t> HIPNode) {
270
+ addSyncPoint (std::shared_ptr<hipGraphNode_t> HIPNode) {
200
271
ur_exp_command_buffer_sync_point_t SyncPoint = NextSyncPoint;
201
- RegisterSyncPoint (SyncPoint, HIPNode);
272
+ registerSyncPoint (SyncPoint, HIPNode);
202
273
return SyncPoint;
203
274
}
275
+ uint32_t incrementInternalReferenceCount () noexcept {
276
+ return ++RefCountInternal;
277
+ }
278
+ uint32_t decrementInternalReferenceCount () noexcept {
279
+ return --RefCountInternal;
280
+ }
281
+ uint32_t getInternalReferenceCount () const noexcept {
282
+ return RefCountInternal;
283
+ }
284
+
285
+ uint32_t incrementExternalReferenceCount () noexcept {
286
+ return ++RefCountExternal;
287
+ }
288
+ uint32_t decrementExternalReferenceCount () noexcept {
289
+ return --RefCountExternal;
290
+ }
291
+ uint32_t getExternalReferenceCount () const noexcept {
292
+ return RefCountExternal;
293
+ }
204
294
205
295
// UR context associated with this command-buffer
206
296
ur_context_handle_t Context;
207
297
// Device associated with this command buffer
208
298
ur_device_handle_t Device;
299
+ // Whether commands in the command-buffer can be updated
300
+ bool IsUpdatable;
209
301
// HIP Graph handle
210
302
hipGraph_t HIPGraph;
211
303
// HIP Graph Exec handle
212
304
hipGraphExec_t HIPGraphExec;
213
305
// Atomic variable counting the number of reference to this command_buffer
214
306
// using std::atomic prevents data race when incrementing/decrementing.
215
- std::atomic_uint32_t RefCount;
307
+ std::atomic_uint32_t RefCountInternal;
308
+ std::atomic_uint32_t RefCountExternal;
216
309
217
310
// Map of sync_points to ur_events
218
311
std::unordered_map<ur_exp_command_buffer_sync_point_t ,
@@ -222,9 +315,6 @@ struct ur_exp_command_buffer_handle_t_ {
222
315
// is not enough)
223
316
ur_exp_command_buffer_sync_point_t NextSyncPoint;
224
317
225
- // Used when retaining an object.
226
- uint32_t incrementReferenceCount () noexcept { return ++RefCount; }
227
- // Used when releasing an object.
228
- uint32_t decrementReferenceCount () noexcept { return --RefCount; }
229
- uint32_t getReferenceCount () const noexcept { return RefCount; }
318
+ // Handles to individual commands in the command-buffer
319
+ std::vector<ur_exp_command_buffer_command_handle_t > CommandHandles;
230
320
};
0 commit comments