Skip to content

Commit 2e4c3ce

Browse files
committed
[L0 v2] implement enqueueCooperativeKernelLaunchExp
1 parent c94dbc8 commit 2e4c3ce

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -971,15 +971,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
971971
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
972972
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
973973
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
974-
std::ignore = hKernel;
975-
std::ignore = workDim;
976-
std::ignore = pGlobalWorkOffset;
977-
std::ignore = pGlobalWorkSize;
978-
std::ignore = pLocalWorkSize;
979-
std::ignore = numEventsInWaitList;
980-
std::ignore = phEventWaitList;
981-
std::ignore = phEvent;
982-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
974+
TRACK_SCOPE_LATENCY(
975+
"ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp");
976+
977+
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
978+
UR_ASSERT(hKernel->getProgramHandle(), UR_RESULT_ERROR_INVALID_NULL_POINTER);
979+
980+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
981+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
982+
983+
ze_kernel_handle_t hZeKernel = hKernel->getZeHandle(hDevice);
984+
985+
std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock(this->Mutex,
986+
hKernel->Mutex);
987+
988+
ze_group_count_t zeThreadGroupDimensions{1, 1, 1};
989+
uint32_t WG[3]{};
990+
UR_CALL(calculateKernelWorkDimensions(hZeKernel, hDevice,
991+
zeThreadGroupDimensions, WG, workDim,
992+
pGlobalWorkSize, pLocalWorkSize));
993+
994+
auto signalEvent = getSignalEvent(phEvent, UR_COMMAND_KERNEL_LAUNCH);
995+
996+
auto waitList = getWaitListView(phEventWaitList, numEventsInWaitList);
997+
998+
bool memoryMigrated = false;
999+
auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1000+
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
1001+
(handler.commandList.get(), dst, src, size, nullptr,
1002+
waitList.second, waitList.first));
1003+
memoryMigrated = true;
1004+
};
1005+
1006+
UR_CALL(hKernel->prepareForSubmission(hContext, hDevice, pGlobalWorkOffset,
1007+
workDim, WG[0], WG[1], WG[2],
1008+
memoryMigrate));
1009+
1010+
if (memoryMigrated) {
1011+
// If memory was migrated, we don't need to pass the wait list to
1012+
// the copy command again.
1013+
waitList.first = nullptr;
1014+
waitList.second = 0;
1015+
}
1016+
1017+
TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::"
1018+
"zeCommandListAppendLaunchCooperativeKernel");
1019+
auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent() : nullptr;
1020+
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
1021+
(handler.commandList.get(), hZeKernel, &zeThreadGroupDimensions,
1022+
zeSignalEvent, waitList.second, waitList.first));
1023+
1024+
return UR_RESULT_SUCCESS;
9831025
}
9841026

9851027
ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp(

0 commit comments

Comments
 (0)