@@ -895,28 +895,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
895
895
/* *
896
896
* Sets the kernel arguments for a kernel command that will be appended to the
897
897
* command buffer.
898
- * @param[in] CommandBuffer The CommandBuffer where the command will be
898
+ * @param[in] Device The Device associated with the command-buffer where the
899
+ * kernel command will be appended.
900
+ * @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
901
+ * on the /p ZeKernel object.
902
+ * @param[in] ZeKernel The handle to the Level-Zero kernel that will be
899
903
* appended.
900
- * @param[in] Kernel The handle to the kernel that will be appended.
901
904
* @return UR_RESULT_SUCCESS or an error code on failure
902
905
*/
903
- ur_result_t
904
- setKernelPendingArguments ( ur_exp_command_buffer_handle_t CommandBuffer ,
905
- ur_kernel_handle_t Kernel) {
906
-
906
+ ur_result_t setKernelPendingArguments (
907
+ ur_device_handle_t Device ,
908
+ std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
909
+ ze_kernel_handle_t ZeKernel) {
907
910
// If there are any pending arguments set them now.
908
- for (auto &Arg : Kernel-> PendingArguments ) {
911
+ for (auto &Arg : PendingArguments) {
909
912
// The ArgValue may be a NULL pointer in which case a NULL value is used for
910
913
// the kernel argument declared as a pointer to global or constant memory.
911
914
char **ZeHandlePtr = nullptr ;
912
915
if (Arg.Value ) {
913
- UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
914
- CommandBuffer-> Device , nullptr , 0u ));
916
+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode , Device,
917
+ nullptr , 0u ));
915
918
}
916
919
ZE2UR_CALL (zeKernelSetArgumentValue,
917
- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
920
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
918
921
}
919
- Kernel-> PendingArguments .clear ();
922
+ PendingArguments.clear ();
920
923
921
924
return UR_RESULT_SUCCESS;
922
925
}
@@ -952,21 +955,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
952
955
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
953
956
954
957
auto Platform = CommandBuffer->Context ->getPlatform ();
958
+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
959
+
955
960
if (NumKernelAlternatives > 0 ) {
956
961
ZeMutableCommandDesc.flags |=
957
962
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
958
963
959
964
std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
960
965
NumKernelAlternatives + 1 , nullptr );
961
966
967
+ ze_kernel_handle_t ZeMainKernel{};
968
+ UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
969
+
962
970
// Translate main kernel first
963
971
ZE2UR_CALL (zelLoaderTranslateHandle,
964
- (ZEL_HANDLE_KERNEL, Kernel-> ZeKernel ,
972
+ (ZEL_HANDLE_KERNEL, ZeMainKernel ,
965
973
(void **)&TranslatedKernelHandles[0 ]));
966
974
967
975
for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
976
+ ze_kernel_handle_t ZeAltKernel{};
977
+ UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
978
+
968
979
ZE2UR_CALL (zelLoaderTranslateHandle,
969
- (ZEL_HANDLE_KERNEL, KernelAlternatives[i]-> ZeKernel ,
980
+ (ZEL_HANDLE_KERNEL, ZeAltKernel ,
970
981
(void **)&TranslatedKernelHandles[i + 1 ]));
971
982
}
972
983
@@ -1023,23 +1034,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1023
1034
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
1024
1035
Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
1025
1036
1037
+ auto Device = CommandBuffer->Device ;
1038
+ ze_kernel_handle_t ZeKernel{};
1039
+ UR_CALL (getZeKernel (Device->ZeDevice , Kernel, &ZeKernel));
1040
+
1026
1041
if (GlobalWorkOffset != NULL ) {
1027
- UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , Kernel-> ZeKernel ,
1028
- WorkDim, GlobalWorkOffset));
1042
+ UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , ZeKernel, WorkDim ,
1043
+ GlobalWorkOffset));
1029
1044
}
1030
1045
1031
1046
// If there are any pending arguments set them now.
1032
1047
if (!Kernel->PendingArguments .empty ()) {
1033
- UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
1048
+ UR_CALL (
1049
+ setKernelPendingArguments (Device, Kernel->PendingArguments , ZeKernel));
1034
1050
}
1035
1051
1036
1052
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
1037
1053
uint32_t WG[3 ];
1038
- UR_CALL (calculateKernelWorkDimensions (Kernel-> ZeKernel , CommandBuffer-> Device ,
1054
+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, Device,
1039
1055
ZeThreadGroupDimensions, WG, WorkDim,
1040
1056
GlobalWorkSize, LocalWorkSize));
1041
1057
1042
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
1058
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
1043
1059
1044
1060
CommandBuffer->KernelsList .push_back (Kernel);
1045
1061
for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
@@ -1064,7 +1080,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1064
1080
SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1065
1081
1066
1082
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
1067
- (CommandBuffer->ZeComputeCommandList , Kernel-> ZeKernel ,
1083
+ (CommandBuffer->ZeComputeCommandList , ZeKernel,
1068
1084
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size (),
1069
1085
getPointerFromVector (ZeEventList)));
1070
1086
@@ -1837,6 +1853,7 @@ ur_result_t updateKernelCommand(
1837
1853
const auto CommandBuffer = Command->CommandBuffer ;
1838
1854
const void *NextDesc = nullptr ;
1839
1855
auto Platform = CommandBuffer->Context ->getPlatform ();
1856
+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
1840
1857
1841
1858
uint32_t Dim = CommandDesc->newWorkDim ;
1842
1859
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
@@ -1845,11 +1862,14 @@ ur_result_t updateKernelCommand(
1845
1862
1846
1863
// Kernel handle must be updated first for a given CommandId if required
1847
1864
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1865
+
1848
1866
if (NewKernel && Command->Kernel != NewKernel) {
1867
+ ze_kernel_handle_t ZeNewKernel{};
1868
+ UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
1869
+
1849
1870
ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1850
- ZE2UR_CALL (
1851
- zelLoaderTranslateHandle,
1852
- (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1871
+ ZE2UR_CALL (zelLoaderTranslateHandle,
1872
+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
1853
1873
1854
1874
ZE2UR_CALL (Platform->ZeMutableCmdListExt
1855
1875
.zexCommandListUpdateMutableCommandKernelsExp ,
@@ -1906,10 +1926,13 @@ ur_result_t updateKernelCommand(
1906
1926
// by the driver for the kernel.
1907
1927
bool UpdateWGSize = NewLocalWorkSize == nullptr ;
1908
1928
1929
+ ze_kernel_handle_t ZeKernel{};
1930
+ UR_CALL (getZeKernel (ZeDevice, Command->Kernel , &ZeKernel));
1931
+
1909
1932
uint32_t WG[3 ];
1910
- UR_CALL (calculateKernelWorkDimensions (
1911
- Command-> Kernel -> ZeKernel , CommandBuffer-> Device ,
1912
- ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1933
+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, CommandBuffer-> Device ,
1934
+ ZeThreadGroupDimensions, WG, Dim ,
1935
+ NewGlobalWorkSize, NewLocalWorkSize));
1913
1936
1914
1937
auto MutableGroupCountDesc =
1915
1938
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
0 commit comments