@@ -170,6 +170,65 @@ ur_result_t getEventsFromSyncPoints(
170
170
return UR_RESULT_SUCCESS;
171
171
}
172
172
173
+ /* *
174
+ * If necessary, it creates a signal event and appends it to the previous
175
+ * command list (copy or compute), to indicate when it's finished executing.
176
+ * @param[in] CommandBuffer The CommandBuffer where the command is appended.
177
+ * @param[in] ZeCommandList the CommandList that's currently in use.
178
+ * @param[out] WaitEventList The list of event for the future command list to
179
+ * wait on before execution.
180
+ * @return UR_RESULT_SUCCESS or an error code on failure
181
+ */
182
+ ur_result_t createSyncPointBetweenCopyAndCompute (
183
+ ur_exp_command_buffer_handle_t CommandBuffer,
184
+ ze_command_list_handle_t ZeCommandList,
185
+ std::vector<ze_event_handle_t > &WaitEventList) {
186
+
187
+ if (!CommandBuffer->ZeCopyCommandList ) {
188
+ return UR_RESULT_SUCCESS;
189
+ }
190
+
191
+ bool IsCopy{ZeCommandList == CommandBuffer->ZeCopyCommandList };
192
+
193
+ // Skip synchronization for the first node in a graph or if the current
194
+ // command list matches the previous one.
195
+ if (!CommandBuffer->MWasPrevCopyCommandList .has_value ()) {
196
+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
197
+ return UR_RESULT_SUCCESS;
198
+ } else if (IsCopy == CommandBuffer->MWasPrevCopyCommandList ) {
199
+ return UR_RESULT_SUCCESS;
200
+ }
201
+
202
+ /*
203
+ * If the current CommandList differs from the previously used one, we must
204
+ * append a signal event to the previous CommandList to track when
205
+ * its execution is complete.
206
+ */
207
+ ur_event_handle_t SignalPrevCommandEvent = nullptr ;
208
+ UR_CALL (EventCreate (CommandBuffer->Context , nullptr /* Queue*/ ,
209
+ false /* IsMultiDevice*/ , false , &SignalPrevCommandEvent,
210
+ false /* CounterBasedEventEnabled*/ ,
211
+ !CommandBuffer->IsProfilingEnabled ,
212
+ false /* InterruptBasedEventEnabled*/ ));
213
+
214
+ // Determine which command list to signal.
215
+ auto CommandListToSignal = (!IsCopy && CommandBuffer->MWasPrevCopyCommandList )
216
+ ? CommandBuffer->ZeCopyCommandList
217
+ : CommandBuffer->ZeComputeCommandList ;
218
+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
219
+
220
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
221
+ (CommandListToSignal, SignalPrevCommandEvent->ZeEvent ));
222
+
223
+ // Add the event to the dependencies for future command list to wait on.
224
+ WaitEventList.push_back (SignalPrevCommandEvent->ZeEvent );
225
+
226
+ // Mark the event for future reset.
227
+ CommandBuffer->ZeEventsList .push_back (SignalPrevCommandEvent->ZeEvent );
228
+
229
+ return UR_RESULT_SUCCESS;
230
+ }
231
+
173
232
/* *
174
233
* If needed, creates a sync point for a given command and returns the L0
175
234
* events associated with the sync point.
@@ -190,7 +249,7 @@ ur_result_t getEventsFromSyncPoints(
190
249
*/
191
250
ur_result_t createSyncPointAndGetZeEvents (
192
251
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
193
- uint32_t NumSyncPointsInWaitList,
252
+ ze_command_list_handle_t ZeCommandList, uint32_t NumSyncPointsInWaitList,
194
253
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
195
254
bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
196
255
std::vector<ze_event_handle_t > &ZeEventList,
@@ -199,6 +258,11 @@ ur_result_t createSyncPointAndGetZeEvents(
199
258
ZeLaunchEvent = nullptr ;
200
259
201
260
if (CommandBuffer->IsInOrderCmdList ) {
261
+ UR_CALL (createSyncPointBetweenCopyAndCompute (CommandBuffer, ZeCommandList,
262
+ ZeEventList));
263
+ if (!ZeEventList.empty ()) {
264
+ NumSyncPointsInWaitList = ZeEventList.size ();
265
+ }
202
266
return UR_RESULT_SUCCESS;
203
267
}
204
268
@@ -225,24 +289,24 @@ ur_result_t createSyncPointAndGetZeEvents(
225
289
return UR_RESULT_SUCCESS;
226
290
}
227
291
228
- // Shared by all memory read/write/copy PI interfaces.
229
- // Helper function for common code when enqueuing memory operations to a command
230
- // buffer.
292
+ // Shared by all memory read/write/copy UR interfaces.
293
+ // Helper function for common code when enqueuing memory operations to a
294
+ // command buffer.
231
295
ur_result_t enqueueCommandBufferMemCopyHelper (
232
296
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
233
297
void *Dst, const void *Src, size_t Size , bool PreferCopyEngine,
234
298
uint32_t NumSyncPointsInWaitList,
235
299
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
236
300
ur_exp_command_buffer_sync_point_t *RetSyncPoint) {
237
301
302
+ ze_command_list_handle_t ZeCommandList =
303
+ CommandBuffer->chooseCommandList (PreferCopyEngine);
304
+
238
305
std::vector<ze_event_handle_t > ZeEventList;
239
306
ze_event_handle_t ZeLaunchEvent = nullptr ;
240
307
UR_CALL (createSyncPointAndGetZeEvents (
241
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
242
- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
243
-
244
- ze_command_list_handle_t ZeCommandList =
245
- CommandBuffer->chooseCommandList (PreferCopyEngine);
308
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
309
+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
246
310
247
311
ZE2UR_CALL (zeCommandListAppendMemoryCopy,
248
312
(ZeCommandList, Dst, Src, Size , ZeLaunchEvent, ZeEventList.size (),
@@ -293,14 +357,14 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
293
357
const ze_copy_region_t ZeDstRegion = {DstOriginX, DstOriginY, DstOriginZ,
294
358
Width, Height, Depth};
295
359
360
+ ze_command_list_handle_t ZeCommandList =
361
+ CommandBuffer->chooseCommandList (PreferCopyEngine);
362
+
296
363
std::vector<ze_event_handle_t > ZeEventList;
297
364
ze_event_handle_t ZeLaunchEvent = nullptr ;
298
365
UR_CALL (createSyncPointAndGetZeEvents (
299
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
300
- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
301
-
302
- ze_command_list_handle_t ZeCommandList =
303
- CommandBuffer->chooseCommandList (PreferCopyEngine);
366
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
367
+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
304
368
305
369
ZE2UR_CALL (zeCommandListAppendMemoryCopyRegion,
306
370
(ZeCommandList, Dst, &ZeDstRegion, DstPitch, DstSlicePitch, Src,
@@ -321,19 +385,19 @@ ur_result_t enqueueCommandBufferFillHelper(
321
385
UR_ASSERT ((PatternSize > 0 ) && ((PatternSize & (PatternSize - 1 )) == 0 ),
322
386
UR_RESULT_ERROR_INVALID_VALUE);
323
387
324
- std::vector<ze_event_handle_t > ZeEventList;
325
- ze_event_handle_t ZeLaunchEvent = nullptr ;
326
- UR_CALL (createSyncPointAndGetZeEvents (
327
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
328
- true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
329
-
330
388
bool PreferCopyEngine;
331
389
UR_CALL (
332
390
preferCopyEngineForFill (CommandBuffer, PatternSize, PreferCopyEngine));
333
391
334
392
ze_command_list_handle_t ZeCommandList =
335
393
CommandBuffer->chooseCommandList (PreferCopyEngine);
336
394
395
+ std::vector<ze_event_handle_t > ZeEventList;
396
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
397
+ UR_CALL (createSyncPointAndGetZeEvents (
398
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
399
+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
400
+
337
401
ZE2UR_CALL (zeCommandListAppendMemoryFill,
338
402
(ZeCommandList, Ptr , Pattern, PatternSize, Size , ZeLaunchEvent,
339
403
ZeEventList.size (), getPointerFromVector (ZeEventList)));
@@ -477,12 +541,12 @@ void ur_exp_command_buffer_handle_t_::registerSyncPoint(
477
541
478
542
ze_command_list_handle_t
479
543
ur_exp_command_buffer_handle_t_::chooseCommandList (bool PreferCopyEngine) {
480
- if (PreferCopyEngine && this -> useCopyEngine () && !this -> IsInOrderCmdList ) {
544
+ if (PreferCopyEngine && useCopyEngine () && !IsInOrderCmdList) {
481
545
// We indicate that ZeCopyCommandList contains commands to be submitted.
482
- this -> MCopyCommandListEmpty = false ;
483
- return this -> ZeCopyCommandList ;
546
+ MCopyCommandListEmpty = false ;
547
+ return ZeCopyCommandList;
484
548
}
485
- return this -> ZeComputeCommandList ;
549
+ return ZeComputeCommandList;
486
550
}
487
551
488
552
ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue (
@@ -646,7 +710,7 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
646
710
// the current implementation only uses the main copy engine and does not use
647
711
// the link engine even if available.
648
712
if (Device->hasMainCopyEngine ()) {
649
- UR_CALL (createMainCommandList (Context, Device, false , false , true ,
713
+ UR_CALL (createMainCommandList (Context, Device, IsInOrder , false , true ,
650
714
ZeCopyCommandList));
651
715
}
652
716
@@ -812,18 +876,24 @@ finalizeWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer) {
812
876
(CommandBuffer->ZeCommandListResetEvents ,
813
877
CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
814
878
815
- if (CommandBuffer->IsInOrderCmdList ) {
816
- ZE2UR_CALL (zeCommandListAppendSignalEvent,
817
- (CommandBuffer->ZeComputeCommandList ,
818
- CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
819
- } else {
820
- // Reset the L0 events we use for command-buffer sync-points to the
821
- // non-signaled state. This is required for multiple submissions.
879
+ // Reset the L0 events we use for command-buffer sync-points to the
880
+ // non-signaled state. This is required for multiple submissions.
881
+ auto resetEvents = [&CommandBuffer]() {
822
882
for (auto &Event : CommandBuffer->ZeEventsList ) {
823
883
ZE2UR_CALL (zeCommandListAppendEventReset,
824
884
(CommandBuffer->ZeCommandListResetEvents , Event));
825
885
}
886
+ };
826
887
888
+ if (CommandBuffer->IsInOrderCmdList ) {
889
+ if (!CommandBuffer->MCopyCommandListEmpty ) {
890
+ resetEvents ();
891
+ }
892
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
893
+ (CommandBuffer->ZeComputeCommandList ,
894
+ CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
895
+ } else {
896
+ resetEvents ();
827
897
// Wait for all the user added commands to complete, and signal the
828
898
// command-buffer signal-event when they are done.
829
899
ZE2UR_CALL (zeCommandListAppendBarrier,
@@ -1073,7 +1143,8 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1073
1143
std::vector<ze_event_handle_t > ZeEventList;
1074
1144
ze_event_handle_t ZeLaunchEvent = nullptr ;
1075
1145
UR_CALL (createSyncPointAndGetZeEvents (
1076
- UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
1146
+ UR_COMMAND_KERNEL_LAUNCH, CommandBuffer,
1147
+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1077
1148
SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1078
1149
1079
1150
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
@@ -1306,29 +1377,25 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
1306
1377
std::ignore = Command;
1307
1378
std::ignore = Flags;
1308
1379
1309
- if (CommandBuffer->IsInOrderCmdList ) {
1310
- // Add the prefetch command to the command-buffer.
1311
- // Note that L0 does not handle migration flags.
1312
- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1313
- (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1314
- } else {
1315
- std::vector<ze_event_handle_t > ZeEventList;
1316
- ze_event_handle_t ZeLaunchEvent = nullptr ;
1317
- UR_CALL (createSyncPointAndGetZeEvents (
1318
- UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
1319
- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1320
-
1321
- if (NumSyncPointsInWaitList) {
1322
- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1323
- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1324
- ZeEventList.data ()));
1325
- }
1380
+ std::vector<ze_event_handle_t > ZeEventList;
1381
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1382
+ UR_CALL (createSyncPointAndGetZeEvents (
1383
+ UR_COMMAND_USM_PREFETCH, CommandBuffer,
1384
+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1385
+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1386
+
1387
+ if (NumSyncPointsInWaitList) {
1388
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1389
+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1390
+ ZeEventList.data ()));
1391
+ }
1326
1392
1327
- // Add the prefetch command to the command-buffer.
1328
- // Note that L0 does not handle migration flags.
1329
- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1330
- (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1393
+ // Add the prefetch command to the command-buffer.
1394
+ // Note that L0 does not handle migration flags.
1395
+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1396
+ (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1331
1397
1398
+ if (!CommandBuffer->IsInOrderCmdList ) {
1332
1399
// Level Zero does not have a completion "event" with the prefetch API,
1333
1400
// so manually add command to signal our event.
1334
1401
ZE2UR_CALL (zeCommandListAppendSignalEvent,
@@ -1376,27 +1443,24 @@ ur_result_t urCommandBufferAppendUSMAdviseExp(
1376
1443
1377
1444
ze_memory_advice_t ZeAdvice = static_cast <ze_memory_advice_t >(Value);
1378
1445
1379
- if (CommandBuffer->IsInOrderCmdList ) {
1380
- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1381
- (CommandBuffer->ZeComputeCommandList ,
1382
- CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1383
- } else {
1384
- std::vector<ze_event_handle_t > ZeEventList;
1385
- ze_event_handle_t ZeLaunchEvent = nullptr ;
1386
- UR_CALL (createSyncPointAndGetZeEvents (
1387
- UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
1388
- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1389
-
1390
- if (NumSyncPointsInWaitList) {
1391
- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1392
- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1393
- ZeEventList.data ()));
1394
- }
1446
+ std::vector<ze_event_handle_t > ZeEventList;
1447
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1448
+ UR_CALL (createSyncPointAndGetZeEvents (
1449
+ UR_COMMAND_USM_ADVISE, CommandBuffer, CommandBuffer->ZeComputeCommandList ,
1450
+ NumSyncPointsInWaitList, SyncPointWaitList, true , RetSyncPoint,
1451
+ ZeEventList, ZeLaunchEvent));
1395
1452
1396
- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1397
- (CommandBuffer->ZeComputeCommandList ,
1398
- CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1453
+ if (NumSyncPointsInWaitList) {
1454
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1455
+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1456
+ ZeEventList.data ()));
1457
+ }
1458
+
1459
+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
1460
+ (CommandBuffer->ZeComputeCommandList ,
1461
+ CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1399
1462
1463
+ if (!CommandBuffer->IsInOrderCmdList ) {
1400
1464
// Level Zero does not have a completion "event" with the advise API,
1401
1465
// so manually add command to signal our event.
1402
1466
ZE2UR_CALL (zeCommandListAppendSignalEvent,
0 commit comments