Skip to content

Commit 6f361e3

Browse files
Moved changes from UR PR
1 parent 3730930 commit 6f361e3

File tree

2 files changed

+140
-73
lines changed

2 files changed

+140
-73
lines changed

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 137 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,65 @@ ur_result_t getEventsFromSyncPoints(
170170
return UR_RESULT_SUCCESS;
171171
}
172172

173+
/**
174+
* If necessary, it creates a signal event and appends it to the previous
175+
* command list (copy or compute), to indicate when it's finished executing.
176+
* @param[in] CommandBuffer The CommandBuffer where the command is appended.
177+
* @param[in] ZeCommandList the CommandList that's currently in use.
178+
* @param[out] WaitEventList The list of event for the future command list to
179+
* wait on before execution.
180+
* @return UR_RESULT_SUCCESS or an error code on failure
181+
*/
182+
ur_result_t createSyncPointBetweenCopyAndCompute(
183+
ur_exp_command_buffer_handle_t CommandBuffer,
184+
ze_command_list_handle_t ZeCommandList,
185+
std::vector<ze_event_handle_t> &WaitEventList) {
186+
187+
if (!CommandBuffer->ZeCopyCommandList) {
188+
return UR_RESULT_SUCCESS;
189+
}
190+
191+
bool IsCopy{ZeCommandList == CommandBuffer->ZeCopyCommandList};
192+
193+
// Skip synchronization for the first node in a graph or if the current
194+
// command list matches the previous one.
195+
if (!CommandBuffer->MWasPrevCopyCommandList.has_value()) {
196+
CommandBuffer->MWasPrevCopyCommandList = IsCopy;
197+
return UR_RESULT_SUCCESS;
198+
} else if (IsCopy == CommandBuffer->MWasPrevCopyCommandList) {
199+
return UR_RESULT_SUCCESS;
200+
}
201+
202+
/*
203+
* If the current CommandList differs from the previously used one, we must
204+
* append a signal event to the previous CommandList to track when
205+
* its execution is complete.
206+
*/
207+
ur_event_handle_t SignalPrevCommandEvent = nullptr;
208+
UR_CALL(EventCreate(CommandBuffer->Context, nullptr /*Queue*/,
209+
false /*IsMultiDevice*/, false, &SignalPrevCommandEvent,
210+
false /*CounterBasedEventEnabled*/,
211+
!CommandBuffer->IsProfilingEnabled,
212+
false /*InterruptBasedEventEnabled*/));
213+
214+
// Determine which command list to signal.
215+
auto CommandListToSignal = (!IsCopy && CommandBuffer->MWasPrevCopyCommandList)
216+
? CommandBuffer->ZeCopyCommandList
217+
: CommandBuffer->ZeComputeCommandList;
218+
CommandBuffer->MWasPrevCopyCommandList = IsCopy;
219+
220+
ZE2UR_CALL(zeCommandListAppendSignalEvent,
221+
(CommandListToSignal, SignalPrevCommandEvent->ZeEvent));
222+
223+
// Add the event to the dependencies for future command list to wait on.
224+
WaitEventList.push_back(SignalPrevCommandEvent->ZeEvent);
225+
226+
// Mark the event for future reset.
227+
CommandBuffer->ZeEventsList.push_back(SignalPrevCommandEvent->ZeEvent);
228+
229+
return UR_RESULT_SUCCESS;
230+
}
231+
173232
/**
174233
* If needed, creates a sync point for a given command and returns the L0
175234
* events associated with the sync point.
@@ -190,7 +249,7 @@ ur_result_t getEventsFromSyncPoints(
190249
*/
191250
ur_result_t createSyncPointAndGetZeEvents(
192251
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
193-
uint32_t NumSyncPointsInWaitList,
252+
ze_command_list_handle_t ZeCommandList, uint32_t NumSyncPointsInWaitList,
194253
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
195254
bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
196255
std::vector<ze_event_handle_t> &ZeEventList,
@@ -199,6 +258,11 @@ ur_result_t createSyncPointAndGetZeEvents(
199258
ZeLaunchEvent = nullptr;
200259

201260
if (CommandBuffer->IsInOrderCmdList) {
261+
UR_CALL(createSyncPointBetweenCopyAndCompute(CommandBuffer, ZeCommandList,
262+
ZeEventList));
263+
if (!ZeEventList.empty()) {
264+
NumSyncPointsInWaitList = ZeEventList.size();
265+
}
202266
return UR_RESULT_SUCCESS;
203267
}
204268

@@ -225,24 +289,24 @@ ur_result_t createSyncPointAndGetZeEvents(
225289
return UR_RESULT_SUCCESS;
226290
}
227291

228-
// Shared by all memory read/write/copy PI interfaces.
229-
// Helper function for common code when enqueuing memory operations to a command
230-
// buffer.
292+
// Shared by all memory read/write/copy UR interfaces.
293+
// Helper function for common code when enqueuing memory operations to a
294+
// command buffer.
231295
ur_result_t enqueueCommandBufferMemCopyHelper(
232296
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
233297
void *Dst, const void *Src, size_t Size, bool PreferCopyEngine,
234298
uint32_t NumSyncPointsInWaitList,
235299
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
236300
ur_exp_command_buffer_sync_point_t *RetSyncPoint) {
237301

302+
ze_command_list_handle_t ZeCommandList =
303+
CommandBuffer->chooseCommandList(PreferCopyEngine);
304+
238305
std::vector<ze_event_handle_t> ZeEventList;
239306
ze_event_handle_t ZeLaunchEvent = nullptr;
240307
UR_CALL(createSyncPointAndGetZeEvents(
241-
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
242-
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
243-
244-
ze_command_list_handle_t ZeCommandList =
245-
CommandBuffer->chooseCommandList(PreferCopyEngine);
308+
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
309+
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
246310

247311
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
248312
(ZeCommandList, Dst, Src, Size, ZeLaunchEvent, ZeEventList.size(),
@@ -293,14 +357,14 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
293357
const ze_copy_region_t ZeDstRegion = {DstOriginX, DstOriginY, DstOriginZ,
294358
Width, Height, Depth};
295359

360+
ze_command_list_handle_t ZeCommandList =
361+
CommandBuffer->chooseCommandList(PreferCopyEngine);
362+
296363
std::vector<ze_event_handle_t> ZeEventList;
297364
ze_event_handle_t ZeLaunchEvent = nullptr;
298365
UR_CALL(createSyncPointAndGetZeEvents(
299-
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
300-
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
301-
302-
ze_command_list_handle_t ZeCommandList =
303-
CommandBuffer->chooseCommandList(PreferCopyEngine);
366+
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
367+
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
304368

305369
ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
306370
(ZeCommandList, Dst, &ZeDstRegion, DstPitch, DstSlicePitch, Src,
@@ -321,19 +385,19 @@ ur_result_t enqueueCommandBufferFillHelper(
321385
UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0),
322386
UR_RESULT_ERROR_INVALID_VALUE);
323387

324-
std::vector<ze_event_handle_t> ZeEventList;
325-
ze_event_handle_t ZeLaunchEvent = nullptr;
326-
UR_CALL(createSyncPointAndGetZeEvents(
327-
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
328-
true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
329-
330388
bool PreferCopyEngine;
331389
UR_CALL(
332390
preferCopyEngineForFill(CommandBuffer, PatternSize, PreferCopyEngine));
333391

334392
ze_command_list_handle_t ZeCommandList =
335393
CommandBuffer->chooseCommandList(PreferCopyEngine);
336394

395+
std::vector<ze_event_handle_t> ZeEventList;
396+
ze_event_handle_t ZeLaunchEvent = nullptr;
397+
UR_CALL(createSyncPointAndGetZeEvents(
398+
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
399+
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
400+
337401
ZE2UR_CALL(zeCommandListAppendMemoryFill,
338402
(ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeLaunchEvent,
339403
ZeEventList.size(), getPointerFromVector(ZeEventList)));
@@ -477,12 +541,12 @@ void ur_exp_command_buffer_handle_t_::registerSyncPoint(
477541

478542
ze_command_list_handle_t
479543
ur_exp_command_buffer_handle_t_::chooseCommandList(bool PreferCopyEngine) {
480-
if (PreferCopyEngine && this->useCopyEngine() && !this->IsInOrderCmdList) {
544+
if (PreferCopyEngine && useCopyEngine() && !IsInOrderCmdList) {
481545
// We indicate that ZeCopyCommandList contains commands to be submitted.
482-
this->MCopyCommandListEmpty = false;
483-
return this->ZeCopyCommandList;
546+
MCopyCommandListEmpty = false;
547+
return ZeCopyCommandList;
484548
}
485-
return this->ZeComputeCommandList;
549+
return ZeComputeCommandList;
486550
}
487551

488552
ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
@@ -646,7 +710,7 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
646710
// the current implementation only uses the main copy engine and does not use
647711
// the link engine even if available.
648712
if (Device->hasMainCopyEngine()) {
649-
UR_CALL(createMainCommandList(Context, Device, false, false, true,
713+
UR_CALL(createMainCommandList(Context, Device, IsInOrder, false, true,
650714
ZeCopyCommandList));
651715
}
652716

@@ -812,18 +876,24 @@ finalizeWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer) {
812876
(CommandBuffer->ZeCommandListResetEvents,
813877
CommandBuffer->ExecutionFinishedEvent->ZeEvent));
814878

815-
if (CommandBuffer->IsInOrderCmdList) {
816-
ZE2UR_CALL(zeCommandListAppendSignalEvent,
817-
(CommandBuffer->ZeComputeCommandList,
818-
CommandBuffer->ExecutionFinishedEvent->ZeEvent));
819-
} else {
820-
// Reset the L0 events we use for command-buffer sync-points to the
821-
// non-signaled state. This is required for multiple submissions.
879+
// Reset the L0 events we use for command-buffer sync-points to the
880+
// non-signaled state. This is required for multiple submissions.
881+
auto resetEvents = [&CommandBuffer]() {
822882
for (auto &Event : CommandBuffer->ZeEventsList) {
823883
ZE2UR_CALL(zeCommandListAppendEventReset,
824884
(CommandBuffer->ZeCommandListResetEvents, Event));
825885
}
886+
};
826887

888+
if (CommandBuffer->IsInOrderCmdList) {
889+
if (!CommandBuffer->MCopyCommandListEmpty) {
890+
resetEvents();
891+
}
892+
ZE2UR_CALL(zeCommandListAppendSignalEvent,
893+
(CommandBuffer->ZeComputeCommandList,
894+
CommandBuffer->ExecutionFinishedEvent->ZeEvent));
895+
} else {
896+
resetEvents();
827897
// Wait for all the user added commands to complete, and signal the
828898
// command-buffer signal-event when they are done.
829899
ZE2UR_CALL(zeCommandListAppendBarrier,
@@ -1073,7 +1143,8 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10731143
std::vector<ze_event_handle_t> ZeEventList;
10741144
ze_event_handle_t ZeLaunchEvent = nullptr;
10751145
UR_CALL(createSyncPointAndGetZeEvents(
1076-
UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
1146+
UR_COMMAND_KERNEL_LAUNCH, CommandBuffer,
1147+
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
10771148
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
10781149

10791150
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
@@ -1306,29 +1377,25 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
13061377
std::ignore = Command;
13071378
std::ignore = Flags;
13081379

1309-
if (CommandBuffer->IsInOrderCmdList) {
1310-
// Add the prefetch command to the command-buffer.
1311-
// Note that L0 does not handle migration flags.
1312-
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
1313-
(CommandBuffer->ZeComputeCommandList, Mem, Size));
1314-
} else {
1315-
std::vector<ze_event_handle_t> ZeEventList;
1316-
ze_event_handle_t ZeLaunchEvent = nullptr;
1317-
UR_CALL(createSyncPointAndGetZeEvents(
1318-
UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
1319-
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
1320-
1321-
if (NumSyncPointsInWaitList) {
1322-
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
1323-
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
1324-
ZeEventList.data()));
1325-
}
1380+
std::vector<ze_event_handle_t> ZeEventList;
1381+
ze_event_handle_t ZeLaunchEvent = nullptr;
1382+
UR_CALL(createSyncPointAndGetZeEvents(
1383+
UR_COMMAND_USM_PREFETCH, CommandBuffer,
1384+
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
1385+
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
1386+
1387+
if (NumSyncPointsInWaitList) {
1388+
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
1389+
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
1390+
ZeEventList.data()));
1391+
}
13261392

1327-
// Add the prefetch command to the command-buffer.
1328-
// Note that L0 does not handle migration flags.
1329-
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
1330-
(CommandBuffer->ZeComputeCommandList, Mem, Size));
1393+
// Add the prefetch command to the command-buffer.
1394+
// Note that L0 does not handle migration flags.
1395+
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
1396+
(CommandBuffer->ZeComputeCommandList, Mem, Size));
13311397

1398+
if (!CommandBuffer->IsInOrderCmdList) {
13321399
// Level Zero does not have a completion "event" with the prefetch API,
13331400
// so manually add command to signal our event.
13341401
ZE2UR_CALL(zeCommandListAppendSignalEvent,
@@ -1376,27 +1443,24 @@ ur_result_t urCommandBufferAppendUSMAdviseExp(
13761443

13771444
ze_memory_advice_t ZeAdvice = static_cast<ze_memory_advice_t>(Value);
13781445

1379-
if (CommandBuffer->IsInOrderCmdList) {
1380-
ZE2UR_CALL(zeCommandListAppendMemAdvise,
1381-
(CommandBuffer->ZeComputeCommandList,
1382-
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));
1383-
} else {
1384-
std::vector<ze_event_handle_t> ZeEventList;
1385-
ze_event_handle_t ZeLaunchEvent = nullptr;
1386-
UR_CALL(createSyncPointAndGetZeEvents(
1387-
UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
1388-
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
1389-
1390-
if (NumSyncPointsInWaitList) {
1391-
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
1392-
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
1393-
ZeEventList.data()));
1394-
}
1446+
std::vector<ze_event_handle_t> ZeEventList;
1447+
ze_event_handle_t ZeLaunchEvent = nullptr;
1448+
UR_CALL(createSyncPointAndGetZeEvents(
1449+
UR_COMMAND_USM_ADVISE, CommandBuffer, CommandBuffer->ZeComputeCommandList,
1450+
NumSyncPointsInWaitList, SyncPointWaitList, true, RetSyncPoint,
1451+
ZeEventList, ZeLaunchEvent));
13951452

1396-
ZE2UR_CALL(zeCommandListAppendMemAdvise,
1397-
(CommandBuffer->ZeComputeCommandList,
1398-
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));
1453+
if (NumSyncPointsInWaitList) {
1454+
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
1455+
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
1456+
ZeEventList.data()));
1457+
}
1458+
1459+
ZE2UR_CALL(zeCommandListAppendMemAdvise,
1460+
(CommandBuffer->ZeComputeCommandList,
1461+
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));
13991462

1463+
if (!CommandBuffer->IsInOrderCmdList) {
14001464
// Level Zero does not have a completion "event" with the advise API,
14011465
// so manually add command to signal our event.
14021466
ZE2UR_CALL(zeCommandListAppendSignalEvent,

unified-runtime/source/adapters/level_zero/command_buffer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include <optional>
1213
#include <ur/ur.hpp>
1314
#include <ur_api.h>
1415
#include <ze_api.h>
@@ -110,6 +111,8 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object {
110111
// This flag must be set to false if at least one copy command has been
111112
// added to `ZeCopyCommandList`
112113
bool MCopyCommandListEmpty = true;
114+
// This flag tracks if the previous node submission was of a copy type.
115+
std::optional<bool> MWasPrevCopyCommandList;
113116
// [WaitEvent Path only] Level Zero fences for each queue the command-buffer
114117
// has been enqueued to. These should be destroyed when the command-buffer is
115118
// released.

0 commit comments

Comments
 (0)