Skip to content

Commit af5aa07

Browse files
authored
Merge pull request #2454 from Bensuo/l0_cmd-buf_multi-device
Fix L0 command-buffer consumption of multi-device kernels
2 parents c2dfd9e + 1120f1c commit af5aa07

File tree

3 files changed

+187
-25
lines changed

3 files changed

+187
-25
lines changed

source/adapters/level_zero/command_buffer.cpp

+48-25
Original file line numberDiff line numberDiff line change
@@ -895,28 +895,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
895895
/**
896896
* Sets the kernel arguments for a kernel command that will be appended to the
897897
* command buffer.
898-
* @param[in] CommandBuffer The CommandBuffer where the command will be
898+
* @param[in] Device The Device associated with the command-buffer where the
899+
* kernel command will be appended.
900+
* @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
901+
* on the /p ZeKernel object.
902+
* @param[in] ZeKernel The handle to the Level-Zero kernel that will be
899903
* appended.
900-
* @param[in] Kernel The handle to the kernel that will be appended.
901904
* @return UR_RESULT_SUCCESS or an error code on failure
902905
*/
903-
ur_result_t
904-
setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
905-
ur_kernel_handle_t Kernel) {
906-
906+
ur_result_t setKernelPendingArguments(
907+
ur_device_handle_t Device,
908+
std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
909+
ze_kernel_handle_t ZeKernel) {
907910
// If there are any pending arguments set them now.
908-
for (auto &Arg : Kernel->PendingArguments) {
911+
for (auto &Arg : PendingArguments) {
909912
// The ArgValue may be a NULL pointer in which case a NULL value is used for
910913
// the kernel argument declared as a pointer to global or constant memory.
911914
char **ZeHandlePtr = nullptr;
912915
if (Arg.Value) {
913-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
914-
CommandBuffer->Device, nullptr, 0u));
916+
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
917+
nullptr, 0u));
915918
}
916919
ZE2UR_CALL(zeKernelSetArgumentValue,
917-
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
920+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
918921
}
919-
Kernel->PendingArguments.clear();
922+
PendingArguments.clear();
920923

921924
return UR_RESULT_SUCCESS;
922925
}
@@ -952,21 +955,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
952955
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
953956

954957
auto Platform = CommandBuffer->Context->getPlatform();
958+
auto ZeDevice = CommandBuffer->Device->ZeDevice;
959+
955960
if (NumKernelAlternatives > 0) {
956961
ZeMutableCommandDesc.flags |=
957962
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
958963

959964
std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
960965
NumKernelAlternatives + 1, nullptr);
961966

967+
ze_kernel_handle_t ZeMainKernel{};
968+
UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel));
969+
962970
// Translate main kernel first
963971
ZE2UR_CALL(zelLoaderTranslateHandle,
964-
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
972+
(ZEL_HANDLE_KERNEL, ZeMainKernel,
965973
(void **)&TranslatedKernelHandles[0]));
966974

967975
for (size_t i = 0; i < NumKernelAlternatives; i++) {
976+
ze_kernel_handle_t ZeAltKernel{};
977+
UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel));
978+
968979
ZE2UR_CALL(zelLoaderTranslateHandle,
969-
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
980+
(ZEL_HANDLE_KERNEL, ZeAltKernel,
970981
(void **)&TranslatedKernelHandles[i + 1]));
971982
}
972983

@@ -1023,23 +1034,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10231034
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
10241035
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
10251036

1037+
auto Device = CommandBuffer->Device;
1038+
ze_kernel_handle_t ZeKernel{};
1039+
UR_CALL(getZeKernel(Device->ZeDevice, Kernel, &ZeKernel));
1040+
10261041
if (GlobalWorkOffset != NULL) {
1027-
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, Kernel->ZeKernel,
1028-
WorkDim, GlobalWorkOffset));
1042+
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, ZeKernel, WorkDim,
1043+
GlobalWorkOffset));
10291044
}
10301045

10311046
// If there are any pending arguments set them now.
10321047
if (!Kernel->PendingArguments.empty()) {
1033-
UR_CALL(setKernelPendingArguments(CommandBuffer, Kernel));
1048+
UR_CALL(
1049+
setKernelPendingArguments(Device, Kernel->PendingArguments, ZeKernel));
10341050
}
10351051

10361052
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
10371053
uint32_t WG[3];
1038-
UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device,
1054+
UR_CALL(calculateKernelWorkDimensions(ZeKernel, Device,
10391055
ZeThreadGroupDimensions, WG, WorkDim,
10401056
GlobalWorkSize, LocalWorkSize));
10411057

1042-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
1058+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
10431059

10441060
CommandBuffer->KernelsList.push_back(Kernel);
10451061
for (size_t i = 0; i < NumKernelAlternatives; i++) {
@@ -1064,7 +1080,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10641080
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
10651081

10661082
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
1067-
(CommandBuffer->ZeComputeCommandList, Kernel->ZeKernel,
1083+
(CommandBuffer->ZeComputeCommandList, ZeKernel,
10681084
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size(),
10691085
getPointerFromVector(ZeEventList)));
10701086

@@ -1837,6 +1853,7 @@ ur_result_t updateKernelCommand(
18371853
const auto CommandBuffer = Command->CommandBuffer;
18381854
const void *NextDesc = nullptr;
18391855
auto Platform = CommandBuffer->Context->getPlatform();
1856+
auto ZeDevice = CommandBuffer->Device->ZeDevice;
18401857

18411858
uint32_t Dim = CommandDesc->newWorkDim;
18421859
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
@@ -1845,11 +1862,14 @@ ur_result_t updateKernelCommand(
18451862

18461863
// Kernel handle must be updated first for a given CommandId if required
18471864
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;
1865+
18481866
if (NewKernel && Command->Kernel != NewKernel) {
1867+
ze_kernel_handle_t ZeNewKernel{};
1868+
UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel));
1869+
18491870
ze_kernel_handle_t ZeKernelTranslated = nullptr;
1850-
ZE2UR_CALL(
1851-
zelLoaderTranslateHandle,
1852-
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));
1871+
ZE2UR_CALL(zelLoaderTranslateHandle,
1872+
(ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
18531873

18541874
ZE2UR_CALL(Platform->ZeMutableCmdListExt
18551875
.zexCommandListUpdateMutableCommandKernelsExp,
@@ -1906,10 +1926,13 @@ ur_result_t updateKernelCommand(
19061926
// by the driver for the kernel.
19071927
bool UpdateWGSize = NewLocalWorkSize == nullptr;
19081928

1929+
ze_kernel_handle_t ZeKernel{};
1930+
UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel));
1931+
19091932
uint32_t WG[3];
1910-
UR_CALL(calculateKernelWorkDimensions(
1911-
Command->Kernel->ZeKernel, CommandBuffer->Device,
1912-
ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1933+
UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device,
1934+
ZeThreadGroupDimensions, WG, Dim,
1935+
NewGlobalWorkSize, NewLocalWorkSize));
19131936

19141937
auto MutableGroupCountDesc =
19151938
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
urProgramSetSpecializationConstantsTest.InvalidValueSize/*
22
urProgramSetSpecializationConstantsTest.InvalidValueId/*
33
urProgramSetSpecializationConstantsTest.InvalidValuePtr/*
4+
{{OPT}}urMultiDeviceCommandBufferExpTest.*

test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp

+138
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,141 @@ TEST_F(urMultiDeviceProgramCreateWithBinaryTest, CheckProgramGetInfo) {
240240
reinterpret_cast<char *>(property_value.data());
241241
ASSERT_STRNE(returned_kernel_names, "");
242242
}
243+
244+
struct urMultiDeviceCommandBufferExpTest
245+
: urMultiDeviceProgramCreateWithBinaryTest {
246+
void SetUp() override {
247+
UUR_RETURN_ON_FATAL_FAILURE(
248+
urMultiDeviceProgramCreateWithBinaryTest::SetUp());
249+
250+
auto kernelName =
251+
uur::KernelsEnvironment::instance->GetEntryPointNames("foo")[0];
252+
253+
ASSERT_SUCCESS(urProgramBuild(context, binary_program, nullptr));
254+
ASSERT_SUCCESS(
255+
urKernelCreate(binary_program, kernelName.data(), &kernel));
256+
}
257+
258+
void TearDown() override {
259+
if (kernel) {
260+
EXPECT_SUCCESS(urKernelRelease(kernel));
261+
}
262+
UUR_RETURN_ON_FATAL_FAILURE(
263+
urMultiDeviceProgramCreateWithBinaryTest::TearDown());
264+
}
265+
266+
static bool hasCommandBufferSupport(ur_device_handle_t device) {
267+
ur_bool_t cmd_buffer_support = false;
268+
auto res = urDeviceGetInfo(
269+
device, UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP,
270+
sizeof(cmd_buffer_support), &cmd_buffer_support, nullptr);
271+
272+
if (res) {
273+
return false;
274+
}
275+
276+
return cmd_buffer_support;
277+
}
278+
279+
static bool hasCommandBufferUpdateSupport(ur_device_handle_t device) {
280+
ur_device_command_buffer_update_capability_flags_t
281+
update_capability_flags;
282+
auto res = urDeviceGetInfo(
283+
device, UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_CAPABILITIES_EXP,
284+
sizeof(update_capability_flags), &update_capability_flags, nullptr);
285+
286+
if (res) {
287+
return false;
288+
}
289+
290+
return (0 != update_capability_flags);
291+
}
292+
293+
ur_kernel_handle_t kernel = nullptr;
294+
295+
static constexpr size_t global_offset = 0;
296+
static constexpr size_t n_dimensions = 1;
297+
static constexpr size_t global_size = 64;
298+
static constexpr size_t local_size = 4;
299+
};
300+
301+
TEST_F(urMultiDeviceCommandBufferExpTest, Enqueue) {
302+
for (size_t i = 0; i < devices.size(); i++) {
303+
auto device = devices[i];
304+
if (!hasCommandBufferSupport(device)) {
305+
continue;
306+
}
307+
308+
// Create command-buffer
309+
uur::raii::CommandBuffer cmd_buf_handle;
310+
ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, nullptr,
311+
cmd_buf_handle.ptr()));
312+
313+
// Append kernel command to command-buffer and close command-buffer
314+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
315+
cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size,
316+
&local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr,
317+
nullptr));
318+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
319+
320+
// Verify execution succeeds
321+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
322+
nullptr, nullptr));
323+
ASSERT_SUCCESS(urQueueFinish(queues[i]));
324+
}
325+
}
326+
327+
TEST_F(urMultiDeviceCommandBufferExpTest, Update) {
328+
for (size_t i = 0; i < devices.size(); i++) {
329+
auto device = devices[i];
330+
if (!(hasCommandBufferSupport(device) &&
331+
hasCommandBufferUpdateSupport(device))) {
332+
continue;
333+
}
334+
335+
// Create a command-buffer with update enabled.
336+
ur_exp_command_buffer_desc_t desc{
337+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, true, false,
338+
false};
339+
340+
// Create command-buffer
341+
uur::raii::CommandBuffer cmd_buf_handle;
342+
ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, &desc,
343+
cmd_buf_handle.ptr()));
344+
345+
// Append kernel command to command-buffer and close command-buffer
346+
uur::raii::CommandBufferCommand command;
347+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
348+
cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size,
349+
&local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr,
350+
command.ptr()));
351+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
352+
353+
// Verify execution succeeds
354+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
355+
nullptr, nullptr));
356+
ASSERT_SUCCESS(urQueueFinish(queues[i]));
357+
358+
// Update kernel and enqueue command-buffer again
359+
ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
360+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
361+
nullptr, // pNext
362+
kernel, // hNewKernel
363+
0, // numNewMemObjArgs
364+
0, // numNewPointerArgs
365+
0, // numNewValueArgs
366+
n_dimensions, // newWorkDim
367+
nullptr, // pNewMemObjArgList
368+
nullptr, // pNewPointerArgList
369+
nullptr, // pNewValueArgList
370+
nullptr, // pNewGlobalWorkOffset
371+
nullptr, // pNewGlobalWorkSize
372+
nullptr, // pNewLocalWorkSize
373+
};
374+
ASSERT_SUCCESS(
375+
urCommandBufferUpdateKernelLaunchExp(command, &update_desc));
376+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
377+
nullptr, nullptr));
378+
ASSERT_SUCCESS(urQueueFinish(queues[i]));
379+
}
380+
}

0 commit comments

Comments
 (0)