@@ -971,15 +971,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
971
971
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
972
972
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
973
973
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
974
- std::ignore = hKernel;
975
- std::ignore = workDim;
976
- std::ignore = pGlobalWorkOffset;
977
- std::ignore = pGlobalWorkSize;
978
- std::ignore = pLocalWorkSize;
979
- std::ignore = numEventsInWaitList;
980
- std::ignore = phEventWaitList;
981
- std::ignore = phEvent;
982
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
974
+ TRACK_SCOPE_LATENCY (
975
+ " ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp" );
976
+
977
+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
978
+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
979
+
980
+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
981
+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
982
+
983
+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice);
984
+
985
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock (this ->Mutex ,
986
+ hKernel->Mutex );
987
+
988
+ ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
989
+ uint32_t WG[3 ]{};
990
+ UR_CALL (calculateKernelWorkDimensions (hZeKernel, hDevice,
991
+ zeThreadGroupDimensions, WG, workDim,
992
+ pGlobalWorkSize, pLocalWorkSize));
993
+
994
+ auto signalEvent = getSignalEvent (phEvent, UR_COMMAND_KERNEL_LAUNCH);
995
+
996
+ auto waitList = getWaitListView (phEventWaitList, numEventsInWaitList);
997
+
998
+ bool memoryMigrated = false ;
999
+ auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1000
+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
1001
+ (handler.commandList .get (), dst, src, size, nullptr ,
1002
+ waitList.second , waitList.first ));
1003
+ memoryMigrated = true ;
1004
+ };
1005
+
1006
+ UR_CALL (hKernel->prepareForSubmission (hContext, hDevice, pGlobalWorkOffset,
1007
+ workDim, WG[0 ], WG[1 ], WG[2 ],
1008
+ memoryMigrate));
1009
+
1010
+ if (memoryMigrated) {
1011
+ // If memory was migrated, we don't need to pass the wait list to
1012
+ // the copy command again.
1013
+ waitList.first = nullptr ;
1014
+ waitList.second = 0 ;
1015
+ }
1016
+
1017
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::"
1018
+ " zeCommandListAppendLaunchCooperativeKernel" );
1019
+ auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent () : nullptr ;
1020
+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
1021
+ (handler.commandList .get (), hZeKernel, &zeThreadGroupDimensions,
1022
+ zeSignalEvent, waitList.second , waitList.first ));
1023
+
1024
+ return UR_RESULT_SUCCESS;
983
1025
}
984
1026
985
1027
ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp (
0 commit comments