@@ -978,15 +978,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
978
978
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
979
979
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
980
980
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
981
- std::ignore = hKernel;
982
- std::ignore = workDim;
983
- std::ignore = pGlobalWorkOffset;
984
- std::ignore = pGlobalWorkSize;
985
- std::ignore = pLocalWorkSize;
986
- std::ignore = numEventsInWaitList;
987
- std::ignore = phEventWaitList;
988
- std::ignore = phEvent;
989
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
981
+ TRACK_SCOPE_LATENCY (
982
+ " ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp" );
983
+
984
+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
985
+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
986
+
987
+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
988
+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
989
+
990
+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice);
991
+
992
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock (this ->Mutex ,
993
+ hKernel->Mutex );
994
+
995
+ ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
996
+ uint32_t WG[3 ]{};
997
+ UR_CALL (calculateKernelWorkDimensions (hZeKernel, hDevice,
998
+ zeThreadGroupDimensions, WG, workDim,
999
+ pGlobalWorkSize, pLocalWorkSize));
1000
+
1001
+ auto signalEvent = getSignalEvent (phEvent, UR_COMMAND_KERNEL_LAUNCH);
1002
+
1003
+ auto waitList = getWaitListView (phEventWaitList, numEventsInWaitList);
1004
+
1005
+ bool memoryMigrated = false ;
1006
+ auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1007
+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
1008
+ (handler.commandList .get (), dst, src, size, nullptr ,
1009
+ waitList.second , waitList.first ));
1010
+ memoryMigrated = true ;
1011
+ };
1012
+
1013
+ UR_CALL (hKernel->prepareForSubmission (hContext, hDevice, pGlobalWorkOffset,
1014
+ workDim, WG[0 ], WG[1 ], WG[2 ],
1015
+ memoryMigrate));
1016
+
1017
+ if (memoryMigrated) {
1018
+ // If memory was migrated, we don't need to pass the wait list to
1019
+ // the copy command again.
1020
+ waitList.first = nullptr ;
1021
+ waitList.second = 0 ;
1022
+ }
1023
+
1024
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::"
1025
+ " zeCommandListAppendLaunchCooperativeKernel" );
1026
+ auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent () : nullptr ;
1027
+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
1028
+ (handler.commandList .get (), hZeKernel, &zeThreadGroupDimensions,
1029
+ zeSignalEvent, waitList.second , waitList.first ));
1030
+
1031
+ return UR_RESULT_SUCCESS;
990
1032
}
991
1033
992
1034
ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp (
0 commit comments