@@ -437,30 +437,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
437
437
return commandHandleReleaseInternal (hCommand);
438
438
}
439
439
440
- UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
441
- ur_exp_command_buffer_command_handle_t hCommand,
440
+ namespace {
441
+ ur_result_t updateKernelExecInfo (
442
+ std::vector<cl_mutable_dispatch_exec_info_khr> &CLExecInfos,
442
443
const ur_exp_command_buffer_update_kernel_launch_desc_t
443
444
*pUpdateKernelLaunch) {
444
-
445
- ur_exp_command_buffer_handle_t hCommandBuffer = hCommand->hCommandBuffer ;
446
- cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext );
447
- cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr ;
448
- cl_int Res =
449
- cl_ext::getExtFuncFromContext<decltype (clUpdateMutableCommandsKHR)>(
450
- CLContext, cl_ext::ExtFuncPtrCache->clUpdateMutableCommandsKHRCache ,
451
- cl_ext::UpdateMutableCommandsName, &clUpdateMutableCommandsKHR);
452
-
453
- if (!clUpdateMutableCommandsKHR || Res != CL_SUCCESS)
454
- return UR_RESULT_ERROR_INVALID_OPERATION;
455
-
456
- if (!hCommandBuffer->IsFinalized || !hCommandBuffer->IsUpdatable )
457
- return UR_RESULT_ERROR_INVALID_OPERATION;
458
-
459
- // Find the CL execution info to update
460
445
const uint32_t NumExecInfos = pUpdateKernelLaunch->numNewExecInfos ;
461
446
const ur_exp_command_buffer_update_exec_info_desc_t *ExecInfoList =
462
447
pUpdateKernelLaunch->pNewExecInfoList ;
463
- std::vector<cl_mutable_dispatch_exec_info_khr> CLExecInfos;
464
448
for (uint32_t i = 0 ; i < NumExecInfos; i++) {
465
449
const ur_exp_command_buffer_update_exec_info_desc_t &URExecInfo =
466
450
ExecInfoList[i];
@@ -488,32 +472,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
488
472
return UR_RESULT_ERROR_INVALID_ENUMERATION;
489
473
}
490
474
}
475
+ return UR_RESULT_SUCCESS;
476
+ }
477
+
478
+ void updateKernelPointerArgs (
479
+ std::vector<cl_mutable_dispatch_arg_khr> &CLUSMArgs,
480
+ const ur_exp_command_buffer_update_kernel_launch_desc_t
481
+ *pUpdateKernelLaunch) {
491
482
492
- // Find the CL USM pointer arguments to the kernel.
493
483
// WARNING - This relies on USM and SVM using the same implementation,
494
484
// which is not guaranteed.
495
485
// See https://github.com/KhronosGroup/OpenCL-Docs/issues/843
496
486
const uint32_t NumPointerArgs = pUpdateKernelLaunch->numNewPointerArgs ;
497
487
const ur_exp_command_buffer_update_pointer_arg_desc_t *ArgPointerList =
498
488
pUpdateKernelLaunch->pNewPointerArgList ;
499
- std::vector<cl_mutable_dispatch_arg_khr> CLUSMArgs (NumPointerArgs);
489
+
490
+ CLUSMArgs.resize (NumPointerArgs);
500
491
for (uint32_t i = 0 ; i < NumPointerArgs; i++) {
501
492
const ur_exp_command_buffer_update_pointer_arg_desc_t &URPointerArg =
502
493
ArgPointerList[i];
503
494
cl_mutable_dispatch_arg_khr &USMArg = CLUSMArgs[i];
504
495
USMArg.arg_index = URPointerArg.argIndex ;
505
496
USMArg.arg_value = *(void *const *)URPointerArg.pNewPointerArg ;
506
497
}
498
+ }
507
499
508
- // Find the memory object and scalar arguments to the kernel.
500
+ void updateKernelArgs (std::vector<cl_mutable_dispatch_arg_khr> &CLArgs,
501
+ const ur_exp_command_buffer_update_kernel_launch_desc_t
502
+ *pUpdateKernelLaunch) {
509
503
const uint32_t NumMemobjArgs = pUpdateKernelLaunch->numNewMemObjArgs ;
510
504
const ur_exp_command_buffer_update_memobj_arg_desc_t *ArgMemobjList =
511
505
pUpdateKernelLaunch->pNewMemObjArgList ;
512
506
const uint32_t NumValueArgs = pUpdateKernelLaunch->numNewValueArgs ;
513
507
const ur_exp_command_buffer_update_value_arg_desc_t *ArgValueList =
514
508
pUpdateKernelLaunch->pNewValueArgList ;
515
509
516
- std::vector<cl_mutable_dispatch_arg_khr> CLArgs;
517
510
for (uint32_t i = 0 ; i < NumMemobjArgs; i++) {
518
511
const ur_exp_command_buffer_update_memobj_arg_desc_t &URMemObjArg =
519
512
ArgMemobjList[i];
@@ -537,45 +530,72 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
537
530
};
538
531
CLArgs.push_back (CLArg);
539
532
}
533
+ }
534
+
535
+ } // end anonymous namespace
536
+
537
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
538
+ ur_exp_command_buffer_command_handle_t hCommand,
539
+ const ur_exp_command_buffer_update_kernel_launch_desc_t
540
+ *pUpdateKernelLaunch) {
541
+
542
+ ur_exp_command_buffer_handle_t hCommandBuffer = hCommand->hCommandBuffer ;
543
+ cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext );
544
+ cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr ;
545
+ cl_int Res =
546
+ cl_ext::getExtFuncFromContext<decltype (clUpdateMutableCommandsKHR)>(
547
+ CLContext, cl_ext::ExtFuncPtrCache->clUpdateMutableCommandsKHRCache ,
548
+ cl_ext::UpdateMutableCommandsName, &clUpdateMutableCommandsKHR);
549
+
550
+ if (!clUpdateMutableCommandsKHR || Res != CL_SUCCESS)
551
+ return UR_RESULT_ERROR_INVALID_OPERATION;
552
+
553
+ if (!hCommandBuffer->IsFinalized || !hCommandBuffer->IsUpdatable )
554
+ return UR_RESULT_ERROR_INVALID_OPERATION;
540
555
541
556
const cl_uint NewWorkDim = pUpdateKernelLaunch->newWorkDim ;
542
- cl_uint &CLWorkDim = hCommand->WorkDim ;
543
- if (NewWorkDim != 0 && NewWorkDim != CLWorkDim) {
544
- // Limitation of the cl_khr_command_buffer_mutable_dispatch specification
545
- // that it is an error to change the ND-Range size.
546
- // https://github.com/KhronosGroup/OpenCL-Docs/issues/1057
547
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
557
+ if (NewWorkDim != 0 && NewWorkDim != hCommand->WorkDim ) {
558
+ return UR_RESULT_ERROR_INVALID_OPERATION;
559
+ }
560
+
561
+ // Find the CL execution info to update
562
+ std::vector<cl_mutable_dispatch_exec_info_khr> CLExecInfos;
563
+ if (ur_result_t result =
564
+ updateKernelExecInfo (CLExecInfos, pUpdateKernelLaunch)) {
565
+ return result;
548
566
}
549
567
550
- // Update the ND-Range configuration of the kernel.
551
- const size_t CopySize = sizeof (size_t ) * CLWorkDim;
568
+ // Find the CL USM pointer arguments to the kernel to update
569
+ std::vector<cl_mutable_dispatch_arg_khr> CLUSMArgs;
570
+ updateKernelPointerArgs (CLUSMArgs, pUpdateKernelLaunch);
571
+
572
+ // Find the memory object and scalar arguments to the kernel to update
573
+ std::vector<cl_mutable_dispatch_arg_khr> CLArgs;
574
+
575
+ updateKernelArgs (CLArgs, pUpdateKernelLaunch);
576
+
577
+ // Find the updated ND-Range configuration of the kernel.
552
578
std::vector<size_t > CLGlobalWorkOffset, CLGlobalWorkSize, CLLocalWorkSize;
579
+ cl_uint &CommandWorkDim = hCommand->WorkDim ;
580
+
581
+ // Lambda for N-Dimensional update
582
+ auto updateNDRange = [CommandWorkDim](std::vector<size_t > &NDRange,
583
+ size_t *UpdatePtr) {
584
+ NDRange.resize (CommandWorkDim, 0 );
585
+ const size_t CopySize = sizeof (size_t ) * CommandWorkDim;
586
+ std::memcpy (NDRange.data (), UpdatePtr, CopySize);
587
+ };
553
588
554
589
if (auto GlobalWorkOffsetPtr = pUpdateKernelLaunch->pNewGlobalWorkOffset ) {
555
- CLGlobalWorkOffset.resize (CLWorkDim);
556
- std::memcpy (CLGlobalWorkOffset.data (), GlobalWorkOffsetPtr, CopySize);
557
- if (CLWorkDim < 3 ) {
558
- const size_t ZeroSize = sizeof (size_t ) * (3 - CLWorkDim);
559
- std::memset (CLGlobalWorkOffset.data () + CLWorkDim, 0 , ZeroSize);
560
- }
590
+ updateNDRange (CLGlobalWorkOffset, GlobalWorkOffsetPtr);
561
591
}
562
592
563
593
if (auto GlobalWorkSizePtr = pUpdateKernelLaunch->pNewGlobalWorkSize ) {
564
- CLGlobalWorkSize.resize (CLWorkDim);
565
- std::memcpy (CLGlobalWorkSize.data (), GlobalWorkSizePtr, CopySize);
566
- if (CLWorkDim < 3 ) {
567
- const size_t ZeroSize = sizeof (size_t ) * (3 - CLWorkDim);
568
- std::memset (CLGlobalWorkSize.data () + CLWorkDim, 0 , ZeroSize);
569
- }
594
+ updateNDRange (CLGlobalWorkSize, GlobalWorkSizePtr);
570
595
}
571
596
572
597
if (auto LocalWorkSizePtr = pUpdateKernelLaunch->pNewLocalWorkSize ) {
573
- CLLocalWorkSize.resize (CLWorkDim);
574
- std::memcpy (CLLocalWorkSize.data (), LocalWorkSizePtr, CopySize);
575
- if (CLWorkDim < 3 ) {
576
- const size_t ZeroSize = sizeof (size_t ) * (3 - CLWorkDim);
577
- std::memset (CLLocalWorkSize.data () + CLWorkDim, 0 , ZeroSize);
578
- }
598
+ updateNDRange (CLLocalWorkSize, LocalWorkSizePtr);
579
599
}
580
600
581
601
cl_mutable_command_khr command =
@@ -587,7 +607,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
587
607
static_cast <cl_uint>(CLArgs.size ()), // num_args
588
608
static_cast <cl_uint>(CLUSMArgs.size ()), // num_svm_args
589
609
static_cast <cl_uint>(CLExecInfos.size ()), // num_exec_infos
590
- CLWorkDim, // work_dim
610
+ CommandWorkDim, // work_dim
591
611
CLArgs.data (), // arg_list
592
612
CLUSMArgs.data (), // arg_svm_list
593
613
CLExecInfos.data (), // exec_info_list
0 commit comments