Skip to content

Commit 99c329b

Browse files
committed
[EXP][Command-Buffer] Add kernel command update
This change introduces a new API that allows the kernel commands of a command-buffer to be updated with a new configuration. For example, modified arguments or ND-Range. The only implemented adapter is CUDA. See [cl_khr_command_buffer_mutable_dispatch](https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Ext.html#cl_khr_command_buffer_mutable_dispatch) as prior art. The differences between the proposed API and the above are: * Only the append kernel entry-point returns a command handle. I imagine this will be changed in future to enable other commands to do update. * Only USM and buffer arguments can be updated, there is not equivalent update struct for `urKernelSetArgLocal`, `urKernelSetArgValue`, or `urKernelSetArgSampler` * There is no granularity of optional support for update, an implementer must either implement all the ways to update a kernel configuration, or none of them.
1 parent fe5bc76 commit 99c329b

22 files changed

+1770
-593
lines changed

include/ur.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ur_function_v(IntEnum):
187187
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
188188
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
189189
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
190+
COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP = 182 ## Enumerator for ::urCommandBufferUpdateKernelLaunchExp
190191
PROGRAM_BUILD_EXP = 197 ## Enumerator for ::urProgramBuildExp
191192
PROGRAM_COMPILE_EXP = 198 ## Enumerator for ::urProgramCompileExp
192193
PROGRAM_LINK_EXP = 199 ## Enumerator for ::urProgramLinkExp
@@ -250,6 +251,10 @@ class ur_structure_type_v(IntEnum):
250251
KERNEL_ARG_VALUE_PROPERTIES = 32 ## ::ur_kernel_arg_value_properties_t
251252
KERNEL_ARG_LOCAL_PROPERTIES = 33 ## ::ur_kernel_arg_local_properties_t
252253
EXP_COMMAND_BUFFER_DESC = 0x1000 ## ::ur_exp_command_buffer_desc_t
254+
EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC = 0x1001 ## ::ur_exp_command_buffer_update_kernel_launch_desc_t
255+
EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC = 0x1002 ## ::ur_exp_command_buffer_update_memobj_arg_desc_t
256+
EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC = 0x1003 ## ::ur_exp_command_buffer_update_pointer_arg_desc_t
257+
EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC = 0x1004 ## ::ur_exp_command_buffer_update_exec_info_desc_t
253258
EXP_SAMPLER_MIP_PROPERTIES = 0x2000 ## ::ur_exp_sampler_mip_properties_t
254259
EXP_INTEROP_MEM_DESC = 0x2001 ## ::ur_exp_interop_mem_desc_t
255260
EXP_INTEROP_SEMAPHORE_DESC = 0x2002 ## ::ur_exp_interop_semaphore_desc_t
@@ -455,6 +460,7 @@ class ur_result_v(IntEnum):
455460
ERROR_INVALID_COMMAND_BUFFER_EXP = 0x1000 ## Invalid Command-Buffer
456461
ERROR_INVALID_COMMAND_BUFFER_SYNC_POINT_EXP = 0x1001## Sync point is not valid for the command-buffer
457462
ERROR_INVALID_COMMAND_BUFFER_SYNC_POINT_WAIT_LIST_EXP = 0x1002 ## Sync point wait list is invalid
463+
ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP = 0x1003## Handle to command-buffer command is invalid
458464
ERROR_UNKNOWN = 0x7ffffffe ## Unknown or internal error
459465

460466
class ur_result_t(c_int):
@@ -865,6 +871,10 @@ class ur_device_info_v(IntEnum):
865871
## version than older devices.
866872
VIRTUAL_MEMORY_SUPPORT = 114 ## [::ur_bool_t] return true if the device supports virtual memory.
867873
ESIMD_SUPPORT = 115 ## [::ur_bool_t] return true if the device supports ESIMD.
874+
COMMAND_BUFFER_SUPPORT_EXP = 0x1000 ## [::ur_bool_t] returns true if the device supports the use of
875+
## command-buffers.
876+
COMMAND_BUFFER_UPDATE_SUPPORT_EXP = 0x1001 ## [::ur_bool_t] returns true if the device supports updating the
877+
## commands in a command-buffer.
868878
BINDLESS_IMAGES_SUPPORT_EXP = 0x2000 ## [::ur_bool_t] returns true if the device supports the creation of
869879
## bindless images
870880
BINDLESS_IMAGES_SHARED_USM_SUPPORT_EXP = 0x2001 ## [::ur_bool_t] returns true if the device supports the creation of
@@ -2300,7 +2310,71 @@ class ur_exp_command_buffer_desc_t(Structure):
23002310
_fields_ = [
23012311
("stype", ur_structure_type_t), ## [in] type of this structure, must be
23022312
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC
2303-
("pNext", c_void_p) ## [in][optional] pointer to extension-specific structure
2313+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
2314+
("isUpdatable", ur_c_bool_t) ## [in] Commands in a finalized command-buffer can be updated.
2315+
]
2316+
2317+
###############################################################################
2318+
## @brief Descriptor type for updating a kernel command memobj argument.
2319+
class ur_exp_command_buffer_update_memobj_arg_desc_t(Structure):
2320+
_fields_ = [
2321+
("stype", ur_structure_type_t), ## [in] type of this structure, must be
2322+
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC
2323+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
2324+
("argIndex", c_ulong), ## [in] Argument index.
2325+
("pProperties", *), ## [in][optinal] Pointer to memory object properties.
2326+
("hArgValue", ur_mem_handle_t) ## [in][optional] Handle of memory object.
2327+
]
2328+
2329+
###############################################################################
2330+
## @brief Descriptor type for updating a kernel command pointer argument.
2331+
class ur_exp_command_buffer_update_pointer_arg_desc_t(Structure):
2332+
_fields_ = [
2333+
("stype", ur_structure_type_t), ## [in] type of this structure, must be
2334+
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC
2335+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
2336+
("argIndex", c_ulong), ## [in] Argument index.
2337+
("pProperties", *), ## [in][optinal] Pointer to USM pointer properties.
2338+
("pArgValue", *) ## [in][optional] USM pointer to memory location holding the argument
2339+
## value.
2340+
]
2341+
2342+
###############################################################################
2343+
## @brief Descriptor type for updating kernel command execution info.
2344+
class ur_exp_command_buffer_update_exec_info_desc_t(Structure):
2345+
_fields_ = [
2346+
("stype", ur_structure_type_t), ## [in] type of this structure, must be
2347+
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC
2348+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
2349+
("propName", ur_kernel_exec_info_t), ## [in] Name of execution attribute.
2350+
("propSize", c_size_t), ## [in] Size of execution attribute.
2351+
("pProperties", *), ## [in][optional] Pointer to execution info properties.
2352+
("pPropValue", *) ## [in] Pointer to memory location holding the property value.
2353+
]
2354+
2355+
###############################################################################
2356+
## @brief Descriptor type for updating a kernel launch command.
2357+
class ur_exp_command_buffer_update_kernel_launch_desc_t(Structure):
2358+
_fields_ = [
2359+
("stype", ur_structure_type_t), ## [in] type of this structure, must be
2360+
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC
2361+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
2362+
("numMemobjArgs", c_ulong), ## [in] Length of pArgMemobjList.
2363+
("numPointerArgs", c_ulong), ## [in] Length of pArgPointerList.
2364+
("numExecInfos", c_ulong), ## [in] Length of pExecInfoList.
2365+
("workDim", c_ulong), ## [in] Number of work dimensions in the kernel ND-range, from 1-3.
2366+
("pArgMemobjList", POINTER(ur_exp_command_buffer_update_memobj_arg_desc_t)),## [in] An array describing the new kernel mem obj arguments for the
2367+
## command.
2368+
("pArgPointerList", POINTER(ur_exp_command_buffer_update_pointer_arg_desc_t)), ## [in] An array describing the new kernel pointer arguments for the
2369+
## command.
2370+
("pArgExecInfoList", POINTER(ur_exp_command_buffer_update_exec_info_desc_t)), ## [in] An array describing the execution info objects for the command.
2371+
("pGlobalWorkOffset", POINTER(c_size_t)), ## [in] Array of workDim unsigned values that describe the offset used to
2372+
## calculate the global ID.
2373+
("pGlobalWorkSize", POINTER(c_size_t)), ## [in] Array of workDim unsigned values that describe the number of
2374+
## global work-items.
2375+
("pLocalWorkSize", POINTER(c_size_t)) ## [in] Array of workDim unsigned values that describe the number of
2376+
## work-items that make up a work-group. If nullptr, the runtime
2377+
## implementation will choose the work-group size.
23042378
]
23052379

23062380
###############################################################################
@@ -2314,6 +2388,11 @@ class ur_exp_command_buffer_sync_point_t(c_ulong):
23142388
class ur_exp_command_buffer_handle_t(c_void_p):
23152389
pass
23162390

2391+
###############################################################################
2392+
## @brief Handle of a Command-Buffer command
2393+
class ur_exp_command_buffer_command_handle_t(c_void_p):
2394+
pass
2395+
23172396
###############################################################################
23182397
## @brief The extension string which defines support for cooperative-kernels
23192398
## which is returned when querying device extensions.
@@ -3610,9 +3689,9 @@ class ur_usm_exp_dditable_t(Structure):
36103689
###############################################################################
36113690
## @brief Function-pointer for urCommandBufferAppendKernelLaunchExp
36123691
if __use_win_types:
3613-
_urCommandBufferAppendKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t) )
3692+
_urCommandBufferAppendKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_command_handle_t) )
36143693
else:
3615-
_urCommandBufferAppendKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t) )
3694+
_urCommandBufferAppendKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_command_handle_t) )
36163695

36173696
###############################################################################
36183697
## @brief Function-pointer for urCommandBufferAppendUSMMemcpyExp
@@ -3698,6 +3777,13 @@ class ur_usm_exp_dditable_t(Structure):
36983777
else:
36993778
_urCommandBufferEnqueueExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_queue_handle_t, c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )
37003779

3780+
###############################################################################
3781+
## @brief Function-pointer for urCommandBufferUpdateKernelLaunchExp
3782+
if __use_win_types:
3783+
_urCommandBufferUpdateKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_command_handle_t, POINTER(ur_exp_command_buffer_update_kernel_launch_desc_t) )
3784+
else:
3785+
_urCommandBufferUpdateKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_command_handle_t, POINTER(ur_exp_command_buffer_update_kernel_launch_desc_t) )
3786+
37013787

37023788
###############################################################################
37033789
## @brief Table of CommandBufferExp functions pointers
@@ -3719,7 +3805,8 @@ class ur_command_buffer_exp_dditable_t(Structure):
37193805
("pfnAppendMemBufferFillExp", c_void_p), ## _urCommandBufferAppendMemBufferFillExp_t
37203806
("pfnAppendUSMPrefetchExp", c_void_p), ## _urCommandBufferAppendUSMPrefetchExp_t
37213807
("pfnAppendUSMAdviseExp", c_void_p), ## _urCommandBufferAppendUSMAdviseExp_t
3722-
("pfnEnqueueExp", c_void_p) ## _urCommandBufferEnqueueExp_t
3808+
("pfnEnqueueExp", c_void_p), ## _urCommandBufferEnqueueExp_t
3809+
("pfnUpdateKernelLaunchExp", c_void_p) ## _urCommandBufferUpdateKernelLaunchExp_t
37233810
]
37243811

37253812
###############################################################################
@@ -4255,6 +4342,7 @@ def __init__(self, version : ur_api_version_t):
42554342
self.urCommandBufferAppendUSMPrefetchExp = _urCommandBufferAppendUSMPrefetchExp_t(self.__dditable.CommandBufferExp.pfnAppendUSMPrefetchExp)
42564343
self.urCommandBufferAppendUSMAdviseExp = _urCommandBufferAppendUSMAdviseExp_t(self.__dditable.CommandBufferExp.pfnAppendUSMAdviseExp)
42574344
self.urCommandBufferEnqueueExp = _urCommandBufferEnqueueExp_t(self.__dditable.CommandBufferExp.pfnEnqueueExp)
4345+
self.urCommandBufferUpdateKernelLaunchExp = _urCommandBufferUpdateKernelLaunchExp_t(self.__dditable.CommandBufferExp.pfnUpdateKernelLaunchExp)
42584346

42594347
# call driver to get function pointers
42604348
UsmP2PExp = ur_usm_p2p_exp_dditable_t()

0 commit comments

Comments
 (0)