diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index ced2d0286b..36cf76d111 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -23,8 +23,8 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_( : Context(Context), Device(Device), ZeCommandList(CommandList), ZeCommandListResetEvents(CommandListResetEvents), ZeCommandListDesc(ZeDesc), ZeFencesList(), QueueProperties(), - SyncPoints(), NextSyncPoint(0) { - (void)Desc; + SyncPoints(), NextSyncPoint(0), + IsUpdatable(Desc ? Desc->isUpdatable : false) { urContextRetain(Context); urDeviceRetain(Device); } @@ -77,59 +77,79 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() { } } +ur_exp_command_buffer_command_handle_t_:: + ur_exp_command_buffer_command_handle_t_( + ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId, + ur_kernel_handle_t Kernel = nullptr) + : CommandBuffer(CommandBuffer), CommandId(CommandId), Kernel(Kernel) { + urCommandBufferRetainExp(CommandBuffer); + if (Kernel) + urKernelRetain(Kernel); +} + +ur_exp_command_buffer_command_handle_t_:: + ~ur_exp_command_buffer_command_handle_t_() { + urCommandBufferReleaseExp(CommandBuffer); + if (Kernel) + urKernelRelease(Kernel); +} + /// Helper function for calculating work dimensions for kernels ur_result_t calculateKernelWorkDimensions( ur_kernel_handle_t Kernel, ur_device_handle_t Device, ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], uint32_t WorkDim, const size_t *GlobalWorkSize, const size_t *LocalWorkSize) { - // global_work_size of unused dimensions must be set to 1 - UR_ASSERT(WorkDim == 3 || GlobalWorkSize[2] == 1, - UR_RESULT_ERROR_INVALID_VALUE); - UR_ASSERT(WorkDim >= 2 || GlobalWorkSize[1] == 1, - UR_RESULT_ERROR_INVALID_VALUE); + + UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); + // If LocalWorkSize is not provided then Kernel must be provided to query + // suggested group size. + UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); + + // New variable needed because GlobalWorkSize parameter might not be of size 3 + size_t GlobalWorkSize3D[3]{1, 1, 1}; + std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); if (LocalWorkSize) { WG[0] = ur_cast(LocalWorkSize[0]); - WG[1] = ur_cast(LocalWorkSize[1]); - WG[2] = ur_cast(LocalWorkSize[2]); + WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; + WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; } else { - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize + // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D // values do not fit to 32-bit that the API only supports currently. bool SuggestGroupSize = true; for (int I : {0, 1, 2}) { - if (GlobalWorkSize[I] > UINT32_MAX) { + if (GlobalWorkSize3D[I] > UINT32_MAX) { SuggestGroupSize = false; } } if (SuggestGroupSize) { ZE2UR_CALL(zeKernelSuggestGroupSize, - (Kernel->ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1], - GlobalWorkSize[2], &WG[0], &WG[1], &WG[2])); + (Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], + GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); } else { for (int I : {0, 1, 2}) { - // Try to find a I-dimension WG size that the GlobalWorkSize[I] is + // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is // fully divisable with. Start with the max possible size in // each dimension. uint32_t GroupSize[] = { Device->ZeDeviceComputeProperties->maxGroupSizeX, Device->ZeDeviceComputeProperties->maxGroupSizeY, Device->ZeDeviceComputeProperties->maxGroupSizeZ}; - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize[I]); - while (GlobalWorkSize[I] % GroupSize[I]) { + GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); + while (GlobalWorkSize3D[I] % GroupSize[I]) { --GroupSize[I]; } - if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { - urPrint("urCommandBufferAppendKernelLaunchExp: can't find a WG size " + if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) { + urPrint("calculateKernelWorkDimensions: can't find a WG size " "suitable for global work size > UINT32_MAX\n"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; } WG[I] = GroupSize[I]; } - urPrint( - "urCommandBufferAppendKernelLaunchExp: using computed WG size = {%d, " - "%d, %d}\n", - WG[0], WG[1], WG[2]); + urPrint("calculateKernelWorkDimensions: using computed WG size = {%d, " + "%d, %d}\n", + WG[0], WG[1], WG[2]); } } @@ -137,48 +157,48 @@ ur_result_t calculateKernelWorkDimensions( switch (WorkDim) { case 3: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize[0] / WG[0]); + ur_cast(GlobalWorkSize3D[0] / WG[0]); ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize[1] / WG[1]); + ur_cast(GlobalWorkSize3D[1] / WG[1]); ZeThreadGroupDimensions.groupCountZ = - ur_cast(GlobalWorkSize[2] / WG[2]); + ur_cast(GlobalWorkSize3D[2] / WG[2]); break; case 2: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize[0] / WG[0]); + ur_cast(GlobalWorkSize3D[0] / WG[0]); ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize[1] / WG[1]); + ur_cast(GlobalWorkSize3D[1] / WG[1]); WG[2] = 1; break; case 1: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize[0] / WG[0]); + ur_cast(GlobalWorkSize3D[0] / WG[0]); WG[1] = WG[2] = 1; break; default: - urPrint("urCommandBufferAppendKernelLaunchExp: unsupported work_dim\n"); + urPrint("calculateKernelWorkDimensions: unsupported work_dim\n"); return UR_RESULT_ERROR_INVALID_VALUE; } // Error handling for non-uniform group size case - if (GlobalWorkSize[0] != + if (GlobalWorkSize3D[0] != size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { - urPrint("urCommandBufferAppendKernelLaunchExp: invalid work_dim. The range " + urPrint("calculateKernelWorkDimensions: invalid work_dim. The range " "is not a " "multiple of the group size in the 1st dimension\n"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; } - if (GlobalWorkSize[1] != + if (GlobalWorkSize3D[1] != size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { - urPrint("urCommandBufferAppendKernelLaunchExp: invalid work_dim. The range " + urPrint("calculateKernelWorkDimensions: invalid work_dim. The range " "is not a " "multiple of the group size in the 2nd dimension\n"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; } - if (GlobalWorkSize[2] != + if (GlobalWorkSize3D[2] != size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { - urPrint("urCommandBufferAppendKernelLaunchExp: invalid work_dim. The range " + urPrint("calculateKernelWorkDimensions: invalid work_dim. The range " "is not a " "multiple of the group size in the 3rd dimension\n"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; @@ -204,6 +224,9 @@ static ur_result_t getEventsFromSyncPoints( size_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, std::vector &ZeEventList) { + if (!SyncPointWaitList || NumSyncPointsInWaitList == 0) + return UR_RESULT_SUCCESS; + // Map of ur_exp_command_buffer_sync_point_t to ur_event_handle_t defining // the event associated with each sync-point auto SyncPoints = CommandBuffer->SyncPoints; @@ -386,6 +409,12 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device, // can enable the backend to further optimize the workload ZeCommandListDesc.flags = ZE_COMMAND_LIST_FLAG_RELAXED_ORDERING; + ZeStruct ZeMutableCommandListDesc; + if (CommandBufferDesc && CommandBufferDesc->isUpdatable) { + ZeMutableCommandListDesc.flags = 0; + ZeCommandListDesc.pNext = &ZeMutableCommandListDesc; + } + ze_command_list_handle_t ZeCommandList; // TODO We could optimize this by pooling both Level Zero command-lists and UR // command-buffers, then reusing them. @@ -441,6 +470,10 @@ urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) { UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) { + UR_ASSERT(CommandBuffer, UR_RESULT_ERROR_INVALID_NULL_POINTER); + // It is not allowed to append to command list from multiple threads. + std::scoped_lock Guard(CommandBuffer->Mutex); + // Create a list of events for our signal event to wait on // This loop also resets the L0 events we use for command-buffer internal // sync-points to the non-signaled state. @@ -465,6 +498,7 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) { // Close the command lists and have them ready for dispatch. ZE2UR_CALL(zeCommandListClose, (CommandBuffer->ZeCommandList)); ZE2UR_CALL(zeCommandListClose, (CommandBuffer->ZeCommandListResetEvents)); + CommandBuffer->IsFinalized = true; return UR_RESULT_SUCCESS; } @@ -475,10 +509,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( uint32_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, ur_exp_command_buffer_sync_point_t *SyncPoint, - ur_exp_command_buffer_command_handle_t *) { + ur_exp_command_buffer_command_handle_t *Command) { + UR_ASSERT(CommandBuffer && Kernel && Kernel->Program, + UR_RESULT_ERROR_INVALID_NULL_POINTER); // Lock automatically releases when this goes out of scope. - std::scoped_lock Lock( - Kernel->Mutex, Kernel->Program->Mutex); + std::scoped_lock Lock( + Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex); if (GlobalWorkOffset != NULL) { if (!CommandBuffer->Context->getPlatform() @@ -523,9 +559,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent)); LaunchEvent->CommandType = UR_COMMAND_KERNEL_LAUNCH; - // Get sync point and register the event with it. - *SyncPoint = CommandBuffer->GetNextSyncPoint(); - CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent); + if (SyncPoint) { + // Get sync point and register the event with it. + *SyncPoint = CommandBuffer->GetNextSyncPoint(); + CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent); + } LaunchEvent->CommandData = (void *)Kernel; // Increment the reference count of the Kernel and indicate that the Kernel @@ -534,6 +572,35 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( // reference count on the kernel, using the kernel saved in CommandData. UR_CALL(urKernelRetain(Kernel)); + // If command-buffer is updatable then get command id which is going to be + // used if command is updated in the future. This + // zeCommandListGetNextCommandIdExp can be called only if command is + // updatable. + uint64_t CommandId = 0; + if (CommandBuffer->IsUpdatable) { + ZeStruct ZeMutableCommandDesc; + ZeMutableCommandDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET; + + auto Plt = CommandBuffer->Context->getPlatform(); + UR_ASSERT(Plt->ZeMutableCmdListExt.Supported, + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + ZE2UR_CALL( + Plt->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp, + (CommandBuffer->ZeCommandList, &ZeMutableCommandDesc, &CommandId)); + } + try { + if (Command) + *Command = new ur_exp_command_buffer_command_handle_t_(CommandBuffer, + CommandId, Kernel); + } catch (const std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; + } + ZE2UR_CALL(zeCommandListAppendLaunchKernel, (CommandBuffer->ZeCommandList, Kernel->ZeKernel, &ZeThreadGroupDimensions, LaunchEvent->ZeEvent, @@ -942,20 +1009,296 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL -urCommandBufferRetainCommandExp(ur_exp_command_buffer_command_handle_t) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainCommandExp( + ur_exp_command_buffer_command_handle_t Command) { + Command->RefCount.increment(); + return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL -urCommandBufferReleaseCommandExp(ur_exp_command_buffer_command_handle_t) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( + ur_exp_command_buffer_command_handle_t Command) { + if (!Command->RefCount.decrementAndTest()) + return UR_RESULT_SUCCESS; + + delete Command; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t, - const ur_exp_command_buffer_update_kernel_launch_desc_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + ur_exp_command_buffer_command_handle_t Command, + const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { + UR_ASSERT(Command, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(CommandDesc, UR_RESULT_ERROR_INVALID_NULL_POINTER); + UR_ASSERT(CommandDesc->newWorkDim >= 0 && CommandDesc->newWorkDim <= 3, + UR_RESULT_ERROR_INVALID_WORK_DIMENSION); + + // Lock command, kernel and command buffer for update. + std::scoped_lock Guard( + Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex); + UR_ASSERT(Command->CommandBuffer->IsUpdatable, + UR_RESULT_ERROR_INVALID_OPERATION); + UR_ASSERT(Command->CommandBuffer->IsFinalized, + UR_RESULT_ERROR_INVALID_OPERATION); + + auto CommandBuffer = Command->CommandBuffer; + uint32_t Dim = CommandDesc->newWorkDim; + const void *NextDesc = nullptr; + auto SupportedFeatures = + Command->CommandBuffer->Device->ZeDeviceMutableCmdListsProperties + ->mutableCommandFlags; + + // We need the created descriptors to live till the point when + // zexCommandListUpdateMutableCommandsExp is called at the end of the + // function. + std::vector>> + ArgDescs; + std::vector>> + OffsetDescs; + std::vector>> + GroupSizeDescs; + std::vector>> + GroupCountDescs; + + // Check if new global offset is provided. + size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset; + UR_ASSERT(!NewGlobalWorkOffset || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + if (NewGlobalWorkOffset && Dim > 0) { + if (!CommandBuffer->Context->getPlatform() + ->ZeDriverGlobalOffsetExtensionFound) { + urPrint("No global offset extension found on this driver\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + auto MutableGroupOffestDesc = + std::make_unique>(); + MutableGroupOffestDesc->commandId = Command->CommandId; + MutableGroupOffestDesc->pNext = NextDesc; + MutableGroupOffestDesc->offsetX = NewGlobalWorkOffset[0]; + MutableGroupOffestDesc->offsetY = Dim >= 2 ? NewGlobalWorkOffset[1] : 0; + MutableGroupOffestDesc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2] : 0; + NextDesc = MutableGroupOffestDesc.get(); + OffsetDescs.push_back(std::move(MutableGroupOffestDesc)); + } + + // Check if new group size is provided. + size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize; + UR_ASSERT(!NewLocalWorkSize || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + if (NewLocalWorkSize && Dim > 0) { + auto MutableGroupSizeDesc = + std::make_unique>(); + MutableGroupSizeDesc->commandId = Command->CommandId; + MutableGroupSizeDesc->pNext = NextDesc; + MutableGroupSizeDesc->groupSizeX = NewLocalWorkSize[0]; + MutableGroupSizeDesc->groupSizeY = Dim >= 2 ? NewLocalWorkSize[1] : 1; + MutableGroupSizeDesc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2] : 1; + NextDesc = MutableGroupSizeDesc.get(); + GroupSizeDescs.push_back(std::move(MutableGroupSizeDesc)); + } + + // Check if new global size is provided and we need to update group count. + size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize; + UR_ASSERT(!NewGlobalWorkSize || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + UR_ASSERT(!(NewGlobalWorkSize && !NewLocalWorkSize) || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + if (NewGlobalWorkSize && Dim > 0) { + ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; + uint32_t WG[3]; + // If new global work size is provided but new local work size is not + // provided then we still need to update local work size based on size + // suggested by the driver for the kernel. + bool UpdateWGSize = NewLocalWorkSize == nullptr; + UR_CALL(calculateKernelWorkDimensions( + Command->Kernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, + Dim, NewGlobalWorkSize, NewLocalWorkSize)); + auto MutableGroupCountDesc = + std::make_unique>(); + MutableGroupCountDesc->pNext = NextDesc; + MutableGroupCountDesc->commandId = Command->CommandId; + MutableGroupCountDesc->pGroupCount = &ZeThreadGroupDimensions; + NextDesc = MutableGroupCountDesc.get(); + GroupCountDescs.push_back(std::move(MutableGroupCountDesc)); + + if (UpdateWGSize) { + auto MutableGroupSizeDesc = + std::make_unique>(); + MutableGroupSizeDesc->commandId = Command->CommandId; + MutableGroupSizeDesc->pNext = NextDesc; + MutableGroupSizeDesc->groupSizeX = WG[0]; + MutableGroupSizeDesc->groupSizeY = WG[1]; + MutableGroupSizeDesc->groupSizeZ = WG[2]; + NextDesc = MutableGroupSizeDesc.get(); + GroupSizeDescs.push_back(std::move(MutableGroupSizeDesc)); + } + } + + UR_ASSERT( + (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs && + !CommandDesc->numNewValueArgs) || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + + // Check if new memory object arguments are provided. + for (uint32_t NewMemObjArgNum = CommandDesc->numNewMemObjArgs; + NewMemObjArgNum-- > 0;) { + ur_exp_command_buffer_update_memobj_arg_desc_t NewMemObjArgDesc = + CommandDesc->pNewMemObjArgList[NewMemObjArgNum]; + const ur_kernel_arg_mem_obj_properties_t *Properties = + NewMemObjArgDesc.pProperties; + ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write; + if (Properties) { + switch (Properties->memoryAccess) { + case UR_MEM_FLAG_READ_WRITE: + UrAccessMode = ur_mem_handle_t_::read_write; + break; + case UR_MEM_FLAG_WRITE_ONLY: + UrAccessMode = ur_mem_handle_t_::write_only; + break; + case UR_MEM_FLAG_READ_ONLY: + UrAccessMode = ur_mem_handle_t_::read_only; + break; + default: + return UR_RESULT_ERROR_INVALID_ARGUMENT; + } + } + ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg; + // The NewMemObjArg may be a NULL pointer in which case a NULL value is used + // for the kernel argument declared as a pointer to global or constant + // memory. + char **ZeHandlePtr = nullptr; + if (NewMemObjArg) { + UR_CALL(NewMemObjArg->getZeHandlePtr(ZeHandlePtr, UrAccessMode, + CommandBuffer->Device)); + } + auto ZeMutableArgDesc = + std::make_unique>(); + ZeMutableArgDesc->commandId = Command->CommandId; + ZeMutableArgDesc->pNext = NextDesc; + ZeMutableArgDesc->argIndex = NewMemObjArgDesc.argIndex; + ZeMutableArgDesc->argSize = sizeof(void *); + ZeMutableArgDesc->pArgValue = ZeHandlePtr; + + NextDesc = ZeMutableArgDesc.get(); + ArgDescs.push_back(std::move(ZeMutableArgDesc)); + } + + // Check if there are new pointer arguments. + for (uint32_t NewPointerArgNum = CommandDesc->numNewPointerArgs; + NewPointerArgNum-- > 0;) { + ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc = + CommandDesc->pNewPointerArgList[NewPointerArgNum]; + auto ZeMutableArgDesc = + std::make_unique>(); + ZeMutableArgDesc->commandId = Command->CommandId; + ZeMutableArgDesc->pNext = NextDesc; + ZeMutableArgDesc->argIndex = NewPointerArgDesc.argIndex; + ZeMutableArgDesc->argSize = sizeof(void *); + ZeMutableArgDesc->pArgValue = NewPointerArgDesc.pNewPointerArg; + + NextDesc = ZeMutableArgDesc.get(); + ArgDescs.push_back(std::move(ZeMutableArgDesc)); + } + + // Check if there are new value arguments. + for (uint32_t NewValueArgNum = CommandDesc->numNewValueArgs; + NewValueArgNum-- > 0;) { + ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc = + CommandDesc->pNewValueArgList[NewValueArgNum]; + auto ZeMutableArgDesc = + std::make_unique>(); + ZeMutableArgDesc->commandId = Command->CommandId; + ZeMutableArgDesc->pNext = NextDesc; + ZeMutableArgDesc->argIndex = NewValueArgDesc.argIndex; + ZeMutableArgDesc->argSize = NewValueArgDesc.argSize; + // OpenCL: "the arg_value pointer can be NULL or point to a NULL value + // in which case a NULL value will be used as the value for the argument + // declared as a pointer to global or constant memory in the kernel" + // + // We don't know the type of the argument but it seems that the only time + // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument + // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. + const void *ArgValuePtr = NewValueArgDesc.pNewValueArg; + if (NewValueArgDesc.argSize == sizeof(void *) && ArgValuePtr && + *(void **)(const_cast(ArgValuePtr)) == nullptr) { + ArgValuePtr = nullptr; + } + ZeMutableArgDesc->pArgValue = ArgValuePtr; + NextDesc = ZeMutableArgDesc.get(); + ArgDescs.push_back(std::move(ZeMutableArgDesc)); + } + + // Check if there are new exec info flags provided. + for (uint32_t NewExecInfoNum = CommandDesc->numNewExecInfos; + NewExecInfoNum-- > 0;) { + ur_exp_command_buffer_update_exec_info_desc_t NewExecInfoDesc = + CommandDesc->pNewExecInfoList[NewExecInfoNum]; + ur_kernel_exec_info_t PropName = NewExecInfoDesc.propName; + const void *PropValue = NewExecInfoDesc.pNewExecInfo; + if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS) { + // The whole point for users really was to not need to know anything + // about the types of allocations kernel uses. So in DPC++ we always + // just set all 3 modes for each kernel. + if (*(static_cast(PropValue)) == true) { + ze_kernel_indirect_access_flags_t IndirectFlags = + ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST | + ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE | + ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED; + ZE2UR_CALL(zeKernelSetIndirectAccess, + (Command->Kernel->ZeKernel, IndirectFlags)); + } + } else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) { + ze_cache_config_flag_t ZeCacheConfig{}; + auto CacheConfig = + *(static_cast(PropValue)); + switch (CacheConfig) { + case UR_KERNEL_CACHE_CONFIG_LARGE_SLM: + ZeCacheConfig = ZE_CACHE_CONFIG_FLAG_LARGE_SLM; + break; + case UR_KERNEL_CACHE_CONFIG_LARGE_DATA: + ZeCacheConfig = ZE_CACHE_CONFIG_FLAG_LARGE_DATA; + break; + case UR_KERNEL_CACHE_CONFIG_DEFAULT: + ZeCacheConfig = static_cast(0); + break; + default: + // Unexpected cache configuration value. + return UR_RESULT_ERROR_INVALID_VALUE; + } + ZE2UR_CALL(zeKernelSetCacheConfig, + (Command->Kernel->ZeKernel, ZeCacheConfig);); + } else if (PropName == UR_KERNEL_EXEC_INFO_USM_PTRS) { + // Ignore this property as such kernel property is not supported by Level + // Zero. + continue; + } else { + urPrint("urCommandBufferUpdateKernelLaunchExp: unsupported name of " + "execution attribute.\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + } + + ZeStruct MutableCommandDesc; + MutableCommandDesc.pNext = NextDesc; + MutableCommandDesc.flags = 0; + + // We must synchronize mutable command list execution before mutating. + ZE2UR_CALL(zeEventHostSynchronize, + (CommandBuffer->SignalEvent->ZeEvent, UINT64_MAX)); + + auto Plt = Command->CommandBuffer->Context->getPlatform(); + UR_ASSERT(Plt->ZeMutableCmdListExt.Supported, + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + ZE2UR_CALL(Plt->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp, + (CommandBuffer->ZeCommandList, &MutableCommandDesc)); + ZE2UR_CALL(zeCommandListClose, (CommandBuffer->ZeCommandList)); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( @@ -975,7 +1318,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( - ur_exp_command_buffer_command_handle_t, - ur_exp_command_buffer_command_info_t, size_t, void *, size_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + ur_exp_command_buffer_command_handle_t Command, + ur_exp_command_buffer_command_info_t PropName, size_t PropSize, + void *PropValue, size_t *PropSizeRet) { + UrReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + + switch (PropName) { + case UR_EXP_COMMAND_BUFFER_COMMAND_INFO_REFERENCE_COUNT: + return ReturnValue(uint32_t{Command->RefCount.load()}); + default: + assert(!"Command-buffer command info request not implemented"); + } + + return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/level_zero/command_buffer.hpp b/source/adapters/level_zero/command_buffer.hpp index 843d9d3f37..67f4afd54c 100644 --- a/source/adapters/level_zero/command_buffer.hpp +++ b/source/adapters/level_zero/command_buffer.hpp @@ -78,4 +78,21 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object { // Event which a command-buffer waits on until the main command-list event // have been reset. ur_event_handle_t AllResetEvent = nullptr; + // Indicates if command-buffer commands can be updated after it is closed. + bool IsUpdatable = false; + // Indicates if command buffer was finalized. + bool IsFinalized = false; +}; + +struct ur_exp_command_buffer_command_handle_t_ : public _ur_object { + ur_exp_command_buffer_command_handle_t_(ur_exp_command_buffer_handle_t, + uint64_t, ur_kernel_handle_t); + + ~ur_exp_command_buffer_command_handle_t_(); + + // Command-buffer of this command. + ur_exp_command_buffer_handle_t CommandBuffer; + + uint64_t CommandId; + ur_kernel_handle_t Kernel; }; diff --git a/source/adapters/level_zero/common.cpp b/source/adapters/level_zero/common.cpp index 1d3c53ca3b..a927c8b444 100644 --- a/source/adapters/level_zero/common.cpp +++ b/source/adapters/level_zero/common.cpp @@ -58,6 +58,8 @@ ur_result_t ze2urResult(ze_result_t ZeResult) { return UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY; case ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY: return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + case ZE_RESULT_ERROR_UNSUPPORTED_FEATURE: + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; default: return UR_RESULT_ERROR_UNKNOWN; } @@ -171,6 +173,40 @@ template <> ze_structure_type_t getZeStructureType() { template <> ze_structure_type_t getZeStructureType() { return ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC; } +template <> +ze_structure_type_t +getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_COMMAND_LIST_EXP_PROPERTIES; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_COMMAND_LIST_EXP_DESC; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_COMMAND_ID_EXP_DESC; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC; +} +template <> +ze_structure_type_t +getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC; +} +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MUTABLE_COMMANDS_EXP_DESC; +} template <> ze_structure_type_t getZeStructureType() { return ZE_STRUCTURE_TYPE_CONTEXT_DESC; } diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index 9fae1ed4af..5d0fbdbc11 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -917,8 +917,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo( } case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: return ReturnValue(true); - case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: - return ReturnValue(false); + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: { + // TODO: Level Zero API allows to check support for all sub-features: + // ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS, + // ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT, + // ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE, + // ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET, + // ZE_MUTABLE_COMMAND_EXP_FLAG_SIGNAL_EVENT, + // ZE_MUTABLE_COMMAND_EXP_FLAG_WAIT_EVENTS + // but UR has only one property to check the mutable command lists feature + // support. For now return true if kernel arguments can be updated. + auto KernelArgUpdateSupport = + Device->ZeDeviceMutableCmdListsProperties->mutableCommandFlags & + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS; + return ReturnValue(KernelArgUpdateSupport && + Device->Platform->ZeMutableCmdListExt.Supported); + } case UR_DEVICE_INFO_BINDLESS_IMAGES_SUPPORT_EXP: return ReturnValue(true); case UR_DEVICE_INFO_BINDLESS_IMAGES_SHARED_USM_SUPPORT_EXP: @@ -1142,6 +1156,15 @@ ur_result_t ur_device_handle_t_::initialize(int SubSubDeviceOrdinal, (ZeDevice, &Count, &Properties)); }; + ZeDeviceMutableCmdListsProperties.Compute = + [ZeDevice]( + ZeStruct &Properties) { + ze_device_properties_t P; + P.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + P.pNext = &Properties; + ZE_CALL_NOCHECK(zeDeviceGetProperties, (ZeDevice, &P)); + }; + ImmCommandListUsed = this->useImmediateCommandLists(); uint32_t numQueueGroups = 0; diff --git a/source/adapters/level_zero/device.hpp b/source/adapters/level_zero/device.hpp index a57a97d38d..484890670b 100644 --- a/source/adapters/level_zero/device.hpp +++ b/source/adapters/level_zero/device.hpp @@ -195,4 +195,6 @@ struct ur_device_handle_t_ : _ur_object { ZeCache> ZeDeviceCacheProperties; ZeCache> ZeDeviceIpVersionExt; ZeCache ZeGlobalMemSize; + ZeCache> + ZeDeviceMutableCmdListsProperties; }; diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index ab577247bd..19ecac7fee 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -206,6 +206,39 @@ ur_result_t ur_platform_handle_t_::initialize() { // If yes, then set up L0 API pointers if the platform supports it. ZeUSMImport.setZeUSMImport(this); + // Check if mutable command list extension is supported and initialize + // function pointers. + ZeMutableCmdListExt.Supported |= + (ZE_CALL_NOCHECK( + zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListGetNextCommandIdExp", + reinterpret_cast( + &ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp))) == 0); + + ZeMutableCmdListExt.Supported &= + (ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListUpdateMutableCommandsExp", + reinterpret_cast( + &ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandsExp))) == + 0); + + ZeMutableCmdListExt.Supported &= + (ZE_CALL_NOCHECK( + zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListUpdateMutableCommandSignalEventExp", + reinterpret_cast( + &ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandSignalEventExp))) == 0); + + ZeMutableCmdListExt.Supported &= + (ZE_CALL_NOCHECK( + zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListUpdateMutableCommandWaitEventsExp", + reinterpret_cast( + &ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandWaitEventsExp))) == 0); + return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/platform.hpp b/source/adapters/level_zero/platform.hpp index 86aa4ec745..d2ef19fd7e 100644 --- a/source/adapters/level_zero/platform.hpp +++ b/source/adapters/level_zero/platform.hpp @@ -55,4 +55,22 @@ struct ur_platform_handle_t_ : public _ur_platform { // in the driver. std::list Contexts; ur_shared_mutex ContextsMutex; + + // Structure with function pointers for mutable command list extension. + // Not all drivers may support it, so considering that the platform object is + // associated with particular Level Zero driver, store this extension here. + struct ZeMutableCmdListExtension { + bool Supported = false; + ze_result_t (*zexCommandListGetNextCommandIdExp)( + ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *, + uint64_t *) = nullptr; + ze_result_t (*zexCommandListUpdateMutableCommandsExp)( + ze_command_list_handle_t, + const ze_mutable_commands_exp_desc_t *) = nullptr; + ze_result_t (*zexCommandListUpdateMutableCommandSignalEventExp)( + ze_command_list_handle_t, uint64_t, ze_event_handle_t) = nullptr; + ze_result_t (*zexCommandListUpdateMutableCommandWaitEventsExp)( + ze_command_list_handle_t, uint64_t, uint32_t, + ze_event_handle_t *) = nullptr; + } ZeMutableCmdListExt; }; diff --git a/test/conformance/exp_command_buffer/buffer_fill_kernel_update.cpp b/test/conformance/exp_command_buffer/buffer_fill_kernel_update.cpp index e7fac99800..3928e76cf1 100644 --- a/test/conformance/exp_command_buffer/buffer_fill_kernel_update.cpp +++ b/test/conformance/exp_command_buffer/buffer_fill_kernel_update.cpp @@ -123,6 +123,10 @@ TEST_P(BufferFillCommandTest, UpdateParameters) { // Test updating the global size so that the fill outputs to a larger buffer TEST_P(BufferFillCommandTest, UpdateGlobalSize) { + if (!updatable_execution_range_support) { + GTEST_SKIP() << "Execution range update is not supported."; + } + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -153,7 +157,7 @@ TEST_P(BufferFillCommandTest, UpdateGlobalSize) { 0, // numNewPointerArgs 0, // numNewValueArgs 0, // numNewExecInfos - 0, // newWorkDim + 1, // newWorkDim &new_output_desc, // pNewMemObjArgList nullptr, // pNewPointerArgList nullptr, // pNewValueArgList @@ -180,7 +184,8 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { ASSERT_SUCCESS(urQueueFinish(queue)); ValidateBuffer(buffer, sizeof(val) * global_size, val); - size_t new_global_size = 64; + size_t new_global_size = + updatable_execution_range_support ? 64 : global_size; const size_t new_buffer_size = sizeof(val) * new_global_size; ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, new_buffer_size, nullptr, &new_buffer)); @@ -247,25 +252,28 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, &input_update_desc)); - ur_exp_command_buffer_update_kernel_launch_desc_t global_size_update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - 0, // numNewMemObjArgs - 0, // numNewPointerArgs - 0, // numNewValueArgs - 0, // numNewExecInfos - 0, // newWorkDim - nullptr, // pNewMemObjArgList - nullptr, // pNewPointerArgList - nullptr, // pNewValueArgList - nullptr, // pNewExecInfoList - nullptr, // pNewGlobalWorkOffset - &new_global_size, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; - - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - command_handle, &global_size_update_desc)); + if (updatable_execution_range_support) { + ur_exp_command_buffer_update_kernel_launch_desc_t + global_size_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 0, // numNewValueArgs + 0, // numNewExecInfos + 0, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + nullptr, // pNewValueArgList + nullptr, // pNewExecInfoList + nullptr, // pNewGlobalWorkOffset + &new_global_size, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + command_handle, &global_size_update_desc)); + } ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); diff --git a/test/conformance/exp_command_buffer/fixtures.h b/test/conformance/exp_command_buffer/fixtures.h index c8a198224b..2f9656c0f9 100644 --- a/test/conformance/exp_command_buffer/fixtures.h +++ b/test/conformance/exp_command_buffer/fixtures.h @@ -112,6 +112,11 @@ struct urUpdatableCommandBufferExpExecutionTest GTEST_SKIP() << "Updating EXP command-buffers is not supported."; } + // Currently level zero driver doesn't support updating execution range. + if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) { + updatable_execution_range_support = false; + } + // Create a command-buffer with update enabled. ur_exp_command_buffer_desc_t desc{ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, true}; @@ -119,17 +124,38 @@ struct urUpdatableCommandBufferExpExecutionTest ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, &desc, &updatable_cmd_buf_handle)); ASSERT_NE(updatable_cmd_buf_handle, nullptr); + + // Currently there are synchronization issue with immediate submission when used for command buffers. + // So, create queue with batched submission for this test suite if the backend is Level Zero. + if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) { + ur_queue_flags_t flags = UR_QUEUE_FLAG_SUBMISSION_BATCHED; + ur_queue_properties_t props = { + /*.stype =*/UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, + /*.pNext =*/nullptr, + /*.flags =*/flags, + }; + ASSERT_SUCCESS(urQueueCreate(context, device, &props, &queue)); + ASSERT_NE(queue, nullptr); + } else { + queue = urCommandBufferExpExecutionTest::queue; + } } void TearDown() override { if (updatable_cmd_buf_handle) { EXPECT_SUCCESS(urCommandBufferReleaseExp(updatable_cmd_buf_handle)); } + if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO && queue) { + ASSERT_SUCCESS(urQueueRelease(queue)); + } + UUR_RETURN_ON_FATAL_FAILURE( urCommandBufferExpExecutionTest::TearDown()); } ur_exp_command_buffer_handle_t updatable_cmd_buf_handle = nullptr; + ur_bool_t updatable_execution_range_support = true; + ur_queue_handle_t queue = nullptr; }; struct urCommandBufferCommandExpTest diff --git a/test/conformance/exp_command_buffer/ndrange_update.cpp b/test/conformance/exp_command_buffer/ndrange_update.cpp index e5631f9176..217bd87f2a 100644 --- a/test/conformance/exp_command_buffer/ndrange_update.cpp +++ b/test/conformance/exp_command_buffer/ndrange_update.cpp @@ -15,6 +15,10 @@ struct NDRangeUpdateTest UUR_RETURN_ON_FATAL_FAILURE( urUpdatableCommandBufferExpExecutionTest::SetUp()); + if (!updatable_execution_range_support) { + GTEST_SKIP() << "Execution range update is not supported."; + } + ur_device_usm_access_capability_flags_t shared_usm_flags; ASSERT_SUCCESS( uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags)); diff --git a/test/conformance/exp_command_buffer/usm_fill_kernel_update.cpp b/test/conformance/exp_command_buffer/usm_fill_kernel_update.cpp index 7e6cab6ee3..5962bd3487 100644 --- a/test/conformance/exp_command_buffer/usm_fill_kernel_update.cpp +++ b/test/conformance/exp_command_buffer/usm_fill_kernel_update.cpp @@ -87,8 +87,9 @@ TEST_P(USMFillCommandTest, UpdateParameters) { ASSERT_SUCCESS(urQueueFinish(queue)); Validate((uint32_t *)shared_ptr, global_size, val); - // Allocate a new USM pointer of larger size - size_t new_global_size = 64; + // Allocate a new USM pointer of larger size if feature is supported. + size_t new_global_size = + updatable_execution_range_support ? 64 : global_size; const size_t new_allocation_size = sizeof(val) * new_global_size; ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr, new_allocation_size, &new_shared_ptr)); @@ -128,8 +129,9 @@ TEST_P(USMFillCommandTest, UpdateParameters) { &new_input_desc, // pNewValueArgList nullptr, // pNewExecInfoList nullptr, // pNewGlobalWorkOffset - &new_global_size, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + updatable_execution_range_support ? &new_global_size + : nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; // Update kernel and enqueue command-buffer again