Skip to content

Commit 50ef8c9

Browse files
authored
Merge pull request #2307 from igchor/cooperative_exec
[L0 v2] implement enqueueCooperativeKernelLaunchExp
2 parents 3a5b23c + 2e4c3ce commit 50ef8c9

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -978,15 +978,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
978978
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
979979
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
980980
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
981-
std::ignore = hKernel;
982-
std::ignore = workDim;
983-
std::ignore = pGlobalWorkOffset;
984-
std::ignore = pGlobalWorkSize;
985-
std::ignore = pLocalWorkSize;
986-
std::ignore = numEventsInWaitList;
987-
std::ignore = phEventWaitList;
988-
std::ignore = phEvent;
989-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
981+
TRACK_SCOPE_LATENCY(
982+
"ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp");
983+
984+
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
985+
UR_ASSERT(hKernel->getProgramHandle(), UR_RESULT_ERROR_INVALID_NULL_POINTER);
986+
987+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
988+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
989+
990+
ze_kernel_handle_t hZeKernel = hKernel->getZeHandle(hDevice);
991+
992+
std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock(this->Mutex,
993+
hKernel->Mutex);
994+
995+
ze_group_count_t zeThreadGroupDimensions{1, 1, 1};
996+
uint32_t WG[3]{};
997+
UR_CALL(calculateKernelWorkDimensions(hZeKernel, hDevice,
998+
zeThreadGroupDimensions, WG, workDim,
999+
pGlobalWorkSize, pLocalWorkSize));
1000+
1001+
auto signalEvent = getSignalEvent(phEvent, UR_COMMAND_KERNEL_LAUNCH);
1002+
1003+
auto waitList = getWaitListView(phEventWaitList, numEventsInWaitList);
1004+
1005+
bool memoryMigrated = false;
1006+
auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1007+
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
1008+
(handler.commandList.get(), dst, src, size, nullptr,
1009+
waitList.second, waitList.first));
1010+
memoryMigrated = true;
1011+
};
1012+
1013+
UR_CALL(hKernel->prepareForSubmission(hContext, hDevice, pGlobalWorkOffset,
1014+
workDim, WG[0], WG[1], WG[2],
1015+
memoryMigrate));
1016+
1017+
if (memoryMigrated) {
1018+
// If memory was migrated, we don't need to pass the wait list to
1019+
// the copy command again.
1020+
waitList.first = nullptr;
1021+
waitList.second = 0;
1022+
}
1023+
1024+
TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::"
1025+
"zeCommandListAppendLaunchCooperativeKernel");
1026+
auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent() : nullptr;
1027+
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
1028+
(handler.commandList.get(), hZeKernel, &zeThreadGroupDimensions,
1029+
zeSignalEvent, waitList.second, waitList.first));
1030+
1031+
return UR_RESULT_SUCCESS;
9901032
}
9911033

9921034
ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp(

0 commit comments

Comments
 (0)