Skip to content

Commit 755a1e7

Browse files
authored
Merge pull request #1385 from yingcong-wu/yc/new-api-suggestgroupsize
Implement urKernelGetSuggestedLocalWorkSize
2 parents 6469b89 + 5593d84 commit 755a1e7

30 files changed

+811
-63
lines changed

include/ur_api.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ typedef enum ur_function_t {
224224
UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP = 222, ///< Enumerator for ::urCommandBufferCommandGetInfoExp
225225
UR_FUNCTION_ENQUEUE_TIMESTAMP_RECORDING_EXP = 223, ///< Enumerator for ::urEnqueueTimestampRecordingExp
226226
UR_FUNCTION_ENQUEUE_KERNEL_LAUNCH_CUSTOM_EXP = 224, ///< Enumerator for ::urEnqueueKernelLaunchCustomExp
227+
UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE = 225, ///< Enumerator for ::urKernelGetSuggestedLocalWorkSize
227228
/// @cond
228229
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
229230
/// @endcond
@@ -5230,6 +5231,43 @@ urKernelCreateWithNativeHandle(
52305231
ur_kernel_handle_t *phKernel ///< [out] pointer to the handle of the kernel object created.
52315232
);
52325233

5234+
///////////////////////////////////////////////////////////////////////////////
5235+
/// @brief Get the suggested local work size for a kernel.
5236+
///
5237+
/// @details
5238+
/// - Query a suggested local work size for a kernel given a global size for
5239+
/// each dimension.
5240+
/// - The application may call this function from simultaneous threads for
5241+
/// the same context.
5242+
///
5243+
/// @returns
5244+
/// - ::UR_RESULT_SUCCESS
5245+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
5246+
/// - ::UR_RESULT_ERROR_DEVICE_LOST
5247+
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
5248+
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
5249+
/// + `NULL == hKernel`
5250+
/// + `NULL == hQueue`
5251+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
5252+
/// + `NULL == pGlobalWorkOffset`
5253+
/// + `NULL == pGlobalWorkSize`
5254+
/// + `NULL == pSuggestedLocalWorkSize`
5255+
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
5256+
UR_APIEXPORT ur_result_t UR_APICALL
5257+
urKernelGetSuggestedLocalWorkSize(
5258+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel
5259+
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
5260+
uint32_t numWorkDim, ///< [in] number of dimensions, from 1 to 3, to specify the global
5261+
///< and work-group work-items
5262+
const size_t *pGlobalWorkOffset, ///< [in] pointer to an array of numWorkDim unsigned values that specify
5263+
///< the offset used to calculate the global ID of a work-item
5264+
const size_t *pGlobalWorkSize, ///< [in] pointer to an array of numWorkDim unsigned values that specify
5265+
///< the number of global work-items in workDim that will execute the
5266+
///< kernel function
5267+
size_t *pSuggestedLocalWorkSize ///< [out] pointer to an array of numWorkDim unsigned values that specify
5268+
///< suggested local work size that will contain the result of the query
5269+
);
5270+
52335271
#if !defined(__GNUC__)
52345272
#pragma endregion
52355273
#endif
@@ -9943,6 +9981,19 @@ typedef struct ur_kernel_create_with_native_handle_params_t {
99439981
ur_kernel_handle_t **pphKernel;
99449982
} ur_kernel_create_with_native_handle_params_t;
99459983

9984+
///////////////////////////////////////////////////////////////////////////////
9985+
/// @brief Function parameters for urKernelGetSuggestedLocalWorkSize
9986+
/// @details Each entry is a pointer to the parameter passed to the function;
9987+
/// allowing the callback the ability to modify the parameter's value
9988+
typedef struct ur_kernel_get_suggested_local_work_size_params_t {
9989+
ur_kernel_handle_t *phKernel;
9990+
ur_queue_handle_t *phQueue;
9991+
uint32_t *pnumWorkDim;
9992+
const size_t **ppGlobalWorkOffset;
9993+
const size_t **ppGlobalWorkSize;
9994+
size_t **ppSuggestedLocalWorkSize;
9995+
} ur_kernel_get_suggested_local_work_size_params_t;
9996+
99469997
///////////////////////////////////////////////////////////////////////////////
99479998
/// @brief Function parameters for urKernelSetArgValue
99489999
/// @details Each entry is a pointer to the parameter passed to the function;

include/ur_ddi.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,16 @@ typedef ur_result_t(UR_APICALL *ur_pfnKernelCreateWithNativeHandle_t)(
535535
const ur_kernel_native_properties_t *,
536536
ur_kernel_handle_t *);
537537

538+
///////////////////////////////////////////////////////////////////////////////
539+
/// @brief Function-pointer for urKernelGetSuggestedLocalWorkSize
540+
typedef ur_result_t(UR_APICALL *ur_pfnKernelGetSuggestedLocalWorkSize_t)(
541+
ur_kernel_handle_t,
542+
ur_queue_handle_t,
543+
uint32_t,
544+
const size_t *,
545+
const size_t *,
546+
size_t *);
547+
538548
///////////////////////////////////////////////////////////////////////////////
539549
/// @brief Function-pointer for urKernelSetArgValue
540550
typedef ur_result_t(UR_APICALL *ur_pfnKernelSetArgValue_t)(
@@ -603,6 +613,7 @@ typedef struct ur_kernel_dditable_t {
603613
ur_pfnKernelRelease_t pfnRelease;
604614
ur_pfnKernelGetNativeHandle_t pfnGetNativeHandle;
605615
ur_pfnKernelCreateWithNativeHandle_t pfnCreateWithNativeHandle;
616+
ur_pfnKernelGetSuggestedLocalWorkSize_t pfnGetSuggestedLocalWorkSize;
606617
ur_pfnKernelSetArgValue_t pfnSetArgValue;
607618
ur_pfnKernelSetArgLocal_t pfnSetArgLocal;
608619
ur_pfnKernelSetArgPointer_t pfnSetArgPointer;

include/ur_print.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintKernelGetNativeHandleParams(const str
14421442
/// - `buff_size < out_size`
14431443
UR_APIEXPORT ur_result_t UR_APICALL urPrintKernelCreateWithNativeHandleParams(const struct ur_kernel_create_with_native_handle_params_t *params, char *buffer, const size_t buff_size, size_t *out_size);
14441444

1445+
///////////////////////////////////////////////////////////////////////////////
1446+
/// @brief Print ur_kernel_get_suggested_local_work_size_params_t struct
1447+
/// @returns
1448+
/// - ::UR_RESULT_SUCCESS
1449+
/// - ::UR_RESULT_ERROR_INVALID_SIZE
1450+
/// - `buff_size < out_size`
1451+
UR_APIEXPORT ur_result_t UR_APICALL urPrintKernelGetSuggestedLocalWorkSizeParams(const struct ur_kernel_get_suggested_local_work_size_params_t *params, char *buffer, const size_t buff_size, size_t *out_size);
1452+
14451453
///////////////////////////////////////////////////////////////////////////////
14461454
/// @brief Print ur_kernel_set_arg_value_params_t struct
14471455
/// @returns

include/ur_print.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) {
929929
case UR_FUNCTION_ENQUEUE_KERNEL_LAUNCH_CUSTOM_EXP:
930930
os << "UR_FUNCTION_ENQUEUE_KERNEL_LAUNCH_CUSTOM_EXP";
931931
break;
932+
case UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE:
933+
os << "UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE";
934+
break;
932935
default:
933936
os << "unknown enumerator";
934937
break;
@@ -11462,6 +11465,49 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1146211465
return os;
1146311466
}
1146411467

11468+
///////////////////////////////////////////////////////////////////////////////
11469+
/// @brief Print operator for the ur_kernel_get_suggested_local_work_size_params_t type
11470+
/// @returns
11471+
/// std::ostream &
11472+
inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_kernel_get_suggested_local_work_size_params_t *params) {
11473+
11474+
os << ".hKernel = ";
11475+
11476+
ur::details::printPtr(os,
11477+
*(params->phKernel));
11478+
11479+
os << ", ";
11480+
os << ".hQueue = ";
11481+
11482+
ur::details::printPtr(os,
11483+
*(params->phQueue));
11484+
11485+
os << ", ";
11486+
os << ".numWorkDim = ";
11487+
11488+
os << *(params->pnumWorkDim);
11489+
11490+
os << ", ";
11491+
os << ".pGlobalWorkOffset = ";
11492+
11493+
ur::details::printPtr(os,
11494+
*(params->ppGlobalWorkOffset));
11495+
11496+
os << ", ";
11497+
os << ".pGlobalWorkSize = ";
11498+
11499+
ur::details::printPtr(os,
11500+
*(params->ppGlobalWorkSize));
11501+
11502+
os << ", ";
11503+
os << ".pSuggestedLocalWorkSize = ";
11504+
11505+
ur::details::printPtr(os,
11506+
*(params->ppSuggestedLocalWorkSize));
11507+
11508+
return os;
11509+
}
11510+
1146511511
///////////////////////////////////////////////////////////////////////////////
1146611512
/// @brief Print operator for the ur_kernel_set_arg_value_params_t type
1146711513
/// @returns
@@ -17143,6 +17189,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_
1714317189
case UR_FUNCTION_KERNEL_CREATE_WITH_NATIVE_HANDLE: {
1714417190
os << (const struct ur_kernel_create_with_native_handle_params_t *)params;
1714517191
} break;
17192+
case UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE: {
17193+
os << (const struct ur_kernel_get_suggested_local_work_size_params_t *)params;
17194+
} break;
1714617195
case UR_FUNCTION_KERNEL_SET_ARG_VALUE: {
1714717196
os << (const struct ur_kernel_set_arg_value_params_t *)params;
1714817197
} break;

scripts/core/kernel.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,44 @@ params:
534534
returns:
535535
- $X_RESULT_ERROR_UNSUPPORTED_FEATURE:
536536
- "If the adapter has no underlying equivalent handle."
537+
--- #--------------------------------------------------------------------------
538+
type: function
539+
desc: "Get the suggested local work size for a kernel."
540+
class: $xKernel
541+
name: GetSuggestedLocalWorkSize
542+
ordinal: "0"
543+
details:
544+
- "Query a suggested local work size for a kernel given a global size for each dimension."
545+
- "The application may call this function from simultaneous threads for the same context."
546+
params:
547+
- type: $x_kernel_handle_t
548+
name: hKernel
549+
desc: |
550+
[in] handle of the kernel
551+
- type: $x_queue_handle_t
552+
name: hQueue
553+
desc: |
554+
[in] handle of the queue object
555+
- type: uint32_t
556+
name: numWorkDim
557+
desc: |
558+
[in] number of dimensions, from 1 to 3, to specify the global
559+
and work-group work-items
560+
- type: const size_t*
561+
name: pGlobalWorkOffset
562+
desc: |
563+
[in] pointer to an array of numWorkDim unsigned values that specify
564+
the offset used to calculate the global ID of a work-item
565+
- type: const size_t*
566+
name: pGlobalWorkSize
567+
desc: |
568+
[in] pointer to an array of numWorkDim unsigned values that specify
569+
the number of global work-items in workDim that will execute the
570+
kernel function
571+
- type: size_t*
572+
name: pSuggestedLocalWorkSize
573+
desc: |
574+
[out] pointer to an array of numWorkDim unsigned values that specify
575+
suggested local work size that will contain the result of the query
576+
returns:
577+
- $X_RESULT_ERROR_UNSUPPORTED_FEATURE

scripts/core/registry.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,9 @@ etors:
586586
- name: ENQUEUE_KERNEL_LAUNCH_CUSTOM_EXP
587587
desc: Enumerator for $xEnqueueKernelLaunchCustomExp
588588
value: '224'
589+
- name: KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE
590+
desc: Enumerator for $xKernelGetSuggestedLocalWorkSize
591+
value: '225'
589592
---
590593
type: enum
591594
desc: Defines structure types

source/adapters/cuda/enqueue.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue, CUstream Stream,
1717
uint32_t NumEventsInWaitList,
1818
const ur_event_handle_t *EventWaitList);
1919

20+
void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
21+
const size_t *GlobalWorkSize, const uint32_t WorkDim,
22+
ur_kernel_handle_t Kernel);
23+
2024
bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
2125
ur_kernel_handle_t Kernel,
2226
size_t BlockSize);

source/adapters/cuda/kernel.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "kernel.hpp"
12+
#include "enqueue.hpp"
1213
#include "memory.hpp"
14+
#include "queue.hpp"
1315
#include "sampler.hpp"
1416

1517
UR_APIEXPORT ur_result_t UR_APICALL
@@ -380,3 +382,30 @@ urKernelSetArgSampler(ur_kernel_handle_t hKernel, uint32_t argIndex,
380382
}
381383
return Result;
382384
}
385+
386+
UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize(
387+
ur_kernel_handle_t hKernel, ur_queue_handle_t hQueue, uint32_t workDim,
388+
[[maybe_unused]] const size_t *pGlobalWorkOffset,
389+
const size_t *pGlobalWorkSize, size_t *pSuggestedLocalWorkSize) {
390+
// Preconditions
391+
UR_ASSERT(hQueue->getContext() == hKernel->getContext(),
392+
UR_RESULT_ERROR_INVALID_KERNEL);
393+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
394+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
395+
UR_ASSERT(pSuggestedLocalWorkSize != nullptr,
396+
UR_RESULT_ERROR_INVALID_NULL_POINTER);
397+
398+
ur_device_handle_t Device = hQueue->Device;
399+
ur_result_t Result = UR_RESULT_SUCCESS;
400+
size_t ThreadsPerBlock[3] = {};
401+
402+
// Set the active context here as guessLocalWorkSize needs an active context
403+
ScopedContext Active(Device);
404+
405+
guessLocalWorkSize(Device, ThreadsPerBlock, pGlobalWorkSize, workDim,
406+
hKernel);
407+
408+
std::copy(ThreadsPerBlock, ThreadsPerBlock + workDim,
409+
pSuggestedLocalWorkSize);
410+
return Result;
411+
}

source/adapters/cuda/ur_interface_loader.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable(
125125
pDdiTable->pfnSetArgValue = urKernelSetArgValue;
126126
pDdiTable->pfnSetExecInfo = urKernelSetExecInfo;
127127
pDdiTable->pfnSetSpecializationConstants = nullptr;
128+
pDdiTable->pfnGetSuggestedLocalWorkSize = urKernelGetSuggestedLocalWorkSize;
128129
return UR_RESULT_SUCCESS;
129130
}
130131

source/adapters/hip/enqueue.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ void setCopyRectParams(ur_rect_region_t Region, const void *SrcPtr,
3030
const hipMemoryType DstType, ur_rect_offset_t DstOffset,
3131
size_t DstRowPitch, size_t DstSlicePitch,
3232
hipMemcpy3DParms &Params);
33+
34+
void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
35+
const size_t *GlobalWorkSize, const uint32_t WorkDim,
36+
const size_t MaxThreadsPerBlock[3]);

source/adapters/hip/kernel.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "kernel.hpp"
12+
#include "enqueue.hpp"
1213
#include "memory.hpp"
1314
#include "sampler.hpp"
1415

@@ -349,3 +350,31 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetSpecializationConstants(
349350
[[maybe_unused]] const ur_specialization_constant_info_t *pSpecConstants) {
350351
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
351352
}
353+
354+
UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize(
355+
[[maybe_unused]] ur_kernel_handle_t hKernel, ur_queue_handle_t hQueue,
356+
uint32_t workDim, [[maybe_unused]] const size_t *pGlobalWorkOffset,
357+
const size_t *pGlobalWorkSize, size_t *pSuggestedLocalWorkSize) {
358+
UR_ASSERT(hQueue->getContext() == hKernel->getContext(),
359+
UR_RESULT_ERROR_INVALID_QUEUE);
360+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
361+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
362+
UR_ASSERT(pSuggestedLocalWorkSize != nullptr,
363+
UR_RESULT_ERROR_INVALID_NULL_POINTER);
364+
365+
size_t MaxThreadsPerBlock[3];
366+
size_t ThreadsPerBlock[3] = {32u, 1u, 1u};
367+
368+
MaxThreadsPerBlock[0] = hQueue->Device->getMaxBlockDimX();
369+
MaxThreadsPerBlock[1] = hQueue->Device->getMaxBlockDimY();
370+
MaxThreadsPerBlock[2] = hQueue->Device->getMaxBlockDimZ();
371+
372+
ur_device_handle_t Device = hQueue->getDevice();
373+
ScopedContext Active(Device);
374+
375+
guessLocalWorkSize(Device, ThreadsPerBlock, pGlobalWorkSize, workDim,
376+
MaxThreadsPerBlock);
377+
std::copy(ThreadsPerBlock, ThreadsPerBlock + workDim,
378+
pSuggestedLocalWorkSize);
379+
return UR_RESULT_SUCCESS;
380+
}

source/adapters/hip/ur_interface_loader.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable(
125125
pDdiTable->pfnSetArgValue = urKernelSetArgValue;
126126
pDdiTable->pfnSetExecInfo = urKernelSetExecInfo;
127127
pDdiTable->pfnSetSpecializationConstants = urKernelSetSpecializationConstants;
128+
pDdiTable->pfnGetSuggestedLocalWorkSize = urKernelGetSuggestedLocalWorkSize;
128129
return UR_RESULT_SUCCESS;
129130
}
130131

0 commit comments

Comments
 (0)