@@ -949,41 +949,53 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
949
949
950
950
auto Platform = CommandBuffer->Context ->getPlatform ();
951
951
auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
952
+ ze_command_list_handle_t ZeCommandList =
953
+ CommandBuffer->ZeComputeCommandListTranslated ;
954
+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
955
+ ZeCommandList = CommandBuffer->ZeComputeCommandList ;
956
+ }
952
957
953
958
if (NumKernelAlternatives > 0 ) {
954
959
ZeMutableCommandDesc.flags |=
955
960
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
956
961
957
- std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
958
- NumKernelAlternatives + 1 , nullptr );
962
+ std::vector<ze_kernel_handle_t > KernelHandles (NumKernelAlternatives + 1 ,
963
+ nullptr );
959
964
960
965
ze_kernel_handle_t ZeMainKernel{};
961
966
UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
962
967
963
- // Translate main kernel first
964
- ZE2UR_CALL (zelLoaderTranslateHandle,
965
- (ZEL_HANDLE_KERNEL, ZeMainKernel,
966
- (void **)&TranslatedKernelHandles[0 ]));
968
+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
969
+ KernelHandles[0 ] = ZeMainKernel;
970
+ } else {
971
+ // If the L0 loader is not aware of the MCL extension, the main kernel
972
+ // handle needs to be translated.
973
+ ZE2UR_CALL (zelLoaderTranslateHandle,
974
+ (ZEL_HANDLE_KERNEL, ZeMainKernel, (void **)&KernelHandles[0 ]));
975
+ }
967
976
968
977
for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
969
978
ze_kernel_handle_t ZeAltKernel{};
970
979
UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
971
980
972
- ZE2UR_CALL (zelLoaderTranslateHandle,
973
- (ZEL_HANDLE_KERNEL, ZeAltKernel,
974
- (void **)&TranslatedKernelHandles[i + 1 ]));
981
+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
982
+ KernelHandles[i + 1 ] = ZeAltKernel;
983
+ } else {
984
+ // If the L0 loader is not aware of the MCL extension, the kernel
985
+ // alternatives need to be translated.
986
+ ZE2UR_CALL (zelLoaderTranslateHandle, (ZEL_HANDLE_KERNEL, ZeAltKernel,
987
+ (void **)&KernelHandles[i + 1 ]));
988
+ }
975
989
}
976
990
977
991
ZE2UR_CALL (Platform->ZeMutableCmdListExt
978
992
.zexCommandListGetNextCommandIdWithKernelsExp ,
979
- (CommandBuffer->ZeComputeCommandListTranslated ,
980
- &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
981
- TranslatedKernelHandles.data (), &CommandId));
993
+ (ZeCommandList, &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
994
+ KernelHandles.data (), &CommandId));
982
995
983
996
} else {
984
997
ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
985
- (CommandBuffer->ZeComputeCommandListTranslated ,
986
- &ZeMutableCommandDesc, &CommandId));
998
+ (ZeCommandList, &ZeMutableCommandDesc, &CommandId));
987
999
}
988
1000
DEBUG_LOG (CommandId);
989
1001
@@ -1863,17 +1875,22 @@ ur_result_t updateKernelCommand(
1863
1875
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1864
1876
1865
1877
if (NewKernel && Command->Kernel != NewKernel) {
1878
+ ze_kernel_handle_t KernelHandle{};
1866
1879
ze_kernel_handle_t ZeNewKernel{};
1867
1880
UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
1868
1881
1869
- ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1870
- ZE2UR_CALL (zelLoaderTranslateHandle,
1871
- (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
1882
+ ze_command_list_handle_t ZeCommandList =
1883
+ CommandBuffer->ZeComputeCommandList ;
1884
+ KernelHandle = ZeNewKernel;
1885
+ if (!Platform->ZeMutableCmdListExt .LoaderExtension ) {
1886
+ ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated ;
1887
+ ZE2UR_CALL (zelLoaderTranslateHandle,
1888
+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle));
1889
+ }
1872
1890
1873
1891
ZE2UR_CALL (Platform->ZeMutableCmdListExt
1874
1892
.zexCommandListUpdateMutableCommandKernelsExp ,
1875
- (CommandBuffer->ZeComputeCommandListTranslated , 1 ,
1876
- &Command->CommandId , &ZeKernelTranslated));
1893
+ (ZeCommandList, 1 , &Command->CommandId , &KernelHandle));
1877
1894
// Set current kernel to be the new kernel
1878
1895
Command->Kernel = NewKernel;
1879
1896
}
@@ -2079,9 +2096,15 @@ ur_result_t updateKernelCommand(
2079
2096
MutableCommandDesc.pNext = NextDesc;
2080
2097
MutableCommandDesc.flags = 0 ;
2081
2098
2099
+ ze_command_list_handle_t ZeCommandList =
2100
+ CommandBuffer->ZeComputeCommandListTranslated ;
2101
+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
2102
+ ZeCommandList = CommandBuffer->ZeComputeCommandList ;
2103
+ }
2104
+
2082
2105
ZE2UR_CALL (
2083
2106
Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
2084
- (CommandBuffer-> ZeComputeCommandListTranslated , &MutableCommandDesc));
2107
+ (ZeCommandList , &MutableCommandDesc));
2085
2108
2086
2109
return UR_RESULT_SUCCESS;
2087
2110
}
0 commit comments