@@ -894,28 +894,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
894
894
/* *
895
895
* Sets the kernel arguments for a kernel command that will be appended to the
896
896
* command buffer.
897
- * @param[in] CommandBuffer The CommandBuffer where the command will be
897
+ * @param[in] Device The Device associated with the command-buffer where the
898
+ * kernel command will be appended.
899
+ * @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
900
+ * on the /p ZeKernel object.
901
+ * @param[in] ZeKernel The handle to the Level-Zero kernel that will be
898
902
* appended.
899
- * @param[in] Kernel The handle to the kernel that will be appended.
900
903
* @return UR_RESULT_SUCCESS or an error code on failure
901
904
*/
902
- ur_result_t
903
- setKernelPendingArguments ( ur_exp_command_buffer_handle_t CommandBuffer ,
904
- ur_kernel_handle_t Kernel) {
905
-
905
+ ur_result_t setKernelPendingArguments (
906
+ ur_device_handle_t Device ,
907
+ std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
908
+ ze_kernel_handle_t ZeKernel) {
906
909
// If there are any pending arguments set them now.
907
- for (auto &Arg : Kernel-> PendingArguments ) {
910
+ for (auto &Arg : PendingArguments) {
908
911
// The ArgValue may be a NULL pointer in which case a NULL value is used for
909
912
// the kernel argument declared as a pointer to global or constant memory.
910
913
char **ZeHandlePtr = nullptr ;
911
914
if (Arg.Value ) {
912
- UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
913
- CommandBuffer-> Device , nullptr , 0u ));
915
+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode , Device,
916
+ nullptr , 0u ));
914
917
}
915
918
ZE2UR_CALL (zeKernelSetArgumentValue,
916
- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
919
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
917
920
}
918
- Kernel-> PendingArguments .clear ();
921
+ PendingArguments.clear ();
919
922
920
923
return UR_RESULT_SUCCESS;
921
924
}
@@ -951,21 +954,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
951
954
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
952
955
953
956
auto Platform = CommandBuffer->Context ->getPlatform ();
957
+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
958
+
954
959
if (NumKernelAlternatives > 0 ) {
955
960
ZeMutableCommandDesc.flags |=
956
961
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
957
962
958
963
std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
959
964
NumKernelAlternatives + 1 , nullptr );
960
965
966
+ ze_kernel_handle_t ZeMainKernel{};
967
+ UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
968
+
961
969
// Translate main kernel first
962
970
ZE2UR_CALL (zelLoaderTranslateHandle,
963
- (ZEL_HANDLE_KERNEL, Kernel-> ZeKernel ,
971
+ (ZEL_HANDLE_KERNEL, ZeMainKernel ,
964
972
(void **)&TranslatedKernelHandles[0 ]));
965
973
966
974
for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
975
+ ze_kernel_handle_t ZeAltKernel{};
976
+ UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
977
+
967
978
ZE2UR_CALL (zelLoaderTranslateHandle,
968
- (ZEL_HANDLE_KERNEL, KernelAlternatives[i]-> ZeKernel ,
979
+ (ZEL_HANDLE_KERNEL, ZeAltKernel ,
969
980
(void **)&TranslatedKernelHandles[i + 1 ]));
970
981
}
971
982
@@ -1022,23 +1033,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1022
1033
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
1023
1034
Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
1024
1035
1036
+ auto Device = CommandBuffer->Device ;
1037
+ ze_kernel_handle_t ZeKernel{};
1038
+ UR_CALL (getZeKernel (Device->ZeDevice , Kernel, &ZeKernel));
1039
+
1025
1040
if (GlobalWorkOffset != NULL ) {
1026
- UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , Kernel-> ZeKernel ,
1027
- WorkDim, GlobalWorkOffset));
1041
+ UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , ZeKernel, WorkDim ,
1042
+ GlobalWorkOffset));
1028
1043
}
1029
1044
1030
1045
// If there are any pending arguments set them now.
1031
1046
if (!Kernel->PendingArguments .empty ()) {
1032
- UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
1047
+ UR_CALL (
1048
+ setKernelPendingArguments (Device, Kernel->PendingArguments , ZeKernel));
1033
1049
}
1034
1050
1035
1051
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
1036
1052
uint32_t WG[3 ];
1037
- UR_CALL (calculateKernelWorkDimensions (Kernel-> ZeKernel , CommandBuffer-> Device ,
1053
+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, Device,
1038
1054
ZeThreadGroupDimensions, WG, WorkDim,
1039
1055
GlobalWorkSize, LocalWorkSize));
1040
1056
1041
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
1057
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
1042
1058
1043
1059
CommandBuffer->KernelsList .push_back (Kernel);
1044
1060
for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
@@ -1063,7 +1079,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1063
1079
SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1064
1080
1065
1081
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
1066
- (CommandBuffer->ZeComputeCommandList , Kernel-> ZeKernel ,
1082
+ (CommandBuffer->ZeComputeCommandList , ZeKernel,
1067
1083
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size (),
1068
1084
getPointerFromVector (ZeEventList)));
1069
1085
@@ -1836,6 +1852,7 @@ ur_result_t updateKernelCommand(
1836
1852
const auto CommandBuffer = Command->CommandBuffer ;
1837
1853
const void *NextDesc = nullptr ;
1838
1854
auto Platform = CommandBuffer->Context ->getPlatform ();
1855
+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
1839
1856
1840
1857
uint32_t Dim = CommandDesc->newWorkDim ;
1841
1858
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
@@ -1844,11 +1861,14 @@ ur_result_t updateKernelCommand(
1844
1861
1845
1862
// Kernel handle must be updated first for a given CommandId if required
1846
1863
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1864
+
1847
1865
if (NewKernel && Command->Kernel != NewKernel) {
1866
+ ze_kernel_handle_t ZeNewKernel{};
1867
+ UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
1868
+
1848
1869
ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1849
- ZE2UR_CALL (
1850
- zelLoaderTranslateHandle,
1851
- (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1870
+ ZE2UR_CALL (zelLoaderTranslateHandle,
1871
+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
1852
1872
1853
1873
ZE2UR_CALL (Platform->ZeMutableCmdListExt
1854
1874
.zexCommandListUpdateMutableCommandKernelsExp ,
@@ -1905,10 +1925,13 @@ ur_result_t updateKernelCommand(
1905
1925
// by the driver for the kernel.
1906
1926
bool UpdateWGSize = NewLocalWorkSize == nullptr ;
1907
1927
1928
+ ze_kernel_handle_t ZeKernel{};
1929
+ UR_CALL (getZeKernel (ZeDevice, Command->Kernel , &ZeKernel));
1930
+
1908
1931
uint32_t WG[3 ];
1909
- UR_CALL (calculateKernelWorkDimensions (
1910
- Command-> Kernel -> ZeKernel , CommandBuffer-> Device ,
1911
- ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1932
+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, CommandBuffer-> Device ,
1933
+ ZeThreadGroupDimensions, WG, Dim ,
1934
+ NewGlobalWorkSize, NewLocalWorkSize));
1912
1935
1913
1936
auto MutableGroupCountDesc =
1914
1937
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
0 commit comments