From fd6cf3af984e720d31b975838a8b5e113766cb96 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 13 Jun 2024 12:14:41 +0100 Subject: [PATCH 01/15] Rework null adapter into mock adapter This also comes with a new accompanying helper header which allows tests to override entry points and create their own reference counted dummy handles. --- include/ur_api.h | 65 +- include/ur_print.h | 8 + include/ur_print.hpp | 25 + scripts/YaML.md | 3 +- scripts/core/INTRO.rst | 33 + scripts/core/adapter.yml | 2 +- scripts/core/context.yml | 2 +- scripts/core/device.yml | 2 +- scripts/core/event.yml | 2 +- scripts/core/exp-command-buffer.yml | 2 +- scripts/core/kernel.yml | 2 +- scripts/core/loader.yml | 28 +- scripts/core/memory.yml | 2 +- scripts/core/program.yml | 2 +- scripts/core/queue.yml | 2 +- scripts/core/registry.yml | 3 + scripts/core/sampler.yml | 2 +- scripts/core/usm.yml | 2 +- scripts/core/virtual_memory.yml | 2 +- scripts/generate_code.py | 10 +- scripts/templates/helper.py | 26 +- scripts/templates/mockddi.cpp.mako | 168 + scripts/templates/nullddi.cpp.mako | 137 - source/CMakeLists.txt | 1 + source/adapters/CMakeLists.txt | 2 +- source/adapters/{null => mock}/CMakeLists.txt | 9 +- .../{null/ur_null.cpp => mock/ur_mock.cpp} | 2 +- .../{null/ur_null.hpp => mock/ur_mock.hpp} | 9 +- .../ur_nullddi.cpp => mock/ur_mockddi.cpp} | 7235 +++++++++++++---- source/loader/CMakeLists.txt | 1 - source/loader/layers/tracing/ur_trcddi.cpp | 27 +- source/loader/layers/validation/ur_valddi.cpp | 27 +- source/loader/loader.def.in | 2 + source/loader/loader.map.in | 2 + source/loader/ur_adapter_registry.hpp | 18 + source/loader/ur_ldrddi.cpp | 172 +- source/loader/ur_lib.cpp | 14 + source/loader/ur_lib.hpp | 4 + source/loader/ur_libapi.cpp | 56 +- source/loader/ur_print.cpp | 8 + source/mock/CMakeLists.txt | 17 + source/mock/ur_mock_helpers.cpp | 19 + source/mock/ur_mock_helpers.hpp | 107 + source/ur_api.cpp | 55 +- test/layers/mock/CMakeLists.txt | 23 + test/layers/mock/mock.cpp | 141 + test/layers/tracing/CMakeLists.txt | 4 +- test/layers/validation/CMakeLists.txt | 2 +- test/loader/CMakeLists.txt | 2 +- test/loader/handles/CMakeLists.txt | 2 +- test/loader/platforms/CMakeLists.txt | 4 +- test/tools/urtrace/CMakeLists.txt | 12 +- test/usm/CMakeLists.txt | 2 +- 53 files changed, 6722 insertions(+), 1787 deletions(-) create mode 100644 scripts/templates/mockddi.cpp.mako delete mode 100644 scripts/templates/nullddi.cpp.mako rename source/adapters/{null => mock}/CMakeLists.txt (72%) rename source/adapters/{null/ur_null.cpp => mock/ur_mock.cpp} (99%) rename source/adapters/{null/ur_null.hpp => mock/ur_mock.hpp} (86%) rename source/adapters/{null/ur_nullddi.cpp => mock/ur_mockddi.cpp} (51%) create mode 100644 source/mock/CMakeLists.txt create mode 100644 source/mock/ur_mock_helpers.cpp create mode 100644 source/mock/ur_mock_helpers.hpp create mode 100644 test/layers/mock/CMakeLists.txt create mode 100644 test/layers/mock/mock.cpp diff --git a/include/ur_api.h b/include/ur_api.h index 7ad05e73da..c896978528 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -226,6 +226,7 @@ typedef enum ur_function_t { UR_FUNCTION_BINDLESS_IMAGES_IMPORT_EXTERNAL_MEMORY_EXP = 226, ///< Enumerator for ::urBindlessImagesImportExternalMemoryExp UR_FUNCTION_BINDLESS_IMAGES_IMPORT_EXTERNAL_SEMAPHORE_EXP = 227, ///< Enumerator for ::urBindlessImagesImportExternalSemaphoreExp UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP = 228, ///< Enumerator for ::urEnqueueNativeCommandExp + UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED = 231, ///< Enumerator for ::urLoaderConfigSetMockingEnabled /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -601,7 +602,7 @@ urLoaderConfigCreate( /// + `NULL == hLoaderConfig` UR_APIEXPORT ur_result_t UR_APICALL urLoaderConfigRetain( - ur_loader_config_handle_t hLoaderConfig ///< [in] loader config handle to retain + ur_loader_config_handle_t hLoaderConfig ///< [in][retain] loader config handle to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -742,6 +743,35 @@ urLoaderConfigSetCodeLocationCallback( void *pUserData ///< [in][out][optional] pointer to data to be passed to callback. ); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Callback to replace or instrument generic mock functionality in the +/// mock adapter. +typedef ur_result_t (*ur_mock_callback_t)( + void *pParams ///< [in][out] Pointer to the appropriate param struct for the function +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief The only adapter reported with mock enabled will be the mock adapter. +/// +/// @details +/// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all +/// entry points. It will also create and correctly reference count dummy +/// handles where appropriate. Its behaviour can be modified by linking +/// the ::ur_mock_headers library and using the callbacks object. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +UR_APIEXPORT ur_result_t UR_APICALL +urLoaderConfigSetMockingEnabled( + ur_loader_config_handle_t hLoaderConfig, ///< [in] Handle to config object mocking will be enabled for. + ur_bool_t enable ///< [in] Handle to config object the layer will be enabled for. +); + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -863,7 +893,7 @@ urAdapterRelease( /// + `NULL == hAdapter` UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -1736,7 +1766,7 @@ urDeviceGetInfo( /// + `NULL == hDevice` UR_APIEXPORT ur_result_t UR_APICALL urDeviceRetain( - ur_device_handle_t hDevice ///< [in] handle of the device to get a reference of. + ur_device_handle_t hDevice ///< [in][retain] handle of the device to get a reference of. ); /////////////////////////////////////////////////////////////////////////////// @@ -2217,7 +2247,7 @@ urContextCreate( /// + `NULL == hContext` UR_APIEXPORT ur_result_t UR_APICALL urContextRetain( - ur_context_handle_t hContext ///< [in] handle of the context to get a reference of. + ur_context_handle_t hContext ///< [in][retain] handle of the context to get a reference of. ); /////////////////////////////////////////////////////////////////////////////// @@ -2739,7 +2769,7 @@ urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t hMem ///< [in][retain] handle of the memory object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -3130,7 +3160,7 @@ urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urSamplerRetain( - ur_sampler_handle_t hSampler ///< [in] handle of the sampler object to get access + ur_sampler_handle_t hSampler ///< [in][retain] handle of the sampler object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -3690,7 +3720,7 @@ urUSMPoolCreate( /// + `NULL == pPool` UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ); /////////////////////////////////////////////////////////////////////////////// @@ -4046,7 +4076,7 @@ urPhysicalMemCreate( /// + `NULL == hPhysicalMem` UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRetain( - ur_physical_mem_handle_t hPhysicalMem ///< [in] handle of the physical memory object to retain. + ur_physical_mem_handle_t hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ); /////////////////////////////////////////////////////////////////////////////// @@ -4334,7 +4364,7 @@ urProgramLink( /// + `NULL == hProgram` UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t hProgram ///< [in][retain] handle for the Program to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -4985,7 +5015,7 @@ urKernelGetSubGroupInfo( /// + `NULL == hKernel` UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -5492,7 +5522,7 @@ urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t hQueue ///< [in][retain] handle of the queue object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -5887,7 +5917,7 @@ urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ); /////////////////////////////////////////////////////////////////////////////// @@ -8252,7 +8282,7 @@ urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp( - ur_exp_command_buffer_handle_t hCommandBuffer ///< [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ); /////////////////////////////////////////////////////////////////////////////// @@ -9657,6 +9687,15 @@ typedef struct ur_loader_config_set_code_location_callback_params_t { void **ppUserData; } ur_loader_config_set_code_location_callback_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urLoaderConfigSetMockingEnabled +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_loader_config_set_mocking_enabled_params_t { + ur_loader_config_handle_t *phLoaderConfig; + ur_bool_t *penable; +} ur_loader_config_set_mocking_enabled_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urPlatformGet /// @details Each entry is a pointer to the parameter passed to the function; diff --git a/include/ur_print.h b/include/ur_print.h index 60aa71f03b..682ec7660b 100644 --- a/include/ur_print.h +++ b/include/ur_print.h @@ -1106,6 +1106,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigEnableLayerParams(const s /// - `buff_size < out_size` UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigSetCodeLocationCallbackParams(const struct ur_loader_config_set_code_location_callback_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_loader_config_set_mocking_enabled_params_t struct +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigSetMockingEnabledParams(const struct ur_loader_config_set_mocking_enabled_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_platform_get_params_t struct /// @returns diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 6e8938decb..e45d180698 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -938,6 +938,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) { case UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP: os << "UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP"; break; + case UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED: + os << "UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED"; + break; default: os << "unknown enumerator"; break; @@ -10277,6 +10280,25 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_loader_config_set_mocking_enabled_params_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_loader_config_set_mocking_enabled_params_t *params) { + + os << ".hLoaderConfig = "; + + ur::details::printPtr(os, + *(params->phLoaderConfig)); + + os << ", "; + os << ".enable = "; + + os << *(params->penable); + + return os; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_platform_get_params_t type /// @returns @@ -17332,6 +17354,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_ case UR_FUNCTION_LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK: { os << (const struct ur_loader_config_set_code_location_callback_params_t *)params; } break; + case UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED: { + os << (const struct ur_loader_config_set_mocking_enabled_params_t *)params; + } break; case UR_FUNCTION_PLATFORM_GET: { os << (const struct ur_platform_get_params_t *)params; } break; diff --git a/scripts/YaML.md b/scripts/YaML.md index 089115cb1b..2ab5dd80b2 100644 --- a/scripts/YaML.md +++ b/scripts/YaML.md @@ -620,11 +620,12 @@ class ur_name_t(Structure): - `out` is used for params that are write-only; if the param is a pointer, then the memory being pointed to is also write-only - `in,out` is used for params that are both read and write; typically this is used for pointers to other data structures that contain both read and write params - `nocheck` is used to specify that no additional validation checks will be generated. - + `desc` may include one the following annotations: {`"[optional]"`, `"[range(start,end)]"`, `"[release]"`, `"[typename(typeVarName)]"`, `"[bounds(offset,size)]"`} + + `desc` may include one the following annotations: {`"[optional]"`, `"[range(start,end)]"`, `"[retain]"`, `"[release]"`, `"[typename(typeVarName)]"`, `"[bounds(offset,size)]"`} - `optional` is used for params that are handles or pointers where it is legal for the value to be `nullptr` - `range` is used for params that are array pointers to specify the valid range that the is valid to read + `start` and `end` must be an ISO-C standard identifier or literal + `start` is inclusive and `end` is exclusive + - `retain` is used for params that are handles or pointers to handles where the function will increment the reference counter associated with the handle(s). - `release` is used for params that are handles or pointers to handles where the function will decrement the handle's reference count, potentially leaving it in an invalid state if the reference count reaches zero. - `typename` is used to denote the type enum for params that are opaque pointers to values of tagged data types. - `bounds` is used for params that are memory objects or USM allocations. It specifies the range within the memory allocation represented by the param that will be accessed by the operation. diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index a81c282070..5e716101ef 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -256,6 +256,37 @@ Currently, UR looks for these adapter libraries: For more information about the usage of mentioned environment variables see `Environment Variables`_ section. +Mocking +--------------------- +A mock UR adapter can be accessed for test purposes by enabling the ``MOCK`` +layer as described below. When the mock layer is enabled, calls to the API will +still be intercepted by other layers (e.g. validation, tracing), but they will +stop short of the loader - the call chain will end in either a generic fallback +behavior defined by the mock layer itself, or a user defined replacement +callback. + +The default fallback behavior for entry points in the mock layer is to simply +return ``UR_RESULT_SUCCESS``. For entry points concerning handles, i.e. those +that create a new handle or modify the reference count of an existing one, a +dummy handle mechanism is used. This means the layer will return generic +handles that track a reference count, and ``Retain``/``Release`` entry points will +function as expected when used with these handles. + +During global setup the behavior of the mock layer can be customized by setting +chain of structs, with each registering a callback with a given entry point in +the API. Callbacks can be registered to be called ``BEFORE`` or ``AFTER`` the +generic implementation, or they can be registered to entirely ``REPLACE`` it. A +given entry point can only have one of each kind of callback associated with +it, multiple structs with the same function/mode combination will override +eachother. + +The callback signature defined by ``${x}_mock_callback_t`` takes a single +``void *`` parameter. When calling a user callback the layer will pack the +entry point's parameters into the appropriate ``_params_t`` struct (e.g. +``ur_adapter_get_params_t``) and pass a pointer to that struct into the +callback. This allows parameters to be accessed and modified. The definitions +for these parameter structs can be found in the main API header. + Layers --------------------- UR comes with a mechanism that allows various API intercept layers to be enabled, either through the API or with an environment variable (see `Environment Variables`_). @@ -278,6 +309,8 @@ Layers currently included with the runtime are as follows: - Enables the XPTI tracing layer, see Tracing_ for more detail. * - UR_LAYER_ASAN \| UR_LAYER_MSAN \| UR_LAYER_TSAN - Enables the device-side sanitizer layer, see Sanitizers_ for more detail. + * - UR_LAYER_MOCK + - Enables adapter mocking for test purposes. Similar behavior to the null adapter except entry points can be overridden or instrumented with callbacks. See Mocking_ for more detail. Environment Variables --------------------- diff --git a/scripts/core/adapter.yml b/scripts/core/adapter.yml index 746c5b3a60..958d135b78 100644 --- a/scripts/core/adapter.yml +++ b/scripts/core/adapter.yml @@ -70,7 +70,7 @@ params: - type: "$x_adapter_handle_t" name: hAdapter desc: | - [in] Adapter handle to retain + [in][retain] Adapter handle to retain --- #-------------------------------------------------------------------------- type: function desc: "Get the last adapter specific error." diff --git a/scripts/core/context.yml b/scripts/core/context.yml index fdd57c7991..69987cab99 100644 --- a/scripts/core/context.yml +++ b/scripts/core/context.yml @@ -80,7 +80,7 @@ params: - type: "$x_context_handle_t" name: hContext desc: | - [in] handle of the context to get a reference of. + [in][retain] handle of the context to get a reference of. --- #-------------------------------------------------------------------------- type: enum desc: "Supported context info" diff --git a/scripts/core/device.yml b/scripts/core/device.yml index 7249568954..ead3ceeb8d 100644 --- a/scripts/core/device.yml +++ b/scripts/core/device.yml @@ -505,7 +505,7 @@ params: - type: "$x_device_handle_t" name: hDevice desc: | - [in] handle of the device to get a reference of. + [in][retain] handle of the device to get a reference of. --- #-------------------------------------------------------------------------- type: function desc: "Releases the device handle reference indicating end of its usage" diff --git a/scripts/core/event.yml b/scripts/core/event.yml index 6c80d1e9ee..45bcbf7d40 100644 --- a/scripts/core/event.yml +++ b/scripts/core/event.yml @@ -230,7 +230,7 @@ analogue: params: - type: $x_event_handle_t name: hEvent - desc: "[in] handle of the event object" + desc: "[in][retain] handle of the event object" returns: - $X_RESULT_ERROR_INVALID_EVENT - $X_RESULT_ERROR_OUT_OF_RESOURCES diff --git a/scripts/core/exp-command-buffer.yml b/scripts/core/exp-command-buffer.yml index 7c12abfa20..73a76e6d87 100644 --- a/scripts/core/exp-command-buffer.yml +++ b/scripts/core/exp-command-buffer.yml @@ -252,7 +252,7 @@ name: RetainExp params: - type: $x_exp_command_buffer_handle_t name: hCommandBuffer - desc: "[in] Handle of the command-buffer object." + desc: "[in][retain] Handle of the command-buffer object." returns: - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP - $X_RESULT_ERROR_OUT_OF_RESOURCES diff --git a/scripts/core/kernel.yml b/scripts/core/kernel.yml index 9604102741..e657d806be 100644 --- a/scripts/core/kernel.yml +++ b/scripts/core/kernel.yml @@ -292,7 +292,7 @@ details: params: - type: $x_kernel_handle_t name: hKernel - desc: "[in] handle for the Kernel to retain" + desc: "[in][retain] handle for the Kernel to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release Kernel." diff --git a/scripts/core/loader.yml b/scripts/core/loader.yml index cd3d9931de..1df23607e8 100644 --- a/scripts/core/loader.yml +++ b/scripts/core/loader.yml @@ -1,4 +1,3 @@ -# # Copyright (C) 2022-2023 Intel Corporation # # Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. @@ -52,7 +51,7 @@ details: params: - type: $x_loader_config_handle_t name: hLoaderConfig - desc: "[in] loader config handle to retain" + desc: "[in][retain] loader config handle to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release config handle." @@ -187,6 +186,31 @@ params: name: pUserData desc: "[in][out][optional] pointer to data to be passed to callback." --- #-------------------------------------------------------------------------- +type: fptr_typedef +desc: "Callback to replace or instrument generic mock functionality in the mock adapter." +name: $x_mock_callback_t +return: $x_result_t +params: + - type: void* + name: pParams + desc: "[in][out] Pointer to the appropriate param struct for the function" +--- #-------------------------------------------------------------------------- +type: function +desc: "The only adapter reported with mock enabled will be the mock adapter." +details: + - "The mock adapter will default to returning $X_RESULT_SUCCESS for all entry points. It will also create and correctly reference count dummy handles where appropriate. Its behaviour can be modified by linking the $x_mock_headers library and using the callbacks object." +class: $xLoaderConfig +loader_only: True +name: "SetMockingEnabled" +decl: static +params: + - type: $x_loader_config_handle_t + name: hLoaderConfig + desc: "[in] Handle to config object mocking will be enabled for." + - type: $x_bool_t + name: enable + desc: "[in] Handle to config object the layer will be enabled for." +--- #-------------------------------------------------------------------------- type: function desc: "Initialize the $OneApi loader" class: $xLoader diff --git a/scripts/core/memory.yml b/scripts/core/memory.yml index 3ec5fd4853..4df4ae0d0b 100644 --- a/scripts/core/memory.yml +++ b/scripts/core/memory.yml @@ -351,7 +351,7 @@ details: params: - type: $x_mem_handle_t name: hMem - desc: "[in] handle of the memory object to get access" + desc: "[in][retain] handle of the memory object to get access" returns: - $X_RESULT_ERROR_INVALID_MEM_OBJECT - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/program.yml b/scripts/core/program.yml index c8b7557480..b7da9d62e7 100644 --- a/scripts/core/program.yml +++ b/scripts/core/program.yml @@ -263,7 +263,7 @@ details: params: - type: $x_program_handle_t name: hProgram - desc: "[in] handle for the Program to retain" + desc: "[in][retain] handle for the Program to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release Program." diff --git a/scripts/core/queue.yml b/scripts/core/queue.yml index 2e64f72c06..74386b911e 100644 --- a/scripts/core/queue.yml +++ b/scripts/core/queue.yml @@ -184,7 +184,7 @@ details: params: - type: $x_queue_handle_t name: hQueue - desc: "[in] handle of the queue object to get access" + desc: "[in][retain] handle of the queue object to get access" returns: - $X_RESULT_ERROR_INVALID_QUEUE - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 8157bbb08a..2b1fcf9f32 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -592,6 +592,9 @@ etors: - name: ENQUEUE_NATIVE_COMMAND_EXP desc: Enumerator for $xEnqueueNativeCommandExp value: '228' +- name: LOADER_CONFIG_SET_MOCKING_ENABLED + desc: Enumerator for $xLoaderConfigSetMockingEnabled + value: '231' --- type: enum desc: Defines structure types diff --git a/scripts/core/sampler.yml b/scripts/core/sampler.yml index a6649fa872..6459277c6f 100644 --- a/scripts/core/sampler.yml +++ b/scripts/core/sampler.yml @@ -116,7 +116,7 @@ analogue: params: - type: $x_sampler_handle_t name: hSampler - desc: "[in] handle of the sampler object to get access" + desc: "[in][retain] handle of the sampler object to get access" returns: - $X_RESULT_ERROR_INVALID_SAMPLER - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/usm.yml b/scripts/core/usm.yml index 29be62d7f5..da5cd8c578 100644 --- a/scripts/core/usm.yml +++ b/scripts/core/usm.yml @@ -433,7 +433,7 @@ ordinal: "0" params: - type: $x_usm_pool_handle_t name: pPool - desc: "[in] pointer to USM memory pool" + desc: "[in][retain] pointer to USM memory pool" returns: - $X_RESULT_ERROR_INVALID_NULL_HANDLE --- #-------------------------------------------------------------------------- diff --git a/scripts/core/virtual_memory.yml b/scripts/core/virtual_memory.yml index 6433b7b91d..5b12e1761e 100644 --- a/scripts/core/virtual_memory.yml +++ b/scripts/core/virtual_memory.yml @@ -292,7 +292,7 @@ name: Retain params: - type: $x_physical_mem_handle_t name: hPhysicalMem - desc: "[in] handle of the physical memory object to retain." + desc: "[in][retain] handle of the physical memory object to retain." --- #-------------------------------------------------------------------------- type: function diff --git a/scripts/generate_code.py b/scripts/generate_code.py index 5ff832945b..b3a1146a3d 100644 --- a/scripts/generate_code.py +++ b/scripts/generate_code.py @@ -218,14 +218,14 @@ def _mako_loader_cpp(path, namespace, tags, version, specs, meta): """ generates c/c++ files from the specification documents """ -def _mako_null_adapter_cpp(path, namespace, tags, version, specs, meta): - dstpath = os.path.join(path, "null") +def _mako_mock_adapter_cpp(path, namespace, tags, version, specs, meta): + dstpath = os.path.join(path, "mock") os.makedirs(dstpath, exist_ok=True) - template = "nullddi.cpp.mako" + template = "mockddi.cpp.mako" fin = os.path.join(templates_dir, template) - name = "%s_nullddi"%(namespace) + name = "%s_mockddi"%(namespace) filename = "%s.cpp"%(name) fout = os.path.join(dstpath, filename) @@ -388,7 +388,7 @@ def generate_adapters(path, section, namespace, tags, version, specs, meta): os.makedirs(dstpath, exist_ok=True) loc = 0 - loc += _mako_null_adapter_cpp(dstpath, namespace, tags, version, specs, meta) + loc += _mako_mock_adapter_cpp(dstpath, namespace, tags, version, specs, meta) loc += _mako_linker_scripts( dstpath, "adapter", "map", namespace, tags, version, specs, meta ) diff --git a/scripts/templates/helper.py b/scripts/templates/helper.py index ff5fda4449..9c08f8be11 100644 --- a/scripts/templates/helper.py +++ b/scripts/templates/helper.py @@ -379,6 +379,7 @@ class param_traits: RE_OPTIONAL = r".*\[optional\].*" RE_NOCHECK = r".*\[nocheck\].*" RE_RANGE = r".*\[range\((.+),\s*(.+)\)\][\S\s]*" + RE_RETAIN = r".*\[retain\].*" RE_RELEASE = r".*\[release\].*" RE_TYPENAME = r".*\[typename\((.+),\s(.+)\)\].*" RE_TAGGED = r".*\[tagged_by\((.+)\)].*" @@ -468,6 +469,13 @@ def range_end(cls, item): except: return None + @classmethod + def is_retain(cls, item): + try: + return True if re.match(cls.RE_RETAIN, item['desc']) else False + except: + return False + @classmethod def is_release(cls, item): try: @@ -915,6 +923,19 @@ def make_param_lines(namespace, tags, obj, decl=False, meta=None, format=["type" lines = ["void"] return lines +""" +Public: + searches params of function `obj` for a match to the given regex and + returns its full C++ name +""" +def find_param_name(name_re, namespace, tags, obj): + for param in obj['params']: + param_cpp_name = _get_param_name(namespace, tags, param) + print("searching {0} for pattner {1}".format(param_cpp_name, name_re)) + if re.search(name_re, param_cpp_name): + return param_cpp_name + return UNDEFINED + """ Public: returns a list of strings for the description @@ -1497,7 +1518,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname) fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname) - if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): + if param_traits.is_retain(item) or param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): if type_traits.is_class_handle(item['type'], meta): if param_traits.is_range(item): range_start = param_traits.range_start(item) @@ -1507,6 +1528,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): 'type': tname, 'obj': obj_name, 'factory': fty_name, + 'retain': param_traits.is_retain(item), 'release': param_traits.is_release(item), 'range': (range_start, range_end) }) @@ -1516,6 +1538,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): 'type': tname, 'obj': obj_name, 'factory': fty_name, + 'retain': param_traits.is_retain(item), 'release': param_traits.is_release(item), 'optional': param_traits.is_optional(item) }) @@ -1553,6 +1576,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): epilogue.append({ 'name': name, 'obj': obj_name, + 'retain': False, 'release': False, 'typename': typename, 'size': prop_size, diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako new file mode 100644 index 0000000000..56b333f798 --- /dev/null +++ b/scripts/templates/mockddi.cpp.mako @@ -0,0 +1,168 @@ +<%! +import re +from templates import helper as th +%><% + n=namespace + N=n.upper() + + x=tags['$x'] + X=x.upper() +%>/* + * + * Copyright (C) 2019-2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ${name}.cpp + * + */ +#include "${x}_mock.hpp" +#include "${x}_mock_helpers.hpp" + +namespace driver +{ + %for obj in th.get_adapter_functions(specs): + /////////////////////////////////////////////////////////////////////////////// + <% + fname = th.make_func_name(n, tags, obj) + %>/// @brief Intercept function for ${fname} + %if 'condition' in obj: + #if ${th.subt(n, tags, obj['condition'])} + %endif + __${x}dlllocal ${x}_result_t ${X}_APICALL + ${fname}( + %for line in th.make_param_lines(n, tags, obj): + ${line} + %endfor + ) + try { + ${x}_result_t result = ${X}_RESULT_SUCCESS; + ${th.get_initial_null_set(obj)} + + ${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} }; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("${fname}")); + if(beforeCallback) { + result = beforeCallback( ¶ms ); + if(result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("${fname}")); + if(replaceCallback) { + result = replaceCallback( ¶ms ); + } + else + { + <% + # We can use the loader epilogue to know when we should be creating mock handles + epilogue = th.get_loader_epilogue(specs, n, tags, obj, meta) + %> + %if 'NativeHandle' in fname: + <% func_class = th.subt(namespace, tags, obj['class'], False, True) %> + %if 'CreateWith' in fname: + *ph${func_class} = reinterpret_cast(hNative${func_class}); + mock::retainDummyHandle(*ph${func_class}); + %else: + *phNative${func_class} = reinterpret_cast(h${func_class}); + %endif + %else: + %if fname == 'urAdapterGet' or fname == 'urDeviceGet' or fname == 'urPlatformGet': + <% + num_param = th.find_param_name(".*pNum.*", n, tags, obj) + %> + if(${num_param}) { + *${num_param} = 1; + } + %endif + %for item in epilogue: + %if item['release']: + mock::releaseDummyHandle(${item['name']}); + %elif item['retain']: + mock::retainDummyHandle(${item['name']}); + %elif 'type' in item: + %if 'range' in item or ('optional' in item and item['optional']): + // optional output handle + if(${item['name']}) { + *${item['name']} = mock::createDummyHandle<${item['type']}>(); + } + %else: + *${item['name']} = mock::createDummyHandle<${item['type']}>(); + %endif + %endif + %endfor + %endif + result = UR_RESULT_SUCCESS; + } + + if(result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("${fname}")); + if(afterCallback) { + return afterCallback( ¶ms ); + } + + return result; + } catch(...) { return exceptionToResult(std::current_exception()); } + %if 'condition' in obj: + #endif // ${th.subt(n, tags, obj['condition'])} + %endif + + %endfor +} // namespace driver + +#if defined(__cplusplus) +extern "C" { +#endif + +%for tbl in th.get_pfntables(specs, meta, n, tags): +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's ${tbl['name']} table +/// with current process' addresses +/// +/// @returns +/// - ::${X}_RESULT_SUCCESS +/// - ::${X}_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::${X}_RESULT_ERROR_UNSUPPORTED_VERSION +${X}_DLLEXPORT ${x}_result_t ${X}_APICALL +${tbl['export']['name']}( + %for line in th.make_param_lines(n, tags, tbl['export']): + ${line} + %endfor + ) +try { + if( nullptr == pDdiTable ) + return ${X}_RESULT_ERROR_INVALID_NULL_POINTER; + + if( driver::d_context.version < version ) + return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION; + + ${x}_result_t result = ${X}_RESULT_SUCCESS; + + %for obj in tbl['functions']: + %if 'condition' in obj: +#if ${th.subt(n, tags, obj['condition'])} + %endif + pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = driver::${th.make_func_name(n, tags, obj)}; + %if 'condition' in obj: +#else + pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = nullptr; +#endif + %endif + + %endfor + return result; +} catch(...) { return exceptionToResult(std::current_exception()); } + +%endfor +#if defined(__cplusplus) +} +#endif diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako deleted file mode 100644 index 2adb62e691..0000000000 --- a/scripts/templates/nullddi.cpp.mako +++ /dev/null @@ -1,137 +0,0 @@ -<%! -import re -from templates import helper as th -%><% - n=namespace - N=n.upper() - - x=tags['$x'] - X=x.upper() -%>/* - * - * Copyright (C) 2019-2022 Intel Corporation - * - * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. - * See LICENSE.TXT - * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - * - * @file ${name}.cpp - * - */ -#include "${x}_null.hpp" - -namespace driver -{ - %for obj in th.get_adapter_functions(specs): - /////////////////////////////////////////////////////////////////////////////// - <% - fname = th.make_func_name(n, tags, obj) - %>/// @brief Intercept function for ${fname} - %if 'condition' in obj: - #if ${th.subt(n, tags, obj['condition'])} - %endif - __${x}dlllocal ${x}_result_t ${X}_APICALL - ${fname}( - %for line in th.make_param_lines(n, tags, obj): - ${line} - %endfor - ) - try { - ${x}_result_t result = ${X}_RESULT_SUCCESS; - ${th.get_initial_null_set(obj)} - - // if the driver has created a custom function, then call it instead of using the generic path - auto ${th.make_pfn_name(n, tags, obj)} = d_context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}; - if( nullptr != ${th.make_pfn_name(n, tags, obj)} ) - { - result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} ); - } - else - { - // generic implementation - %for item in th.get_loader_epilogue(specs, n, tags, obj, meta): - %if 'typename' in item: - if (${item['name']} != nullptr) { - switch (${item['typename']}) { - %for etor in item['etors']: - case ${etor['name']}: { - ${etor['type']} *handles = reinterpret_cast<${etor['type']} *>(${item['name']}); - size_t nelements = ${item['size']} / sizeof(${etor['type']}); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = reinterpret_cast<${etor['type']}>( d_context.get() ); - } - } break; - %endfor - default: {} break; - } - } - %elif 'range' in item: - for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i ) - ${item['name']}[ i ] = reinterpret_cast<${item['type']}>( d_context.get() ); - %elif not item['release']: - %if item['optional']: - if( nullptr != ${item['name']} ) *${item['name']} = reinterpret_cast<${item['type']}>( d_context.get() ); - %else: - *${item['name']} = reinterpret_cast<${item['type']}>( d_context.get() ); - %endif - %endif - - %endfor - } - - return result; - } catch(...) { return exceptionToResult(std::current_exception()); } - %if 'condition' in obj: - #endif // ${th.subt(n, tags, obj['condition'])} - %endif - - %endfor -} // namespace driver - -#if defined(__cplusplus) -extern "C" { -#endif - -%for tbl in th.get_pfntables(specs, meta, n, tags): -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's ${tbl['name']} table -/// with current process' addresses -/// -/// @returns -/// - ::${X}_RESULT_SUCCESS -/// - ::${X}_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::${X}_RESULT_ERROR_UNSUPPORTED_VERSION -${X}_DLLEXPORT ${x}_result_t ${X}_APICALL -${tbl['export']['name']}( - %for line in th.make_param_lines(n, tags, tbl['export']): - ${line} - %endfor - ) -try { - if( nullptr == pDdiTable ) - return ${X}_RESULT_ERROR_INVALID_NULL_POINTER; - - if( driver::d_context.version < version ) - return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION; - - ${x}_result_t result = ${X}_RESULT_SUCCESS; - - %for obj in tbl['functions']: - %if 'condition' in obj: -#if ${th.subt(n, tags, obj['condition'])} - %endif - pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = driver::${th.make_func_name(n, tags, obj)}; - %if 'condition' in obj: -#else - pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = nullptr; -#endif - %endif - - %endfor - return result; -} catch(...) { return exceptionToResult(std::current_exception()); } - -%endfor -#if defined(__cplusplus) -} -#endif diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index cbbc635191..0e994bfb9c 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -9,4 +9,5 @@ add_definitions(-DUR_VALIDATION_LAYER_SUPPORTED_VERSION="${PROJECT_VERSION_MAJOR add_subdirectory(common) add_subdirectory(loader) #add_subdirectory(layers) +add_subdirectory(mock) add_subdirectory(adapters) diff --git a/source/adapters/CMakeLists.txt b/source/adapters/CMakeLists.txt index 71b9baafa2..e45f39fca8 100644 --- a/source/adapters/CMakeLists.txt +++ b/source/adapters/CMakeLists.txt @@ -28,7 +28,7 @@ function(add_ur_adapter name) endif() endfunction() -add_subdirectory(null) +add_subdirectory(mock) function(add_ur_adapter_subdirectory name) string(TOUPPER ${name} NAME) diff --git a/source/adapters/null/CMakeLists.txt b/source/adapters/mock/CMakeLists.txt similarity index 72% rename from source/adapters/null/CMakeLists.txt rename to source/adapters/mock/CMakeLists.txt index 0d4aa13e01..e51add1301 100644 --- a/source/adapters/null/CMakeLists.txt +++ b/source/adapters/mock/CMakeLists.txt @@ -3,13 +3,13 @@ # See LICENSE.TXT # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -set(TARGET_NAME ur_adapter_null) +set(TARGET_NAME ur_adapter_mock) add_ur_adapter(${TARGET_NAME} SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/ur_null.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/ur_null.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ur_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ur_mock.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/ur_mock.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ur_mockddi.cpp ) set_target_properties(${TARGET_NAME} PROPERTIES @@ -20,4 +20,5 @@ set_target_properties(${TARGET_NAME} PROPERTIES target_link_libraries(${TARGET_NAME} PRIVATE ${PROJECT_NAME}::headers ${PROJECT_NAME}::common + ${PROJECT_NAME}::mock ) diff --git a/source/adapters/null/ur_null.cpp b/source/adapters/mock/ur_mock.cpp similarity index 99% rename from source/adapters/null/ur_null.cpp rename to source/adapters/mock/ur_mock.cpp index 84ad1ba352..cf95f08aee 100644 --- a/source/adapters/null/ur_null.cpp +++ b/source/adapters/mock/ur_mock.cpp @@ -9,7 +9,7 @@ * @file ur_null.cpp * */ -#include "ur_null.hpp" +#include "ur_mock.hpp" namespace driver { ////////////////////////////////////////////////////////////////////////// diff --git a/source/adapters/null/ur_null.hpp b/source/adapters/mock/ur_mock.hpp similarity index 86% rename from source/adapters/null/ur_null.hpp rename to source/adapters/mock/ur_mock.hpp index b9b997f5bf..f82a56bfcd 100644 --- a/source/adapters/null/ur_null.hpp +++ b/source/adapters/mock/ur_mock.hpp @@ -10,13 +10,11 @@ * */ #include "ur_api.h" -#ifndef UR_ADAPTER_NULL_H -#define UR_ADAPTER_NULL_H 1 +#ifndef UR_ADAPTER_MOCK_H +#define UR_ADAPTER_MOCK_H 1 #include "ur_ddi.h" #include "ur_util.hpp" -#include -#include namespace driver { /////////////////////////////////////////////////////////////////////////////// @@ -37,6 +35,7 @@ class __urdlllocal context_t { }; extern context_t d_context; + } // namespace driver -#endif /* UR_ADAPTER_NULL_H */ +#endif /* UR_ADAPTER_MOCK_H */ diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/mock/ur_mockddi.cpp similarity index 51% rename from source/adapters/null/ur_nullddi.cpp rename to source/adapters/mock/ur_mockddi.cpp index e7d412c7fe..4ecce32beb 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -1,15 +1,16 @@ /* * - * Copyright (C) 2019-2022 Intel Corporation + * Copyright (C) 2019-2024 Intel Corporation * * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file ur_nullddi.cpp + * @file ur_mockddi.cpp * */ -#include "ur_null.hpp" +#include "ur_mock.hpp" +#include "ur_mock_helpers.hpp" namespace driver { /////////////////////////////////////////////////////////////////////////////// @@ -29,16 +30,41 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAdapterGet = d_context.urDdiTable.Global.pfnAdapterGet; - if (nullptr != pfnAdapterGet) { - result = pfnAdapterGet(NumEntries, phAdapters, pNumAdapters); + ur_adapter_get_params_t params = {&NumEntries, &phAdapters, &pNumAdapters}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urAdapterGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urAdapterGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - for (size_t i = 0; (nullptr != phAdapters) && (i < NumEntries); ++i) { - phAdapters[i] = - reinterpret_cast(d_context.get()); + + if (pNumAdapters) { + *pNumAdapters = 1; + } + // optional output handle + if (phAdapters) { + *phAdapters = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urAdapterGet")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -53,12 +79,35 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAdapterRelease = d_context.urDdiTable.Global.pfnAdapterRelease; - if (nullptr != pfnAdapterRelease) { - result = pfnAdapterRelease(hAdapter); + ur_adapter_release_params_t params = {&hAdapter}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urAdapterRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urAdapterRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hAdapter); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urAdapterRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -69,16 +118,39 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAdapterRetain = d_context.urDdiTable.Global.pfnAdapterRetain; - if (nullptr != pfnAdapterRetain) { - result = pfnAdapterRetain(hAdapter); + ur_adapter_retain_params_t params = {&hAdapter}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urAdapterRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urAdapterRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hAdapter); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urAdapterRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -99,13 +171,35 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetLastError( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAdapterGetLastError = - d_context.urDdiTable.Global.pfnAdapterGetLastError; - if (nullptr != pfnAdapterGetLastError) { - result = pfnAdapterGetLastError(hAdapter, ppMessage, pError); + ur_adapter_get_last_error_params_t params = {&hAdapter, &ppMessage, + &pError}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urAdapterGetLastError")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urAdapterGetLastError")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urAdapterGetLastError")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -130,13 +224,35 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAdapterGetInfo = d_context.urDdiTable.Global.pfnAdapterGetInfo; - if (nullptr != pfnAdapterGetInfo) { - result = pfnAdapterGetInfo(hAdapter, propName, propSize, pPropValue, - pPropSizeRet); + ur_adapter_get_info_params_t params = {&hAdapter, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urAdapterGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urAdapterGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urAdapterGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -164,17 +280,42 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGet = d_context.urDdiTable.Platform.pfnGet; - if (nullptr != pfnGet) { - result = pfnGet(phAdapters, NumAdapters, NumEntries, phPlatforms, - pNumPlatforms); + ur_platform_get_params_t params = {&phAdapters, &NumAdapters, &NumEntries, + &phPlatforms, &pNumPlatforms}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPlatformGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPlatformGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - for (size_t i = 0; (nullptr != phPlatforms) && (i < NumEntries); ++i) { - phPlatforms[i] = - reinterpret_cast(d_context.get()); + + if (pNumPlatforms) { + *pNumPlatforms = 1; + } + // optional output handle + if (phPlatforms) { + *phPlatforms = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformGet")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -199,13 +340,35 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Platform.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hPlatform, propName, propSize, pPropValue, pPropSizeRet); + ur_platform_get_info_params_t params = {&hPlatform, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPlatformGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPlatformGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -221,12 +384,34 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetApiVersion( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetApiVersion = d_context.urDdiTable.Platform.pfnGetApiVersion; - if (nullptr != pfnGetApiVersion) { - result = pfnGetApiVersion(hPlatform, pVersion); + ur_platform_get_api_version_params_t params = {&hPlatform, &pVersion}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPlatformGetApiVersion")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPlatformGetApiVersion")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformGetApiVersion")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -243,14 +428,36 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Platform.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hPlatform, phNativePlatform); + ur_platform_get_native_handle_params_t params = {&hPlatform, + &phNativePlatform}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPlatformGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPlatformGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativePlatform = - reinterpret_cast(d_context.get()); + + *phNativePlatform = reinterpret_cast(hPlatform); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -272,15 +479,39 @@ __urdlllocal ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Platform.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativePlatform, hAdapter, - pProperties, phPlatform); + ur_platform_create_with_native_handle_params_t params = { + &hNativePlatform, &hAdapter, &pProperties, &phPlatform}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urPlatformCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urPlatformCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phPlatform = reinterpret_cast(d_context.get()); + + *phPlatform = reinterpret_cast(hNativePlatform); + mock::retainDummyHandle(*phPlatform); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -300,14 +531,35 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetBackendOption( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetBackendOption = - d_context.urDdiTable.Platform.pfnGetBackendOption; - if (nullptr != pfnGetBackendOption) { - result = - pfnGetBackendOption(hPlatform, pFrontendOption, ppPlatformOption); + ur_platform_get_backend_option_params_t params = { + &hPlatform, &pFrontendOption, &ppPlatformOption}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPlatformGetBackendOption")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPlatformGetBackendOption")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPlatformGetBackendOption")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -334,17 +586,42 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGet = d_context.urDdiTable.Device.pfnGet; - if (nullptr != pfnGet) { - result = - pfnGet(hPlatform, DeviceType, NumEntries, phDevices, pNumDevices); + ur_device_get_params_t params = {&hPlatform, &DeviceType, &NumEntries, + &phDevices, &pNumDevices}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - for (size_t i = 0; (nullptr != phDevices) && (i < NumEntries); ++i) { - phDevices[i] = - reinterpret_cast(d_context.get()); + + if (pNumDevices) { + *pNumDevices = 1; } + // optional output handle + if (phDevices) { + *phDevices = mock::createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceGet")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -370,57 +647,37 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Device.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hDevice, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_DEVICE_INFO_PLATFORM: { - ur_platform_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_platform_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_DEVICE_INFO_PARENT_DEVICE: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_DEVICE_INFO_COMPONENT_DEVICES: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_DEVICE_INFO_COMPOSITE_DEVICE: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_device_get_info_params_t params = {&hDevice, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -430,16 +687,39 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Device.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hDevice); + ur_device_retain_params_t params = {&hDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hDevice); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -455,12 +735,35 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Device.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hDevice); + ur_device_release_params_t params = {&hDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hDevice); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -485,17 +788,39 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPartition = d_context.urDdiTable.Device.pfnPartition; - if (nullptr != pfnPartition) { - result = pfnPartition(hDevice, pProperties, NumDevices, phSubDevices, - pNumDevicesRet); + ur_device_partition_params_t params = {&hDevice, &pProperties, &NumDevices, + &phSubDevices, &pNumDevicesRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDevicePartition")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDevicePartition")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - for (size_t i = 0; (nullptr != phSubDevices) && (i < NumDevices); ++i) { - phSubDevices[i] = - reinterpret_cast(d_context.get()); + + // optional output handle + if (phSubDevices) { + *phSubDevices = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDevicePartition")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -519,13 +844,35 @@ __urdlllocal ur_result_t UR_APICALL urDeviceSelectBinary( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSelectBinary = d_context.urDdiTable.Device.pfnSelectBinary; - if (nullptr != pfnSelectBinary) { - result = - pfnSelectBinary(hDevice, pBinaries, NumBinaries, pSelectedBinary); + ur_device_select_binary_params_t params = {&hDevice, &pBinaries, + &NumBinaries, &pSelectedBinary}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceSelectBinary")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceSelectBinary")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceSelectBinary")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -542,13 +889,35 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Device.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hDevice, phNativeDevice); + ur_device_get_native_handle_params_t params = {&hDevice, &phNativeDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeDevice = reinterpret_cast(d_context.get()); + + *phNativeDevice = reinterpret_cast(hDevice); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -569,15 +938,37 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Device.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform, - pProperties, phDevice); + ur_device_create_with_native_handle_params_t params = { + &hNativeDevice, &hPlatform, &pProperties, &phDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phDevice = reinterpret_cast(d_context.get()); + + *phDevice = reinterpret_cast(hNativeDevice); + mock::retainDummyHandle(*phDevice); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -598,14 +989,35 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetGlobalTimestamps( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetGlobalTimestamps = - d_context.urDdiTable.Device.pfnGetGlobalTimestamps; - if (nullptr != pfnGetGlobalTimestamps) { - result = - pfnGetGlobalTimestamps(hDevice, pDeviceTimestamp, pHostTimestamp); + ur_device_get_global_timestamps_params_t params = { + &hDevice, &pDeviceTimestamp, &pHostTimestamp}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urDeviceGetGlobalTimestamps")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urDeviceGetGlobalTimestamps")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urDeviceGetGlobalTimestamps")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -626,13 +1038,36 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreate = d_context.urDdiTable.Context.pfnCreate; - if (nullptr != pfnCreate) { - result = pfnCreate(DeviceCount, phDevices, pProperties, phContext); + ur_context_create_params_t params = {&DeviceCount, &phDevices, &pProperties, + &phContext}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phContext = reinterpret_cast(d_context.get()); + + *phContext = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -644,16 +1079,39 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Context.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hContext); + ur_context_retain_params_t params = {&hContext}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hContext); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -669,12 +1127,35 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Context.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hContext); + ur_context_release_params_t params = {&hContext}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hContext); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -701,30 +1182,37 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Context.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hContext, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_CONTEXT_INFO_DEVICES: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_context_get_info_params_t params = {&hContext, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -739,14 +1227,36 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Context.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hContext, phNativeContext); + ur_context_get_native_handle_params_t params = {&hContext, + &phNativeContext}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeContext = - reinterpret_cast(d_context.get()); + + *phNativeContext = reinterpret_cast(hContext); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -769,15 +1279,38 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Context.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeContext, numDevices, - phDevices, pProperties, phContext); + ur_context_create_with_native_handle_params_t params = { + &hNativeContext, &numDevices, &phDevices, &pProperties, &phContext}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urContextCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phContext = reinterpret_cast(d_context.get()); + + *phContext = reinterpret_cast(hNativeContext); + mock::retainDummyHandle(*phContext); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -796,13 +1329,35 @@ __urdlllocal ur_result_t UR_APICALL urContextSetExtendedDeleter( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetExtendedDeleter = - d_context.urDdiTable.Context.pfnSetExtendedDeleter; - if (nullptr != pfnSetExtendedDeleter) { - result = pfnSetExtendedDeleter(hContext, pfnDeleter, pUserData); + ur_context_set_extended_deleter_params_t params = {&hContext, &pfnDeleter, + &pUserData}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urContextSetExtendedDeleter")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urContextSetExtendedDeleter")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urContextSetExtendedDeleter")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -823,14 +1378,36 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageCreate = d_context.urDdiTable.Mem.pfnImageCreate; - if (nullptr != pfnImageCreate) { - result = pfnImageCreate(hContext, flags, pImageFormat, pImageDesc, - pHost, phMem); + ur_mem_image_create_params_t params = {&hContext, &flags, &pImageFormat, + &pImageDesc, &pHost, &phMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemImageCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemImageCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phMem = reinterpret_cast(d_context.get()); + + *phMem = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemImageCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -851,13 +1428,36 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnBufferCreate = d_context.urDdiTable.Mem.pfnBufferCreate; - if (nullptr != pfnBufferCreate) { - result = pfnBufferCreate(hContext, flags, size, pProperties, phBuffer); + ur_mem_buffer_create_params_t params = {&hContext, &flags, &size, + &pProperties, &phBuffer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemBufferCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemBufferCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phBuffer = reinterpret_cast(d_context.get()); + + *phBuffer = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemBufferCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -868,16 +1468,40 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Mem.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hMem); + ur_mem_retain_params_t params = {&hMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -893,12 +1517,35 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Mem.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hMem); + ur_mem_release_params_t params = {&hMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -920,14 +1567,36 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnBufferPartition = d_context.urDdiTable.Mem.pfnBufferPartition; - if (nullptr != pfnBufferPartition) { - result = pfnBufferPartition(hBuffer, flags, bufferCreateType, pRegion, - phMem); + ur_mem_buffer_partition_params_t params = { + &hBuffer, &flags, &bufferCreateType, &pRegion, &phMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemBufferPartition")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemBufferPartition")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phMem = reinterpret_cast(d_context.get()); + + *phMem = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemBufferPartition")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -947,13 +1616,35 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Mem.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hMem, hDevice, phNativeMem); + ur_mem_get_native_handle_params_t params = {&hMem, &hDevice, &phNativeMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeMem = reinterpret_cast(d_context.get()); + + *phNativeMem = reinterpret_cast(hMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -974,15 +1665,40 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnBufferCreateWithNativeHandle = - d_context.urDdiTable.Mem.pfnBufferCreateWithNativeHandle; - if (nullptr != pfnBufferCreateWithNativeHandle) { - result = pfnBufferCreateWithNativeHandle(hNativeMem, hContext, - pProperties, phMem); + ur_mem_buffer_create_with_native_handle_params_t params = { + &hNativeMem, &hContext, &pProperties, &phMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urMemBufferCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urMemBufferCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phMem = reinterpret_cast(d_context.get()); + + *phMem = reinterpret_cast(hNativeMem); + mock::retainDummyHandle(*phMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urMemBufferCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1006,15 +1722,40 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageCreateWithNativeHandle = - d_context.urDdiTable.Mem.pfnImageCreateWithNativeHandle; - if (nullptr != pfnImageCreateWithNativeHandle) { - result = pfnImageCreateWithNativeHandle( - hNativeMem, hContext, pImageFormat, pImageDesc, pProperties, phMem); + ur_mem_image_create_with_native_handle_params_t params = { + &hNativeMem, &hContext, &pImageFormat, + &pImageDesc, &pProperties, &phMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urMemImageCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urMemImageCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phMem = reinterpret_cast(d_context.get()); + + *phMem = reinterpret_cast(hNativeMem); + mock::retainDummyHandle(*phMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemImageCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1041,30 +1782,37 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Mem.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hMemory, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_MEM_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_mem_get_info_params_t params = {&hMemory, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -1088,13 +1836,35 @@ __urdlllocal ur_result_t UR_APICALL urMemImageGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageGetInfo = d_context.urDdiTable.Mem.pfnImageGetInfo; - if (nullptr != pfnImageGetInfo) { - result = pfnImageGetInfo(hMemory, propName, propSize, pPropValue, - pPropSizeRet); + ur_mem_image_get_info_params_t params = {&hMemory, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urMemImageGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urMemImageGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urMemImageGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1112,13 +1882,35 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreate = d_context.urDdiTable.Sampler.pfnCreate; - if (nullptr != pfnCreate) { - result = pfnCreate(hContext, pDesc, phSampler); + ur_sampler_create_params_t params = {&hContext, &pDesc, &phSampler}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urSamplerCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phSampler = reinterpret_cast(d_context.get()); + + *phSampler = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1130,16 +1922,39 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Sampler.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hSampler); + ur_sampler_retain_params_t params = {&hSampler}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urSamplerRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hSampler); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1155,12 +1970,35 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Sampler.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hSampler); + ur_sampler_release_params_t params = {&hSampler}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urSamplerRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hSampler); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1183,30 +2021,37 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Sampler.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hSampler, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_SAMPLER_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_sampler_get_info_params_t params = {&hSampler, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urSamplerGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -1221,14 +2066,36 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Sampler.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hSampler, phNativeSampler); + ur_sampler_get_native_handle_params_t params = {&hSampler, + &phNativeSampler}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urSamplerGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeSampler = - reinterpret_cast(d_context.get()); + + *phNativeSampler = reinterpret_cast(hSampler); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1249,15 +2116,38 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Sampler.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeSampler, hContext, - pProperties, phSampler); + ur_sampler_create_with_native_handle_params_t params = { + &hNativeSampler, &hContext, &pProperties, &phSampler}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urSamplerCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urSamplerCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phSampler = reinterpret_cast(d_context.get()); + + *phSampler = reinterpret_cast(hNativeSampler); + mock::retainDummyHandle(*phSampler); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urSamplerCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1279,12 +2169,35 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnHostAlloc = d_context.urDdiTable.USM.pfnHostAlloc; - if (nullptr != pfnHostAlloc) { - result = pfnHostAlloc(hContext, pUSMDesc, pool, size, ppMem); + ur_usm_host_alloc_params_t params = {&hContext, &pUSMDesc, &pool, &size, + &ppMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMHostAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMHostAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMHostAlloc")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1307,12 +2220,35 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnDeviceAlloc = d_context.urDdiTable.USM.pfnDeviceAlloc; - if (nullptr != pfnDeviceAlloc) { - result = pfnDeviceAlloc(hContext, hDevice, pUSMDesc, pool, size, ppMem); + ur_usm_device_alloc_params_t params = {&hContext, &hDevice, &pUSMDesc, + &pool, &size, &ppMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMDeviceAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMDeviceAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMDeviceAlloc")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1335,12 +2271,35 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSharedAlloc = d_context.urDdiTable.USM.pfnSharedAlloc; - if (nullptr != pfnSharedAlloc) { - result = pfnSharedAlloc(hContext, hDevice, pUSMDesc, pool, size, ppMem); + ur_usm_shared_alloc_params_t params = {&hContext, &hDevice, &pUSMDesc, + &pool, &size, &ppMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMSharedAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMSharedAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMSharedAlloc")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1356,12 +2315,34 @@ __urdlllocal ur_result_t UR_APICALL urUSMFree( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnFree = d_context.urDdiTable.USM.pfnFree; - if (nullptr != pfnFree) { - result = pfnFree(hContext, pMem); + ur_usm_free_params_t params = {&hContext, &pMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMFree")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMFree")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMFree")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1386,39 +2367,37 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetMemAllocInfo = d_context.urDdiTable.USM.pfnGetMemAllocInfo; - if (nullptr != pfnGetMemAllocInfo) { - result = pfnGetMemAllocInfo(hContext, pMem, propName, propSize, - pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_USM_ALLOC_INFO_DEVICE: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_USM_ALLOC_INFO_POOL: { - ur_usm_pool_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_usm_pool_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_usm_get_mem_alloc_info_params_t params = { + &hContext, &pMem, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMGetMemAllocInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMGetMemAllocInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMGetMemAllocInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -1435,13 +2414,35 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPoolCreate = d_context.urDdiTable.USM.pfnPoolCreate; - if (nullptr != pfnPoolCreate) { - result = pfnPoolCreate(hContext, pPoolDesc, ppPool); + ur_usm_pool_create_params_t params = {&hContext, &pPoolDesc, &ppPool}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMPoolCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMPoolCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *ppPool = reinterpret_cast(d_context.get()); + + *ppPool = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMPoolCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1452,16 +2453,39 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPoolRetain = d_context.urDdiTable.USM.pfnPoolRetain; - if (nullptr != pfnPoolRetain) { - result = pfnPoolRetain(pPool); + ur_usm_pool_retain_params_t params = {&pPool}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMPoolRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMPoolRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(pPool); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMPoolRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1476,12 +2500,35 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPoolRelease = d_context.urDdiTable.USM.pfnPoolRelease; - if (nullptr != pfnPoolRelease) { - result = pfnPoolRelease(pPool); + ur_usm_pool_release_params_t params = {&pPool}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMPoolRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMPoolRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(pPool); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMPoolRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1503,30 +2550,37 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPoolGetInfo = d_context.urDdiTable.USM.pfnPoolGetInfo; - if (nullptr != pfnPoolGetInfo) { - result = - pfnPoolGetInfo(hPool, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_USM_POOL_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_usm_pool_get_info_params_t params = {&hPool, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMPoolGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMPoolGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMPoolGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -1553,14 +2607,35 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGranularityGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGranularityGetInfo = - d_context.urDdiTable.VirtualMem.pfnGranularityGetInfo; - if (nullptr != pfnGranularityGetInfo) { - result = pfnGranularityGetInfo(hContext, hDevice, propName, propSize, - pPropValue, pPropSizeRet); + ur_virtual_mem_granularity_get_info_params_t params = { + &hContext, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemGranularityGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemGranularityGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemGranularityGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1584,12 +2659,35 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemReserve( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReserve = d_context.urDdiTable.VirtualMem.pfnReserve; - if (nullptr != pfnReserve) { - result = pfnReserve(hContext, pStart, size, ppStart); + ur_virtual_mem_reserve_params_t params = {&hContext, &pStart, &size, + &ppStart}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemReserve")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemReserve")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemReserve")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1607,12 +2705,34 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemFree( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnFree = d_context.urDdiTable.VirtualMem.pfnFree; - if (nullptr != pfnFree) { - result = pfnFree(hContext, pStart, size); + ur_virtual_mem_free_params_t params = {&hContext, &pStart, &size}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemFree")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemFree")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemFree")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1636,12 +2756,35 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemMap( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMap = d_context.urDdiTable.VirtualMem.pfnMap; - if (nullptr != pfnMap) { - result = pfnMap(hContext, pStart, size, hPhysicalMem, offset, flags); + ur_virtual_mem_map_params_t params = {&hContext, &pStart, &size, + &hPhysicalMem, &offset, &flags}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemMap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemMap")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemMap")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1659,12 +2802,34 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemUnmap( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUnmap = d_context.urDdiTable.VirtualMem.pfnUnmap; - if (nullptr != pfnUnmap) { - result = pfnUnmap(hContext, pStart, size); + ur_virtual_mem_unmap_params_t params = {&hContext, &pStart, &size}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemUnmap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemUnmap")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemUnmap")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1684,12 +2849,35 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemSetAccess( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetAccess = d_context.urDdiTable.VirtualMem.pfnSetAccess; - if (nullptr != pfnSetAccess) { - result = pfnSetAccess(hContext, pStart, size, flags); + ur_virtual_mem_set_access_params_t params = {&hContext, &pStart, &size, + &flags}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemSetAccess")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemSetAccess")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemSetAccess")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1717,13 +2905,36 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.VirtualMem.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = pfnGetInfo(hContext, pStart, size, propName, propSize, - pPropValue, pPropSizeRet); + ur_virtual_mem_get_info_params_t params = { + &hContext, &pStart, &size, &propName, + &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urVirtualMemGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urVirtualMemGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urVirtualMemGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1746,14 +2957,36 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreate = d_context.urDdiTable.PhysicalMem.pfnCreate; - if (nullptr != pfnCreate) { - result = pfnCreate(hContext, hDevice, size, pProperties, phPhysicalMem); + ur_physical_mem_create_params_t params = {&hContext, &hDevice, &size, + &pProperties, &phPhysicalMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPhysicalMemCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPhysicalMemCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phPhysicalMem = - reinterpret_cast(d_context.get()); + + *phPhysicalMem = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPhysicalMemCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1765,16 +2998,39 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.PhysicalMem.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hPhysicalMem); + ur_physical_mem_retain_params_t params = {&hPhysicalMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPhysicalMemRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPhysicalMemRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hPhysicalMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPhysicalMemRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1790,12 +3046,35 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.PhysicalMem.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hPhysicalMem); + ur_physical_mem_release_params_t params = {&hPhysicalMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urPhysicalMemRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urPhysicalMemRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hPhysicalMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urPhysicalMemRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1816,13 +3095,36 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithIL = d_context.urDdiTable.Program.pfnCreateWithIL; - if (nullptr != pfnCreateWithIL) { - result = pfnCreateWithIL(hContext, pIL, length, pProperties, phProgram); + ur_program_create_with_il_params_t params = {&hContext, &pIL, &length, + &pProperties, &phProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramCreateWithIL")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramCreateWithIL")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phProgram = reinterpret_cast(d_context.get()); + + *phProgram = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramCreateWithIL")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1845,14 +3147,36 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithBinary = d_context.urDdiTable.Program.pfnCreateWithBinary; - if (nullptr != pfnCreateWithBinary) { - result = pfnCreateWithBinary(hContext, hDevice, size, pBinary, - pProperties, phProgram); + ur_program_create_with_binary_params_t params = { + &hContext, &hDevice, &size, &pBinary, &pProperties, &phProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramCreateWithBinary")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramCreateWithBinary")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phProgram = reinterpret_cast(d_context.get()); + + *phProgram = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramCreateWithBinary")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1870,12 +3194,34 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuild( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnBuild = d_context.urDdiTable.Program.pfnBuild; - if (nullptr != pfnBuild) { - result = pfnBuild(hContext, hProgram, pOptions); + ur_program_build_params_t params = {&hContext, &hProgram, &pOptions}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramBuild")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramBuild")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramBuild")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1894,12 +3240,34 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompile( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCompile = d_context.urDdiTable.Program.pfnCompile; - if (nullptr != pfnCompile) { - result = pfnCompile(hContext, hProgram, pOptions); + ur_program_compile_params_t params = {&hContext, &hProgram, &pOptions}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramCompile")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramCompile")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramCompile")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1924,13 +3292,36 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( *phProgram = nullptr; } - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnLink = d_context.urDdiTable.Program.pfnLink; - if (nullptr != pfnLink) { - result = pfnLink(hContext, count, phPrograms, pOptions, phProgram); + ur_program_link_params_t params = {&hContext, &count, &phPrograms, + &pOptions, &phProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramLink")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramLink")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phProgram = reinterpret_cast(d_context.get()); + + *phProgram = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramLink")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1941,16 +3332,40 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Program.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hProgram); + ur_program_retain_params_t params = {&hProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hProgram); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1966,12 +3381,35 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Program.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hProgram); + ur_program_release_params_t params = {&hProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hProgram); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -1995,14 +3433,35 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetFunctionPointer = - d_context.urDdiTable.Program.pfnGetFunctionPointer; - if (nullptr != pfnGetFunctionPointer) { - result = pfnGetFunctionPointer(hDevice, hProgram, pFunctionName, - ppFunctionPointer); + ur_program_get_function_pointer_params_t params = { + &hDevice, &hProgram, &pFunctionName, &ppFunctionPointer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramGetFunctionPointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramGetFunctionPointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramGetFunctionPointer")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2027,15 +3486,39 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetGlobalVariablePointer = - d_context.urDdiTable.Program.pfnGetGlobalVariablePointer; - if (nullptr != pfnGetGlobalVariablePointer) { - result = pfnGetGlobalVariablePointer( - hDevice, hProgram, pGlobalVariableName, pGlobalVariableSizeRet, - ppGlobalVariablePointerRet); + ur_program_get_global_variable_pointer_params_t params = { + &hDevice, &hProgram, &pGlobalVariableName, &pGlobalVariableSizeRet, + &ppGlobalVariablePointerRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urProgramGetGlobalVariablePointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urProgramGetGlobalVariablePointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urProgramGetGlobalVariablePointer")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2061,39 +3544,37 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Program.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hProgram, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_PROGRAM_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_PROGRAM_INFO_DEVICES: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_program_get_info_params_t params = {&hProgram, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -2119,13 +3600,35 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetBuildInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetBuildInfo = d_context.urDdiTable.Program.pfnGetBuildInfo; - if (nullptr != pfnGetBuildInfo) { - result = pfnGetBuildInfo(hProgram, hDevice, propName, propSize, - pPropValue, pPropSizeRet); + ur_program_get_build_info_params_t params = { + &hProgram, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramGetBuildInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramGetBuildInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramGetBuildInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2144,13 +3647,38 @@ __urdlllocal ur_result_t UR_APICALL urProgramSetSpecializationConstants( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetSpecializationConstants = - d_context.urDdiTable.Program.pfnSetSpecializationConstants; - if (nullptr != pfnSetSpecializationConstants) { - result = pfnSetSpecializationConstants(hProgram, count, pSpecConstants); + ur_program_set_specialization_constants_params_t params = { + &hProgram, &count, &pSpecConstants}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urProgramSetSpecializationConstants")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urProgramSetSpecializationConstants")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urProgramSetSpecializationConstants")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2167,14 +3695,36 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Program.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hProgram, phNativeProgram); + ur_program_get_native_handle_params_t params = {&hProgram, + &phNativeProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeProgram = - reinterpret_cast(d_context.get()); + + *phNativeProgram = reinterpret_cast(hProgram); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2195,15 +3745,38 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Program.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeProgram, hContext, - pProperties, phProgram); + ur_program_create_with_native_handle_params_t params = { + &hNativeProgram, &hContext, &pProperties, &phProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urProgramCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phProgram = reinterpret_cast(d_context.get()); + + *phProgram = reinterpret_cast(hNativeProgram); + mock::retainDummyHandle(*phProgram); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2221,13 +3794,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreate = d_context.urDdiTable.Kernel.pfnCreate; - if (nullptr != pfnCreate) { - result = pfnCreate(hProgram, pKernelName, phKernel); + ur_kernel_create_params_t params = {&hProgram, &pKernelName, &phKernel}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phKernel = reinterpret_cast(d_context.get()); + + *phKernel = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2248,13 +3843,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetArgValue = d_context.urDdiTable.Kernel.pfnSetArgValue; - if (nullptr != pfnSetArgValue) { - result = - pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue); + ur_kernel_set_arg_value_params_t params = {&hKernel, &argIndex, &argSize, + &pProperties, &pArgValue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetArgValue")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetArgValue")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetArgValue")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2274,12 +3891,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetArgLocal = d_context.urDdiTable.Kernel.pfnSetArgLocal; - if (nullptr != pfnSetArgLocal) { - result = pfnSetArgLocal(hKernel, argIndex, argSize, pProperties); + ur_kernel_set_arg_local_params_t params = {&hKernel, &argIndex, &argSize, + &pProperties}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetArgLocal")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetArgLocal")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetArgLocal")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2306,39 +3946,37 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Kernel.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hKernel, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_KERNEL_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_KERNEL_INFO_PROGRAM: { - ur_program_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_program_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_kernel_get_info_params_t params = {&hKernel, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -2361,13 +3999,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetGroupInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetGroupInfo = d_context.urDdiTable.Kernel.pfnGetGroupInfo; - if (nullptr != pfnGetGroupInfo) { - result = pfnGetGroupInfo(hKernel, hDevice, propName, propSize, - pPropValue, pPropSizeRet); + ur_kernel_get_group_info_params_t params = { + &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelGetGroupInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelGetGroupInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelGetGroupInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2392,13 +4052,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetSubGroupInfo = d_context.urDdiTable.Kernel.pfnGetSubGroupInfo; - if (nullptr != pfnGetSubGroupInfo) { - result = pfnGetSubGroupInfo(hKernel, hDevice, propName, propSize, - pPropValue, pPropSizeRet); + ur_kernel_get_sub_group_info_params_t params = { + &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelGetSubGroupInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelGetSubGroupInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelGetSubGroupInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2409,16 +4091,39 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Kernel.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hKernel); + ur_kernel_retain_params_t params = {&hKernel}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hKernel); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2434,12 +4139,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Kernel.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hKernel); + ur_kernel_release_params_t params = {&hKernel}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hKernel); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2460,12 +4188,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetArgPointer = d_context.urDdiTable.Kernel.pfnSetArgPointer; - if (nullptr != pfnSetArgPointer) { - result = pfnSetArgPointer(hKernel, argIndex, pProperties, pArgValue); + ur_kernel_set_arg_pointer_params_t params = {&hKernel, &argIndex, + &pProperties, &pArgValue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetArgPointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetArgPointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetArgPointer")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2487,13 +4238,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetExecInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetExecInfo = d_context.urDdiTable.Kernel.pfnSetExecInfo; - if (nullptr != pfnSetExecInfo) { - result = pfnSetExecInfo(hKernel, propName, propSize, pProperties, - pPropValue); + ur_kernel_set_exec_info_params_t params = {&hKernel, &propName, &propSize, + &pProperties, &pPropValue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetExecInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetExecInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetExecInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2512,12 +4285,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetArgSampler = d_context.urDdiTable.Kernel.pfnSetArgSampler; - if (nullptr != pfnSetArgSampler) { - result = pfnSetArgSampler(hKernel, argIndex, pProperties, hArgValue); + ur_kernel_set_arg_sampler_params_t params = {&hKernel, &argIndex, + &pProperties, &hArgValue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetArgSampler")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetArgSampler")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetArgSampler")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2536,12 +4332,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetArgMemObj = d_context.urDdiTable.Kernel.pfnSetArgMemObj; - if (nullptr != pfnSetArgMemObj) { - result = pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue); + ur_kernel_set_arg_mem_obj_params_t params = {&hKernel, &argIndex, + &pProperties, &hArgValue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelSetArgMemObj")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelSetArgMemObj")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelSetArgMemObj")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2559,13 +4378,38 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetSpecializationConstants( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetSpecializationConstants = - d_context.urDdiTable.Kernel.pfnSetSpecializationConstants; - if (nullptr != pfnSetSpecializationConstants) { - result = pfnSetSpecializationConstants(hKernel, count, pSpecConstants); + ur_kernel_set_specialization_constants_params_t params = {&hKernel, &count, + &pSpecConstants}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urKernelSetSpecializationConstants")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urKernelSetSpecializationConstants")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urKernelSetSpecializationConstants")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2582,13 +4426,35 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Kernel.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hKernel, phNativeKernel); + ur_kernel_get_native_handle_params_t params = {&hKernel, &phNativeKernel}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeKernel = reinterpret_cast(d_context.get()); + + *phNativeKernel = reinterpret_cast(hKernel); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2611,15 +4477,37 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Kernel.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeKernel, hContext, hProgram, - pProperties, phKernel); + ur_kernel_create_with_native_handle_params_t params = { + &hNativeKernel, &hContext, &hProgram, &pProperties, &phKernel}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urKernelCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urKernelCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phKernel = reinterpret_cast(d_context.get()); + + *phKernel = reinterpret_cast(hNativeKernel); + mock::retainDummyHandle(*phKernel); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urKernelCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2648,15 +4536,39 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetSuggestedLocalWorkSize = - d_context.urDdiTable.Kernel.pfnGetSuggestedLocalWorkSize; - if (nullptr != pfnGetSuggestedLocalWorkSize) { - result = pfnGetSuggestedLocalWorkSize( - hKernel, hQueue, numWorkDim, pGlobalWorkOffset, pGlobalWorkSize, - pSuggestedLocalWorkSize); + ur_kernel_get_suggested_local_work_size_params_t params = { + &hKernel, &hQueue, &numWorkDim, + &pGlobalWorkOffset, &pGlobalWorkSize, &pSuggestedLocalWorkSize}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2679,48 +4591,37 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Queue.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hQueue, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_QUEUE_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_QUEUE_INFO_DEVICE: { - ur_device_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_device_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_QUEUE_INFO_DEVICE_DEFAULT: { - ur_queue_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_queue_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_queue_get_info_params_t params = {&hQueue, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -2738,13 +4639,36 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreate = d_context.urDdiTable.Queue.pfnCreate; - if (nullptr != pfnCreate) { - result = pfnCreate(hContext, hDevice, pProperties, phQueue); + ur_queue_create_params_t params = {&hContext, &hDevice, &pProperties, + &phQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phQueue = reinterpret_cast(d_context.get()); + + *phQueue = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueCreate")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2755,16 +4679,40 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Queue.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hQueue); + ur_queue_retain_params_t params = {&hQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hQueue); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2780,12 +4728,35 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Queue.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hQueue); + ur_queue_release_params_t params = {&hQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hQueue); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2804,13 +4775,36 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Queue.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hQueue, pDesc, phNativeQueue); + ur_queue_get_native_handle_params_t params = {&hQueue, &pDesc, + &phNativeQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeQueue = reinterpret_cast(d_context.get()); + + *phNativeQueue = reinterpret_cast(hQueue); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2832,15 +4826,37 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Queue.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeQueue, hContext, hDevice, - pProperties, phQueue); + ur_queue_create_with_native_handle_params_t params = { + &hNativeQueue, &hContext, &hDevice, &pProperties, &phQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phQueue = reinterpret_cast(d_context.get()); + + *phQueue = reinterpret_cast(hNativeQueue); + mock::retainDummyHandle(*phQueue); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2855,12 +4871,34 @@ __urdlllocal ur_result_t UR_APICALL urQueueFinish( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnFinish = d_context.urDdiTable.Queue.pfnFinish; - if (nullptr != pfnFinish) { - result = pfnFinish(hQueue); + ur_queue_finish_params_t params = {&hQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueFinish")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueFinish")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueFinish")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2875,12 +4913,34 @@ __urdlllocal ur_result_t UR_APICALL urQueueFlush( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnFlush = d_context.urDdiTable.Queue.pfnFlush; - if (nullptr != pfnFlush) { - result = pfnFlush(hQueue); - } else { - // generic implementation + ur_queue_flush_params_t params = {&hQueue}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urQueueFlush")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urQueueFlush")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urQueueFlush")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2901,39 +4961,37 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfo = d_context.urDdiTable.Event.pfnGetInfo; - if (nullptr != pfnGetInfo) { - result = - pfnGetInfo(hEvent, propName, propSize, pPropValue, pPropSizeRet); - } else { - // generic implementation - if (pPropValue != nullptr) { - switch (propName) { - case UR_EVENT_INFO_COMMAND_QUEUE: { - ur_queue_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_queue_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - case UR_EVENT_INFO_CONTEXT: { - ur_context_handle_t *handles = - reinterpret_cast(pPropValue); - size_t nelements = propSize / sizeof(ur_context_handle_t); - for (size_t i = 0; i < nelements; ++i) { - handles[i] = - reinterpret_cast(d_context.get()); - } - } break; - default: { - } break; - } + ur_event_get_info_params_t params = {&hEvent, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; } } + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + return result; } catch (...) { return exceptionToResult(std::current_exception()); @@ -2955,13 +5013,35 @@ __urdlllocal ur_result_t UR_APICALL urEventGetProfilingInfo( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetProfilingInfo = d_context.urDdiTable.Event.pfnGetProfilingInfo; - if (nullptr != pfnGetProfilingInfo) { - result = pfnGetProfilingInfo(hEvent, propName, propSize, pPropValue, - pPropSizeRet); + ur_event_get_profiling_info_params_t params = { + &hEvent, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventGetProfilingInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventGetProfilingInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventGetProfilingInfo")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2979,12 +5059,34 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnWait = d_context.urDdiTable.Event.pfnWait; - if (nullptr != pfnWait) { - result = pfnWait(numEvents, phEventWaitList); + ur_event_wait_params_t params = {&numEvents, &phEventWaitList}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventWait")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventWait")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventWait")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -2995,16 +5097,39 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetain = d_context.urDdiTable.Event.pfnRetain; - if (nullptr != pfnRetain) { - result = pfnRetain(hEvent); + ur_event_retain_params_t params = {&hEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hEvent); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventRetain")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3019,12 +5144,35 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRelease = d_context.urDdiTable.Event.pfnRelease; - if (nullptr != pfnRelease) { - result = pfnRelease(hEvent); + ur_event_release_params_t params = {&hEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hEvent); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventRelease")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3041,13 +5189,35 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetNativeHandle = d_context.urDdiTable.Event.pfnGetNativeHandle; - if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hEvent, phNativeEvent); + ur_event_get_native_handle_params_t params = {&hEvent, &phNativeEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phNativeEvent = reinterpret_cast(d_context.get()); + + *phNativeEvent = reinterpret_cast(hEvent); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3068,15 +5238,37 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateWithNativeHandle = - d_context.urDdiTable.Event.pfnCreateWithNativeHandle; - if (nullptr != pfnCreateWithNativeHandle) { - result = pfnCreateWithNativeHandle(hNativeEvent, hContext, pProperties, - phEvent); + ur_event_create_with_native_handle_params_t params = { + &hNativeEvent, &hContext, &pProperties, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phEvent = reinterpret_cast(d_context.get()); + + *phEvent = reinterpret_cast(hNativeEvent); + mock::retainDummyHandle(*phEvent); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3095,12 +5287,35 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSetCallback = d_context.urDdiTable.Event.pfnSetCallback; - if (nullptr != pfnSetCallback) { - result = pfnSetCallback(hEvent, execStatus, pfnNotify, pUserData); + ur_event_set_callback_params_t params = {&hEvent, &execStatus, &pfnNotify, + &pUserData}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEventSetCallback")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEventSetCallback")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEventSetCallback")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3141,17 +5356,46 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnKernelLaunch = d_context.urDdiTable.Enqueue.pfnKernelLaunch; - if (nullptr != pfnKernelLaunch) { - result = pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_kernel_launch_params_t params = {&hQueue, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueKernelLaunch")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueKernelLaunch")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueKernelLaunch")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3176,16 +5420,39 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWait( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnEventsWait = d_context.urDdiTable.Enqueue.pfnEventsWait; - if (nullptr != pfnEventsWait) { - result = pfnEventsWait(hQueue, numEventsInWaitList, phEventWaitList, - phEvent); + ur_enqueue_events_wait_params_t params = {&hQueue, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueEventsWait")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueEventsWait")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueEventsWait")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3210,17 +5477,39 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnEventsWaitWithBarrier = - d_context.urDdiTable.Enqueue.pfnEventsWaitWithBarrier; - if (nullptr != pfnEventsWaitWithBarrier) { - result = pfnEventsWaitWithBarrier(hQueue, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_events_wait_with_barrier_params_t params = { + &hQueue, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueEventsWaitWithBarrier")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueEventsWaitWithBarrier")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueEventsWaitWithBarrier")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3250,17 +5539,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferRead = d_context.urDdiTable.Enqueue.pfnMemBufferRead; - if (nullptr != pfnMemBufferRead) { - result = - pfnMemBufferRead(hQueue, hBuffer, blockingRead, offset, size, pDst, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_buffer_read_params_t params = { + &hQueue, &hBuffer, &blockingRead, &offset, + &size, &pDst, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferRead")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3292,17 +5605,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferWrite = d_context.urDdiTable.Enqueue.pfnMemBufferWrite; - if (nullptr != pfnMemBufferWrite) { - result = pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size, - pSrc, numEventsInWaitList, phEventWaitList, - phEvent); + ur_enqueue_mem_buffer_write_params_t params = { + &hQueue, &hBuffer, &blockingWrite, &offset, + &size, &pSrc, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferWrite")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3344,19 +5681,51 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferReadRect = - d_context.urDdiTable.Enqueue.pfnMemBufferReadRect; - if (nullptr != pfnMemBufferReadRect) { - result = pfnMemBufferReadRect( - hQueue, hBuffer, blockingRead, bufferOrigin, hostOrigin, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pDst, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_buffer_read_rect_params_t params = {&hQueue, + &hBuffer, + &blockingRead, + &bufferOrigin, + &hostOrigin, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pDst, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferReadRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferReadRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferReadRect")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3401,19 +5770,51 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferWriteRect = - d_context.urDdiTable.Enqueue.pfnMemBufferWriteRect; - if (nullptr != pfnMemBufferWriteRect) { - result = pfnMemBufferWriteRect( - hQueue, hBuffer, blockingWrite, bufferOrigin, hostOrigin, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pSrc, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_buffer_write_rect_params_t params = {&hQueue, + &hBuffer, + &blockingWrite, + &bufferOrigin, + &hostOrigin, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pSrc, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferWriteRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferWriteRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferWriteRect")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3444,17 +5845,40 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferCopy = d_context.urDdiTable.Enqueue.pfnMemBufferCopy; - if (nullptr != pfnMemBufferCopy) { - result = pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset, - dstOffset, size, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_mem_buffer_copy_params_t params = { + &hQueue, &hBufferSrc, &hBufferDst, &srcOffset, &dstOffset, + &size, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferCopy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferCopy")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferCopy")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3494,19 +5918,42 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferCopyRect = - d_context.urDdiTable.Enqueue.pfnMemBufferCopyRect; - if (nullptr != pfnMemBufferCopyRect) { - result = pfnMemBufferCopyRect( - hQueue, hBufferSrc, hBufferDst, srcOrigin, dstOrigin, region, - srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_buffer_copy_rect_params_t params = { + &hQueue, &hBufferSrc, &hBufferDst, &srcOrigin, + &dstOrigin, ®ion, &srcRowPitch, &srcSlicePitch, + &dstRowPitch, &dstSlicePitch, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferCopyRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferCopyRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferCopyRect")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3536,17 +5983,46 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferFill = d_context.urDdiTable.Enqueue.pfnMemBufferFill; - if (nullptr != pfnMemBufferFill) { - result = pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, - offset, size, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_mem_buffer_fill_params_t params = {&hQueue, + &hBuffer, + &pPattern, + &patternSize, + &offset, + &size, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferFill")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferFill")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferFill")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3581,17 +6057,42 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemImageRead = d_context.urDdiTable.Enqueue.pfnMemImageRead; - if (nullptr != pfnMemImageRead) { - result = pfnMemImageRead(hQueue, hImage, blockingRead, origin, region, - rowPitch, slicePitch, pDst, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_image_read_params_t params = { + &hQueue, &hImage, &blockingRead, + &origin, ®ion, &rowPitch, + &slicePitch, &pDst, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemImageRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemImageRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemImageRead")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3627,17 +6128,42 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemImageWrite = d_context.urDdiTable.Enqueue.pfnMemImageWrite; - if (nullptr != pfnMemImageWrite) { - result = pfnMemImageWrite( - hQueue, hImage, blockingWrite, origin, region, rowPitch, slicePitch, - pSrc, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_mem_image_write_params_t params = { + &hQueue, &hImage, &blockingWrite, + &origin, ®ion, &rowPitch, + &slicePitch, &pSrc, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemImageWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemImageWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemImageWrite")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3674,17 +6200,40 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemImageCopy = d_context.urDdiTable.Enqueue.pfnMemImageCopy; - if (nullptr != pfnMemImageCopy) { - result = pfnMemImageCopy(hQueue, hImageSrc, hImageDst, srcOrigin, - dstOrigin, region, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_mem_image_copy_params_t params = { + &hQueue, &hImageSrc, &hImageDst, &srcOrigin, &dstOrigin, + ®ion, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemImageCopy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemImageCopy")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemImageCopy")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3716,17 +6265,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemBufferMap = d_context.urDdiTable.Enqueue.pfnMemBufferMap; - if (nullptr != pfnMemBufferMap) { - result = pfnMemBufferMap(hQueue, hBuffer, blockingMap, mapFlags, offset, - size, numEventsInWaitList, phEventWaitList, - phEvent, ppRetMap); + ur_enqueue_mem_buffer_map_params_t params = { + &hQueue, &hBuffer, &blockingMap, &mapFlags, + &offset, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent, &ppRetMap}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemBufferMap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemBufferMap")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemBufferMap")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3753,16 +6326,40 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMemUnmap = d_context.urDdiTable.Enqueue.pfnMemUnmap; - if (nullptr != pfnMemUnmap) { - result = pfnMemUnmap(hQueue, hMem, pMappedPtr, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_mem_unmap_params_t params = { + &hQueue, &hMem, &pMappedPtr, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueMemUnmap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueMemUnmap")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueMemUnmap")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3794,16 +6391,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMFill = d_context.urDdiTable.Enqueue.pfnUSMFill; - if (nullptr != pfnUSMFill) { - result = pfnUSMFill(hQueue, pMem, patternSize, pPattern, size, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_usm_fill_params_t params = { + &hQueue, &pMem, &patternSize, + &pPattern, &size, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMFill")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMFill")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMFill")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3833,16 +6455,40 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMMemcpy = d_context.urDdiTable.Enqueue.pfnUSMMemcpy; - if (nullptr != pfnUSMMemcpy) { - result = pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_usm_memcpy_params_t params = { + &hQueue, &blocking, &pDst, &pSrc, &size, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMMemcpy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMMemcpy")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMMemcpy")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3870,16 +6516,40 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMPrefetch = d_context.urDdiTable.Enqueue.pfnUSMPrefetch; - if (nullptr != pfnUSMPrefetch) { - result = pfnUSMPrefetch(hQueue, pMem, size, flags, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_usm_prefetch_params_t params = { + &hQueue, &pMem, &size, &flags, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMPrefetch")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMPrefetch")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMPrefetch")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3901,15 +6571,39 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMAdvise = d_context.urDdiTable.Enqueue.pfnUSMAdvise; - if (nullptr != pfnUSMAdvise) { - result = pfnUSMAdvise(hQueue, pMem, size, advice, phEvent); + ur_enqueue_usm_advise_params_t params = {&hQueue, &pMem, &size, &advice, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMAdvise")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMAdvise")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMAdvise")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3946,17 +6640,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMFill2D = d_context.urDdiTable.Enqueue.pfnUSMFill2D; - if (nullptr != pfnUSMFill2D) { - result = - pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width, - height, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_usm_fill_2d_params_t params = { + &hQueue, &pMem, &pitch, &patternSize, + &pPattern, &width, &height, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMFill2D")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMFill2D")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMFill2D")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -3992,17 +6710,42 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUSMMemcpy2D = d_context.urDdiTable.Enqueue.pfnUSMMemcpy2D; - if (nullptr != pfnUSMMemcpy2D) { - result = pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, - srcPitch, width, height, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_usm_memcpy_2d_params_t params = { + &hQueue, &blocking, &pDst, + &dstPitch, &pSrc, &srcPitch, + &width, &height, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueUSMMemcpy2D")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueUSMMemcpy2D")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueUSMMemcpy2D")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4035,18 +6778,44 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnDeviceGlobalVariableWrite = - d_context.urDdiTable.Enqueue.pfnDeviceGlobalVariableWrite; - if (nullptr != pfnDeviceGlobalVariableWrite) { - result = pfnDeviceGlobalVariableWrite( - hQueue, hProgram, name, blockingWrite, count, offset, pSrc, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_device_global_variable_write_params_t params = { + &hQueue, &hProgram, &name, &blockingWrite, + &count, &offset, &pSrc, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4079,18 +6848,44 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnDeviceGlobalVariableRead = - d_context.urDdiTable.Enqueue.pfnDeviceGlobalVariableRead; - if (nullptr != pfnDeviceGlobalVariableRead) { - result = pfnDeviceGlobalVariableRead( - hQueue, hProgram, name, blockingRead, count, offset, pDst, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_device_global_variable_read_params_t params = { + &hQueue, &hProgram, &name, &blockingRead, + &count, &offset, &pDst, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4127,17 +6922,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueReadHostPipe( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReadHostPipe = d_context.urDdiTable.Enqueue.pfnReadHostPipe; - if (nullptr != pfnReadHostPipe) { - result = - pfnReadHostPipe(hQueue, hProgram, pipe_symbol, blocking, pDst, size, - numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_read_host_pipe_params_t params = { + &hQueue, &hProgram, &pipe_symbol, &blocking, + &pDst, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueReadHostPipe")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueReadHostPipe")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueReadHostPipe")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4174,17 +6993,41 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueWriteHostPipe( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnWriteHostPipe = d_context.urDdiTable.Enqueue.pfnWriteHostPipe; - if (nullptr != pfnWriteHostPipe) { - result = pfnWriteHostPipe(hQueue, hProgram, pipe_symbol, blocking, pSrc, - size, numEventsInWaitList, phEventWaitList, - phEvent); + ur_enqueue_write_host_pipe_params_t params = { + &hQueue, &hProgram, &pipe_symbol, &blocking, + &pSrc, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueWriteHostPipe")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueWriteHostPipe")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueWriteHostPipe")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4211,14 +7054,36 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPitchedAllocExp = d_context.urDdiTable.USMExp.pfnPitchedAllocExp; - if (nullptr != pfnPitchedAllocExp) { - result = - pfnPitchedAllocExp(hContext, hDevice, pUSMDesc, pool, widthInBytes, - height, elementSizeBytes, ppMem, pResultPitch); + ur_usm_pitched_alloc_exp_params_t params = { + &hContext, &hDevice, &pUSMDesc, &pool, &widthInBytes, + &height, &elementSizeBytes, &ppMem, &pResultPitch}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMPitchedAllocExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMPitchedAllocExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMPitchedAllocExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4237,14 +7102,39 @@ urBindlessImagesUnsampledImageHandleDestroyExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUnsampledImageHandleDestroyExp = - d_context.urDdiTable.BindlessImagesExp - .pfnUnsampledImageHandleDestroyExp; - if (nullptr != pfnUnsampledImageHandleDestroyExp) { - result = pfnUnsampledImageHandleDestroyExp(hContext, hDevice, hImage); + ur_bindless_images_unsampled_image_handle_destroy_exp_params_t params = { + &hContext, &hDevice, &hImage}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hImage); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4263,13 +7153,39 @@ urBindlessImagesSampledImageHandleDestroyExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSampledImageHandleDestroyExp = - d_context.urDdiTable.BindlessImagesExp.pfnSampledImageHandleDestroyExp; - if (nullptr != pfnSampledImageHandleDestroyExp) { - result = pfnSampledImageHandleDestroyExp(hContext, hDevice, hImage); + ur_bindless_images_sampled_image_handle_destroy_exp_params_t params = { + &hContext, &hDevice, &hImage}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hImage); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4290,16 +7206,39 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageAllocateExp = - d_context.urDdiTable.BindlessImagesExp.pfnImageAllocateExp; - if (nullptr != pfnImageAllocateExp) { - result = pfnImageAllocateExp(hContext, hDevice, pImageFormat, - pImageDesc, phImageMem); + ur_bindless_images_image_allocate_exp_params_t params = { + &hContext, &hDevice, &pImageFormat, &pImageDesc, &phImageMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesImageAllocateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesImageAllocateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + *phImageMem = - reinterpret_cast(d_context.get()); + mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urBindlessImagesImageAllocateExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4317,13 +7256,36 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageFreeExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageFreeExp = - d_context.urDdiTable.BindlessImagesExp.pfnImageFreeExp; - if (nullptr != pfnImageFreeExp) { - result = pfnImageFreeExp(hContext, hDevice, hImageMem); + ur_bindless_images_image_free_exp_params_t params = {&hContext, &hDevice, + &hImageMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urBindlessImagesImageFreeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urBindlessImagesImageFreeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hImageMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urBindlessImagesImageFreeExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4346,16 +7308,39 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUnsampledImageCreateExp = - d_context.urDdiTable.BindlessImagesExp.pfnUnsampledImageCreateExp; - if (nullptr != pfnUnsampledImageCreateExp) { - result = pfnUnsampledImageCreateExp(hContext, hDevice, hImageMem, - pImageFormat, pImageDesc, phImage); + ur_bindless_images_unsampled_image_create_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &pImageFormat, &pImageDesc, &phImage}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phImage = - reinterpret_cast(d_context.get()); + + *phImage = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4379,17 +7364,40 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSampledImageCreateExp = - d_context.urDdiTable.BindlessImagesExp.pfnSampledImageCreateExp; - if (nullptr != pfnSampledImageCreateExp) { - result = - pfnSampledImageCreateExp(hContext, hDevice, hImageMem, pImageFormat, - pImageDesc, hSampler, phImage); + ur_bindless_images_sampled_image_create_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &pImageFormat, + &pImageDesc, &hSampler, &phImage}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesSampledImageCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesSampledImageCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phImage = - reinterpret_cast(d_context.get()); + + *phImage = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesSampledImageCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4433,19 +7441,50 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageCopyExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageCopyExp = - d_context.urDdiTable.BindlessImagesExp.pfnImageCopyExp; - if (nullptr != pfnImageCopyExp) { - result = pfnImageCopyExp(hQueue, pDst, pSrc, pImageFormat, pImageDesc, - imageCopyFlags, srcOffset, dstOffset, - copyExtent, hostExtent, numEventsInWaitList, - phEventWaitList, phEvent); + ur_bindless_images_image_copy_exp_params_t params = {&hQueue, + &pDst, + &pSrc, + &pImageFormat, + &pImageDesc, + &imageCopyFlags, + &srcOffset, + &dstOffset, + ©Extent, + &hostExtent, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urBindlessImagesImageCopyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urBindlessImagesImageCopyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urBindlessImagesImageCopyExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4465,14 +7504,36 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImageGetInfoExp = - d_context.urDdiTable.BindlessImagesExp.pfnImageGetInfoExp; - if (nullptr != pfnImageGetInfoExp) { - result = pfnImageGetInfoExp(hContext, hImageMem, propName, pPropValue, - pPropSizeRet); + ur_bindless_images_image_get_info_exp_params_t params = { + &hContext, &hImageMem, &propName, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urBindlessImagesImageGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesImageGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urBindlessImagesImageGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4493,16 +7554,40 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMipmapGetLevelExp = - d_context.urDdiTable.BindlessImagesExp.pfnMipmapGetLevelExp; - if (nullptr != pfnMipmapGetLevelExp) { - result = pfnMipmapGetLevelExp(hContext, hDevice, hImageMem, mipmapLevel, - phImageMem); + ur_bindless_images_mipmap_get_level_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &mipmapLevel, &phImageMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + *phImageMem = - reinterpret_cast(d_context.get()); + mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4520,13 +7605,36 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMipmapFreeExp = - d_context.urDdiTable.BindlessImagesExp.pfnMipmapFreeExp; - if (nullptr != pfnMipmapFreeExp) { - result = pfnMipmapFreeExp(hContext, hDevice, hMem); + ur_bindless_images_mipmap_free_exp_params_t params = {&hContext, &hDevice, + &hMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urBindlessImagesMipmapFreeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urBindlessImagesMipmapFreeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urBindlessImagesMipmapFreeExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4549,17 +7657,40 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalMemoryExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImportExternalMemoryExp = - d_context.urDdiTable.BindlessImagesExp.pfnImportExternalMemoryExp; - if (nullptr != pfnImportExternalMemoryExp) { - result = - pfnImportExternalMemoryExp(hContext, hDevice, size, memHandleType, - pInteropMemDesc, phInteropMem); + ur_bindless_images_import_external_memory_exp_params_t params = { + &hContext, &hDevice, &size, + &memHandleType, &pInteropMemDesc, &phInteropMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesImportExternalMemoryExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesImportExternalMemoryExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phInteropMem = - reinterpret_cast(d_context.get()); + + *phInteropMem = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesImportExternalMemoryExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4582,16 +7713,41 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnMapExternalArrayExp = - d_context.urDdiTable.BindlessImagesExp.pfnMapExternalArrayExp; - if (nullptr != pfnMapExternalArrayExp) { - result = pfnMapExternalArrayExp(hContext, hDevice, pImageFormat, - pImageDesc, hInteropMem, phImageMem); + ur_bindless_images_map_external_array_exp_params_t params = { + &hContext, &hDevice, &pImageFormat, + &pImageDesc, &hInteropMem, &phImageMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesMapExternalArrayExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesMapExternalArrayExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + *phImageMem = - reinterpret_cast(d_context.get()); + mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesMapExternalArrayExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4609,13 +7765,39 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReleaseInteropExp = - d_context.urDdiTable.BindlessImagesExp.pfnReleaseInteropExp; - if (nullptr != pfnReleaseInteropExp) { - result = pfnReleaseInteropExp(hContext, hDevice, hInteropMem); + ur_bindless_images_release_interop_exp_params_t params = { + &hContext, &hDevice, &hInteropMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesReleaseInteropExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesReleaseInteropExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hInteropMem); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesReleaseInteropExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4637,18 +7819,41 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalSemaphoreExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImportExternalSemaphoreExp = - d_context.urDdiTable.BindlessImagesExp.pfnImportExternalSemaphoreExp; - if (nullptr != pfnImportExternalSemaphoreExp) { - result = pfnImportExternalSemaphoreExp(hContext, hDevice, semHandleType, - pInteropSemaphoreDesc, - phInteropSemaphore); + ur_bindless_images_import_external_semaphore_exp_params_t params = { + &hContext, &hDevice, &semHandleType, &pInteropSemaphoreDesc, + &phInteropSemaphore}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesImportExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesImportExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + *phInteropSemaphore = - reinterpret_cast( - d_context.get()); + mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesImportExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4666,14 +7871,39 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesDestroyExternalSemaphoreExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnDestroyExternalSemaphoreExp = - d_context.urDdiTable.BindlessImagesExp.pfnDestroyExternalSemaphoreExp; - if (nullptr != pfnDestroyExternalSemaphoreExp) { - result = pfnDestroyExternalSemaphoreExp(hContext, hDevice, - hInteropSemaphore); + ur_bindless_images_destroy_external_semaphore_exp_params_t params = { + &hContext, &hDevice, &hInteropSemaphore}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hInteropSemaphore); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4706,18 +7936,44 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesWaitExternalSemaphoreExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnWaitExternalSemaphoreExp = - d_context.urDdiTable.BindlessImagesExp.pfnWaitExternalSemaphoreExp; - if (nullptr != pfnWaitExternalSemaphoreExp) { - result = pfnWaitExternalSemaphoreExp(hQueue, hSemaphore, hasWaitValue, - waitValue, numEventsInWaitList, - phEventWaitList, phEvent); + ur_bindless_images_wait_external_semaphore_exp_params_t params = { + &hQueue, &hSemaphore, &hasWaitValue, + &waitValue, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4750,18 +8006,44 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSignalExternalSemaphoreExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSignalExternalSemaphoreExp = - d_context.urDdiTable.BindlessImagesExp.pfnSignalExternalSemaphoreExp; - if (nullptr != pfnSignalExternalSemaphoreExp) { - result = pfnSignalExternalSemaphoreExp( - hQueue, hSemaphore, hasSignalValue, signalValue, - numEventsInWaitList, phEventWaitList, phEvent); + ur_bindless_images_signal_external_semaphore_exp_params_t params = { + &hQueue, &hSemaphore, &hasSignalValue, + &signalValue, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4781,15 +8063,37 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCreateExp = d_context.urDdiTable.CommandBufferExp.pfnCreateExp; - if (nullptr != pfnCreateExp) { - result = pfnCreateExp(hContext, hDevice, pCommandBufferDesc, - phCommandBuffer); + ur_command_buffer_create_exp_params_t params = { + &hContext, &hDevice, &pCommandBufferDesc, &phCommandBuffer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + *phCommandBuffer = - reinterpret_cast(d_context.get()); + mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4801,16 +8105,39 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetainExp = d_context.urDdiTable.CommandBufferExp.pfnRetainExp; - if (nullptr != pfnRetainExp) { - result = pfnRetainExp(hCommandBuffer); + ur_command_buffer_retain_exp_params_t params = {&hCommandBuffer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferRetainExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferRetainExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::retainDummyHandle(hCommandBuffer); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferRetainExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4826,12 +8153,35 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReleaseExp = d_context.urDdiTable.CommandBufferExp.pfnReleaseExp; - if (nullptr != pfnReleaseExp) { - result = pfnReleaseExp(hCommandBuffer); + ur_command_buffer_release_exp_params_t params = {&hCommandBuffer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferReleaseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferReleaseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hCommandBuffer); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferReleaseExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4847,12 +8197,34 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferFinalizeExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnFinalizeExp = d_context.urDdiTable.CommandBufferExp.pfnFinalizeExp; - if (nullptr != pfnFinalizeExp) { - result = pfnFinalizeExp(hCommandBuffer); + ur_command_buffer_finalize_exp_params_t params = {&hCommandBuffer}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferFinalizeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferFinalizeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferFinalizeExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4885,21 +8257,52 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendKernelLaunchExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendKernelLaunchExp; - if (nullptr != pfnAppendKernelLaunchExp) { - result = pfnAppendKernelLaunchExp( - hCommandBuffer, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint, phCommand); + ur_command_buffer_append_kernel_launch_exp_params_t params = { + &hCommandBuffer, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint, + &phCommand}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phCommand) { - *phCommand = - reinterpret_cast( - d_context.get()); + + // optional output handle + if (phCommand) { + *phCommand = mock::createDummyHandle< + ur_exp_command_buffer_command_handle_t>(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4925,15 +8328,39 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendUSMMemcpyExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendUSMMemcpyExp; - if (nullptr != pfnAppendUSMMemcpyExp) { - result = pfnAppendUSMMemcpyExp(hCommandBuffer, pDst, pSrc, size, - numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); - } else { - // generic implementation + ur_command_buffer_append_usm_memcpy_exp_params_t params = { + &hCommandBuffer, &pDst, &pSrc, &size, &numSyncPointsInWaitList, + &pSyncPointWaitList, &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4961,15 +8388,38 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendUSMFillExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendUSMFillExp; - if (nullptr != pfnAppendUSMFillExp) { - result = pfnAppendUSMFillExp(hCommandBuffer, pMemory, pPattern, - patternSize, size, numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_usm_fill_exp_params_t params = { + &hCommandBuffer, &pMemory, &pPattern, + &patternSize, &size, &numSyncPointsInWaitList, + &pSyncPointWaitList, &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferAppendUSMFillExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendUSMFillExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferAppendUSMFillExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -4997,15 +8447,46 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferCopyExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferCopyExp; - if (nullptr != pfnAppendMemBufferCopyExp) { - result = pfnAppendMemBufferCopyExp( - hCommandBuffer, hSrcMem, hDstMem, srcOffset, dstOffset, size, - numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_copy_exp_params_t params = { + &hCommandBuffer, + &hSrcMem, + &hDstMem, + &srcOffset, + &dstOffset, + &size, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5033,15 +8514,45 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferWriteExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferWriteExp; - if (nullptr != pfnAppendMemBufferWriteExp) { - result = pfnAppendMemBufferWriteExp(hCommandBuffer, hBuffer, offset, - size, pSrc, numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_write_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &offset, + &size, + &pSrc, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5068,15 +8579,45 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferReadExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferReadExp; - if (nullptr != pfnAppendMemBufferReadExp) { - result = pfnAppendMemBufferReadExp(hCommandBuffer, hBuffer, offset, - size, pDst, numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_read_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &offset, + &size, + &pDst, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5111,16 +8652,50 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferCopyRectExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferCopyRectExp; - if (nullptr != pfnAppendMemBufferCopyRectExp) { - result = pfnAppendMemBufferCopyRectExp( - hCommandBuffer, hSrcMem, hDstMem, srcOrigin, dstOrigin, region, - srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, - numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_copy_rect_exp_params_t params = { + &hCommandBuffer, + &hSrcMem, + &hDstMem, + &srcOrigin, + &dstOrigin, + ®ion, + &srcRowPitch, + &srcSlicePitch, + &dstRowPitch, + &dstSlicePitch, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5161,16 +8736,50 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferWriteRectExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferWriteRectExp; - if (nullptr != pfnAppendMemBufferWriteRectExp) { - result = pfnAppendMemBufferWriteRectExp( - hCommandBuffer, hBuffer, bufferOffset, hostOffset, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pSrc, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_write_rect_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &bufferOffset, + &hostOffset, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pSrc, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5209,16 +8818,50 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferReadRectExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferReadRectExp; - if (nullptr != pfnAppendMemBufferReadRectExp) { - result = pfnAppendMemBufferReadRectExp( - hCommandBuffer, hBuffer, bufferOffset, hostOffset, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pDst, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_read_rect_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &bufferOffset, + &hostOffset, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pDst, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5247,15 +8890,46 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendMemBufferFillExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendMemBufferFillExp; - if (nullptr != pfnAppendMemBufferFillExp) { - result = pfnAppendMemBufferFillExp( - hCommandBuffer, hBuffer, pPattern, patternSize, offset, size, - numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_mem_buffer_fill_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &pPattern, + &patternSize, + &offset, + &size, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5281,15 +8955,44 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendUSMPrefetchExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendUSMPrefetchExp; - if (nullptr != pfnAppendUSMPrefetchExp) { - result = pfnAppendUSMPrefetchExp(hCommandBuffer, pMemory, size, flags, - numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_usm_prefetch_exp_params_t params = { + &hCommandBuffer, + &pMemory, + &size, + &flags, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5315,15 +9018,44 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnAppendUSMAdviseExp = - d_context.urDdiTable.CommandBufferExp.pfnAppendUSMAdviseExp; - if (nullptr != pfnAppendUSMAdviseExp) { - result = pfnAppendUSMAdviseExp(hCommandBuffer, pMemory, size, advice, - numSyncPointsInWaitList, - pSyncPointWaitList, pSyncPoint); + ur_command_buffer_append_usm_advise_exp_params_t params = { + &hCommandBuffer, + &pMemory, + &size, + &advice, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5349,16 +9081,40 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnEnqueueExp = d_context.urDdiTable.CommandBufferExp.pfnEnqueueExp; - if (nullptr != pfnEnqueueExp) { - result = pfnEnqueueExp(hCommandBuffer, hQueue, numEventsInWaitList, - phEventWaitList, phEvent); + ur_command_buffer_enqueue_exp_params_t params = { + &hCommandBuffer, &hQueue, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferEnqueueExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferEnqueueExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferEnqueueExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5374,13 +9130,35 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnRetainCommandExp = - d_context.urDdiTable.CommandBufferExp.pfnRetainCommandExp; - if (nullptr != pfnRetainCommandExp) { - result = pfnRetainCommandExp(hCommand); + ur_command_buffer_retain_command_exp_params_t params = {&hCommand}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferRetainCommandExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferRetainCommandExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferRetainCommandExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5396,13 +9174,37 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReleaseCommandExp = - d_context.urDdiTable.CommandBufferExp.pfnReleaseCommandExp; - if (nullptr != pfnReleaseCommandExp) { - result = pfnReleaseCommandExp(hCommand); + ur_command_buffer_release_command_exp_params_t params = {&hCommand}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferReleaseCommandExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferReleaseCommandExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + mock::releaseDummyHandle(hCommand); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferReleaseCommandExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5420,13 +9222,38 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnUpdateKernelLaunchExp = - d_context.urDdiTable.CommandBufferExp.pfnUpdateKernelLaunchExp; - if (nullptr != pfnUpdateKernelLaunchExp) { - result = pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); + ur_command_buffer_update_kernel_launch_exp_params_t params = { + &hCommand, &pUpdateKernelLaunch}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5451,13 +9278,35 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetInfoExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnGetInfoExp = d_context.urDdiTable.CommandBufferExp.pfnGetInfoExp; - if (nullptr != pfnGetInfoExp) { - result = pfnGetInfoExp(hCommandBuffer, propName, propSize, pPropValue, - pPropSizeRet); + ur_command_buffer_get_info_exp_params_t params = { + &hCommandBuffer, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urCommandBufferGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urCommandBufferGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5482,14 +9331,37 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCommandGetInfoExp = - d_context.urDdiTable.CommandBufferExp.pfnCommandGetInfoExp; - if (nullptr != pfnCommandGetInfoExp) { - result = pfnCommandGetInfoExp(hCommand, propName, propSize, pPropValue, - pPropSizeRet); + ur_command_buffer_command_get_info_exp_params_t params = { + &hCommand, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urCommandBufferCommandGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urCommandBufferCommandGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urCommandBufferCommandGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5530,18 +9402,50 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCooperativeKernelLaunchExp = - d_context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; - if (nullptr != pfnCooperativeKernelLaunchExp) { - result = pfnCooperativeKernelLaunchExp( - hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, - pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_cooperative_kernel_launch_exp_params_t params = { + &hQueue, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5563,14 +9467,38 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSuggestMaxCooperativeGroupCountExp = - d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; - if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) { - result = pfnSuggestMaxCooperativeGroupCountExp( - hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet); + ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = { + &hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = + reinterpret_cast(mock::callbacks.get_after_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5603,15 +9531,36 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnTimestampRecordingExp = - d_context.urDdiTable.EnqueueExp.pfnTimestampRecordingExp; - if (nullptr != pfnTimestampRecordingExp) { - result = pfnTimestampRecordingExp(hQueue, blocking, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_timestamp_recording_exp_params_t params = { + &hQueue, &blocking, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueTimestampRecordingExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueTimestampRecordingExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phEvent = reinterpret_cast(d_context.get()); + + *phEvent = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueTimestampRecordingExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5651,16 +9600,39 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnKernelLaunchCustomExp = - d_context.urDdiTable.EnqueueExp.pfnKernelLaunchCustomExp; - if (nullptr != pfnKernelLaunchCustomExp) { - result = pfnKernelLaunchCustomExp( - hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, - numPropsInLaunchPropList, launchPropList, numEventsInWaitList, - phEventWaitList, phEvent); + ur_enqueue_kernel_launch_custom_exp_params_t params = { + &hQueue, &hKernel, + &workDim, &pGlobalWorkSize, + &pLocalWorkSize, &numPropsInLaunchPropList, + &launchPropList, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueKernelLaunchCustomExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueKernelLaunchCustomExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueKernelLaunchCustomExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5680,12 +9652,35 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnBuildExp = d_context.urDdiTable.ProgramExp.pfnBuildExp; - if (nullptr != pfnBuildExp) { - result = pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + ur_program_build_exp_params_t params = {&hProgram, &numDevices, &phDevices, + &pOptions}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramBuildExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramBuildExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramBuildExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5706,12 +9701,35 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCompileExp = d_context.urDdiTable.ProgramExp.pfnCompileExp; - if (nullptr != pfnCompileExp) { - result = pfnCompileExp(hProgram, numDevices, phDevices, pOptions); + ur_program_compile_exp_params_t params = {&hProgram, &numDevices, + &phDevices, &pOptions}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramCompileExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramCompileExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramCompileExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5739,14 +9757,37 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( *phProgram = nullptr; } - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnLinkExp = d_context.urDdiTable.ProgramExp.pfnLinkExp; - if (nullptr != pfnLinkExp) { - result = pfnLinkExp(hContext, numDevices, phDevices, count, phPrograms, - pOptions, phProgram); + ur_program_link_exp_params_t params = {&hContext, &numDevices, &phDevices, + &count, &phPrograms, &pOptions, + &phProgram}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urProgramLinkExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urProgramLinkExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - *phProgram = reinterpret_cast(d_context.get()); + + *phProgram = mock::createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urProgramLinkExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5763,12 +9804,34 @@ __urdlllocal ur_result_t UR_APICALL urUSMImportExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnImportExp = d_context.urDdiTable.USMExp.pfnImportExp; - if (nullptr != pfnImportExp) { - result = pfnImportExp(hContext, pMem, size); + ur_usm_import_exp_params_t params = {&hContext, &pMem, &size}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMImportExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMImportExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMImportExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5784,12 +9847,34 @@ __urdlllocal ur_result_t UR_APICALL urUSMReleaseExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnReleaseExp = d_context.urDdiTable.USMExp.pfnReleaseExp; - if (nullptr != pfnReleaseExp) { - result = pfnReleaseExp(hContext, pMem); + ur_usm_release_exp_params_t params = {&hContext, &pMem}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUSMReleaseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUSMReleaseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUSMReleaseExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5806,13 +9891,35 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnEnablePeerAccessExp = - d_context.urDdiTable.UsmP2PExp.pfnEnablePeerAccessExp; - if (nullptr != pfnEnablePeerAccessExp) { - result = pfnEnablePeerAccessExp(commandDevice, peerDevice); + ur_usm_p2p_enable_peer_access_exp_params_t params = {&commandDevice, + &peerDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUsmP2PEnablePeerAccessExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUsmP2PEnablePeerAccessExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUsmP2PEnablePeerAccessExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5829,13 +9936,35 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnDisablePeerAccessExp = - d_context.urDdiTable.UsmP2PExp.pfnDisablePeerAccessExp; - if (nullptr != pfnDisablePeerAccessExp) { - result = pfnDisablePeerAccessExp(commandDevice, peerDevice); + ur_usm_p2p_disable_peer_access_exp_params_t params = {&commandDevice, + &peerDevice}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUsmP2PDisablePeerAccessExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUsmP2PDisablePeerAccessExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUsmP2PDisablePeerAccessExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5863,14 +9992,36 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnPeerAccessGetInfoExp = - d_context.urDdiTable.UsmP2PExp.pfnPeerAccessGetInfoExp; - if (nullptr != pfnPeerAccessGetInfoExp) { - result = pfnPeerAccessGetInfoExp(commandDevice, peerDevice, propName, - propSize, pPropValue, pPropSizeRet); + ur_usm_p2p_peer_access_get_info_exp_params_t params = { + &commandDevice, &peerDevice, &propName, + &propSize, &pPropValue, &pPropSizeRet}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urUsmP2PPeerAccessGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urUsmP2PPeerAccessGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urUsmP2PPeerAccessGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; @@ -5905,18 +10056,46 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp( ) try { ur_result_t result = UR_RESULT_SUCCESS; - // if the driver has created a custom function, then call it instead of using the generic path - auto pfnNativeCommandExp = - d_context.urDdiTable.EnqueueExp.pfnNativeCommandExp; - if (nullptr != pfnNativeCommandExp) { - result = pfnNativeCommandExp( - hQueue, pfnNativeEnqueue, data, numMemsInMemList, phMemList, - pProperties, numEventsInWaitList, phEventWaitList, phEvent); + ur_enqueue_native_command_exp_params_t params = {&hQueue, + &pfnNativeEnqueue, + &data, + &numMemsInMemList, + &phMemList, + &pProperties, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + auto beforeCallback = reinterpret_cast( + mock::callbacks.get_before_callback("urEnqueueNativeCommandExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::callbacks.get_replace_callback("urEnqueueNativeCommandExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); } else { - // generic implementation - if (nullptr != phEvent) { - *phEvent = reinterpret_cast(d_context.get()); + + // optional output handle + if (phEvent) { + *phEvent = mock::createDummyHandle(); } + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::callbacks.get_after_callback("urEnqueueNativeCommandExp")); + if (afterCallback) { + return afterCallback(¶ms); } return result; diff --git a/source/loader/CMakeLists.txt b/source/loader/CMakeLists.txt index 075d9909b0..4f9b5b53ec 100644 --- a/source/loader/CMakeLists.txt +++ b/source/loader/CMakeLists.txt @@ -159,7 +159,6 @@ if(UR_ENABLE_SANITIZER) ) endif() - # link validation backtrace dependencies if(UNIX) find_package(Libbacktrace) diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index e0c69f50e9..9f0f302399 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -87,7 +87,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { auto pfnAdapterRetain = context.urDdiTable.Global.pfnAdapterRetain; @@ -518,7 +518,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { auto pfnRetain = context.urDdiTable.Device.pfnRetain; @@ -816,7 +816,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { auto pfnRetain = context.urDdiTable.Context.pfnRetain; @@ -1112,7 +1112,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { auto pfnRetain = context.urDdiTable.Mem.pfnRetain; @@ -1452,7 +1453,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { auto pfnRetain = context.urDdiTable.Sampler.pfnRetain; @@ -1852,7 +1853,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { auto pfnPoolRetain = context.urDdiTable.USM.pfnPoolRetain; @@ -2270,7 +2271,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { auto pfnRetain = context.urDdiTable.PhysicalMem.pfnRetain; @@ -2519,7 +2520,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { auto pfnRetain = context.urDdiTable.Program.pfnRetain; @@ -3110,7 +3112,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { auto pfnRetain = context.urDdiTable.Kernel.pfnRetain; @@ -3561,7 +3563,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { auto pfnRetain = context.urDdiTable.Queue.pfnRetain; @@ -3870,7 +3873,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { auto pfnRetain = context.urDdiTable.Event.pfnRetain; @@ -6268,7 +6271,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { auto pfnRetainExp = context.urDdiTable.CommandBufferExp.pfnRetainExp; diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 46ce0cf24b..163dc76d03 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -79,7 +79,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { auto pfnAdapterRetain = context.urDdiTable.Global.pfnAdapterRetain; @@ -531,7 +531,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { auto pfnRetain = context.urDdiTable.Device.pfnRetain; @@ -827,7 +827,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { auto pfnRetain = context.urDdiTable.Context.pfnRetain; @@ -1199,7 +1199,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { auto pfnRetain = context.urDdiTable.Mem.pfnRetain; @@ -1608,7 +1609,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { auto pfnRetain = context.urDdiTable.Sampler.pfnRetain; @@ -2104,7 +2105,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { auto pfnPoolRetain = context.urDdiTable.USM.pfnPoolRetain; @@ -2581,7 +2582,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { auto pfnRetain = context.urDdiTable.PhysicalMem.pfnRetain; @@ -2890,7 +2891,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { auto pfnRetain = context.urDdiTable.Program.pfnRetain; @@ -3552,7 +3554,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { auto pfnRetain = context.urDdiTable.Kernel.pfnRetain; @@ -4074,7 +4076,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { auto pfnRetain = context.urDdiTable.Queue.pfnRetain; @@ -4393,7 +4396,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { auto pfnRetain = context.urDdiTable.Event.pfnRetain; @@ -7831,7 +7834,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { auto pfnRetainExp = context.urDdiTable.CommandBufferExp.pfnRetainExp; diff --git a/source/loader/loader.def.in b/source/loader/loader.def.in index 45ab3b1caf..f37edad919 100644 --- a/source/loader/loader.def.in +++ b/source/loader/loader.def.in @@ -142,6 +142,7 @@ EXPORTS urLoaderConfigRelease urLoaderConfigRetain urLoaderConfigSetCodeLocationCallback + urLoaderConfigSetMockingEnabled urLoaderInit urLoaderTearDown urMemBufferCreate @@ -361,6 +362,7 @@ EXPORTS urPrintLoaderConfigReleaseParams urPrintLoaderConfigRetainParams urPrintLoaderConfigSetCodeLocationCallbackParams + urPrintLoaderConfigSetMockingEnabledParams urPrintLoaderInitParams urPrintLoaderTearDownParams urPrintMapFlags diff --git a/source/loader/loader.map.in b/source/loader/loader.map.in index 170365ac4b..1ba389a3b5 100644 --- a/source/loader/loader.map.in +++ b/source/loader/loader.map.in @@ -142,6 +142,7 @@ urLoaderConfigRelease; urLoaderConfigRetain; urLoaderConfigSetCodeLocationCallback; + urLoaderConfigSetMockingEnabled; urLoaderInit; urLoaderTearDown; urMemBufferCreate; @@ -361,6 +362,7 @@ urPrintLoaderConfigReleaseParams; urPrintLoaderConfigRetainParams; urPrintLoaderConfigSetCodeLocationCallbackParams; + urPrintLoaderConfigSetMockingEnabledParams; urPrintLoaderInitParams; urPrintLoaderTearDownParams; urPrintMapFlags; diff --git a/source/loader/ur_adapter_registry.hpp b/source/loader/ur_adapter_registry.hpp index 060a5ae8a9..61279e820b 100644 --- a/source/loader/ur_adapter_registry.hpp +++ b/source/loader/ur_adapter_registry.hpp @@ -122,6 +122,9 @@ class AdapterRegistry { MAKE_LIBRARY_NAME("ur_adapter_native_cpu", "0"), }; + static constexpr const char *mockAdapterName = + MAKE_LIBRARY_NAME("ur_adapter_mock", "0"); + std::optional> getEnvAdapterSearchPaths() { std::optional> pathStringsOpt; try { @@ -179,6 +182,21 @@ class AdapterRegistry { adaptersLoadPaths.emplace_back(loadPaths); } } + + public: + void enableMock() { + adaptersLoadPaths.clear(); + + std::vector loadPaths; + auto adapterNamePathOpt = getAdapterNameAsPath(mockAdapterName); + auto loaderLibPathOpt = getLoaderLibPath(); + if (adapterNamePathOpt.has_value() && loaderLibPathOpt.has_value()) { + const auto &adapterNamePath = adapterNamePathOpt.value(); + const auto &loaderLibPath = loaderLibPathOpt.value(); + loadPaths.emplace_back(loaderLibPath / adapterNamePath); + } + adaptersLoadPaths.emplace_back(loadPaths); + } }; } // namespace ur_loader diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index f0ade9a664..b389b24713 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -105,7 +105,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -122,6 +122,18 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hAdapter = reinterpret_cast( + ur_adapter_factory.getInstance(*hAdapter, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -586,7 +598,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -603,6 +615,18 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hDevice = reinterpret_cast( + ur_device_factory.getInstance(*hDevice, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -863,7 +887,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -880,6 +904,18 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hContext = reinterpret_cast( + ur_context_factory.getInstance(*hContext, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1168,7 +1204,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1185,6 +1222,18 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hMem = reinterpret_cast( + ur_mem_factory.getInstance(*hMem, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1526,7 +1575,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1543,6 +1592,18 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hSampler = reinterpret_cast( + ur_sampler_factory.getInstance(*hSampler, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1961,7 +2022,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1978,6 +2039,18 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *pPool = reinterpret_cast( + ur_usm_pool_factory.getInstance(*pPool, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -2346,7 +2419,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2365,6 +2438,18 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hPhysicalMem = reinterpret_cast( + ur_physical_mem_factory.getInstance(*hPhysicalMem, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -2595,7 +2680,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2612,6 +2698,18 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hProgram = reinterpret_cast( + ur_program_factory.getInstance(*hProgram, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -3185,7 +3283,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3202,6 +3300,18 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hKernel = reinterpret_cast( + ur_kernel_factory.getInstance(*hKernel, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -3631,7 +3741,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3648,6 +3759,18 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hQueue = reinterpret_cast( + ur_queue_factory.getInstance(*hQueue, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -3939,7 +4062,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3956,6 +4079,18 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hEvent = reinterpret_cast( + ur_event_factory.getInstance(*hEvent, dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -6600,7 +6735,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -6621,6 +6756,19 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + *hCommandBuffer = reinterpret_cast( + ur_exp_command_buffer_factory.getInstance(*hCommandBuffer, + dditable)); + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index d2ed4853a8..a4a391b73f 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -23,6 +23,7 @@ #include // for std::memcpy #include +#include namespace ur_lib { /////////////////////////////////////////////////////////////////////////////// @@ -80,6 +81,12 @@ void context_t::tearDownLayers() const { ////////////////////////////////////////////////////////////////////////// __urdlllocal ur_result_t context_t::Init( ur_device_init_flags_t, ur_loader_config_handle_t hLoaderConfig) { + if (hLoaderConfig->enableMock) { + // This clears default known adapters and replaces them with the mock + // adapter. + ur_loader::context->adapter_registry.enableMock(); + } + ur_result_t result; const char *logger_name = "loader"; logger::init(logger_name); @@ -215,6 +222,13 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, return UR_RESULT_SUCCESS; } +ur_result_t +urLoaderConfigSetMockingEnabled(ur_loader_config_handle_t hLoaderConfig, + ur_bool_t enable) { + hLoaderConfig->enableMock = enable; + return UR_RESULT_SUCCESS; +} + ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, ur_device_type_t DeviceType, uint32_t NumEntries, diff --git a/source/loader/ur_lib.hpp b/source/loader/ur_lib.hpp index 839c0041d9..6fafe3a32d 100644 --- a/source/loader/ur_lib.hpp +++ b/source/loader/ur_lib.hpp @@ -48,6 +48,7 @@ struct ur_loader_config_handle_t_ { std::set &getEnabledLayerNames() { return enabledLayers; } codeloc_data codelocData; + bool enableMock; }; namespace ur_lib { @@ -104,6 +105,9 @@ ur_result_t urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, ur_code_location_callback_t pfnCodeloc, void *pUserData); +ur_result_t +urLoaderConfigSetMockingEnabled(ur_loader_config_handle_t hLoaderConfig, + ur_bool_t enable); ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, ur_device_type_t DeviceType, diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index b2eade1771..2fc6b18f9a 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -52,7 +52,7 @@ ur_result_t UR_APICALL urLoaderConfigCreate( /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRetain( ur_loader_config_handle_t - hLoaderConfig ///< [in] loader config handle to retain + hLoaderConfig ///< [in][retain] loader config handle to retain ) try { return ur_lib::urLoaderConfigRetain(hLoaderConfig); } catch (...) { @@ -192,6 +192,33 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief The only adapter reported with mock enabled will be the mock adapter. +/// +/// @details +/// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all +/// entry points. It will also create and correctly reference count dummy +/// handles where appropriate. Its behaviour can be modified by linking +/// the ::ur_mock_headers library and using the callbacks object. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +ur_result_t UR_APICALL urLoaderConfigSetMockingEnabled( + ur_loader_config_handle_t + hLoaderConfig, ///< [in] Handle to config object mocking will be enabled for. + ur_bool_t + enable ///< [in] Handle to config object the layer will be enabled for. + ) try { + return ur_lib::urLoaderConfigSetMockingEnabled(hLoaderConfig, enable); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -341,7 +368,7 @@ ur_result_t UR_APICALL urAdapterRelease( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) try { auto pfnAdapterRetain = ur_lib::context->urDdiTable.Global.pfnAdapterRetain; if (nullptr == pfnAdapterRetain) { @@ -912,7 +939,7 @@ ur_result_t UR_APICALL urDeviceGetInfo( /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) try { auto pfnRetain = ur_lib::context->urDdiTable.Device.pfnRetain; if (nullptr == pfnRetain) { @@ -1253,7 +1280,7 @@ ur_result_t UR_APICALL urContextCreate( /// + `NULL == hContext` ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) try { auto pfnRetain = ur_lib::context->urDdiTable.Context.pfnRetain; if (nullptr == pfnRetain) { @@ -1621,7 +1648,8 @@ ur_result_t UR_APICALL urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Mem.pfnRetain; if (nullptr == pfnRetain) { @@ -2021,7 +2049,7 @@ ur_result_t UR_APICALL urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Sampler.pfnRetain; if (nullptr == pfnRetain) { @@ -2490,7 +2518,7 @@ ur_result_t UR_APICALL urUSMPoolCreate( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) try { auto pfnPoolRetain = ur_lib::context->urDdiTable.USM.pfnPoolRetain; if (nullptr == pfnPoolRetain) { @@ -2882,7 +2910,7 @@ ur_result_t UR_APICALL urPhysicalMemCreate( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) try { auto pfnRetain = ur_lib::context->urDdiTable.PhysicalMem.pfnRetain; if (nullptr == pfnRetain) { @@ -3186,7 +3214,8 @@ ur_result_t UR_APICALL urProgramLink( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) try { auto pfnRetain = ur_lib::context->urDdiTable.Program.pfnRetain; if (nullptr == pfnRetain) { @@ -3815,7 +3844,7 @@ ur_result_t UR_APICALL urKernelGetSubGroupInfo( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) try { auto pfnRetain = ur_lib::context->urDdiTable.Kernel.pfnRetain; if (nullptr == pfnRetain) { @@ -4313,7 +4342,8 @@ ur_result_t UR_APICALL urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Queue.pfnRetain; if (nullptr == pfnRetain) { @@ -4671,7 +4701,7 @@ ur_result_t UR_APICALL urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) try { auto pfnRetain = ur_lib::context->urDdiTable.Event.pfnRetain; if (nullptr == pfnRetain) { @@ -7344,7 +7374,7 @@ ur_result_t UR_APICALL urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) try { auto pfnRetainExp = ur_lib::context->urDdiTable.CommandBufferExp.pfnRetainExp; diff --git a/source/loader/ur_print.cpp b/source/loader/ur_print.cpp index 3f2d017a89..e0c11d2f51 100644 --- a/source/loader/ur_print.cpp +++ b/source/loader/ur_print.cpp @@ -2001,6 +2001,14 @@ ur_result_t urPrintLoaderConfigSetCodeLocationCallbackParams( return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintLoaderConfigSetMockingEnabledParams( + const struct ur_loader_config_set_mocking_enabled_params_t *params, + char *buffer, const size_t buff_size, size_t *out_size) { + std::stringstream ss; + ss << params; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintMemImageCreateParams(const struct ur_mem_image_create_params_t *params, char *buffer, const size_t buff_size, diff --git a/source/mock/CMakeLists.txt b/source/mock/CMakeLists.txt new file mode 100644 index 0000000000..f3933fbef1 --- /dev/null +++ b/source/mock/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (C) 2024 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_library (ur_mock_headers SHARED + "${CMAKE_CURRENT_SOURCE_DIR}/ur_mock_helpers.cpp") + +target_include_directories(ur_mock_headers + INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}" +) + +target_link_libraries(ur_mock_headers PRIVATE + ${PROJECT_NAME}::headers +) + +add_library(${PROJECT_NAME}::mock ALIAS ur_mock_headers) diff --git a/source/mock/ur_mock_helpers.cpp b/source/mock/ur_mock_helpers.cpp new file mode 100644 index 0000000000..5fb4391a2d --- /dev/null +++ b/source/mock/ur_mock_helpers.cpp @@ -0,0 +1,19 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_mock_helpers.cpp + * + */ + +#include "ur_mock_helpers.hpp" + +namespace mock { + +callbacks_t callbacks; + +} // namespace mock diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp new file mode 100644 index 0000000000..f1b350b736 --- /dev/null +++ b/source/mock/ur_mock_helpers.hpp @@ -0,0 +1,107 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_mock_helpers.hpp + * + */ + +#pragma once + +#include "ur_api.h" +#include +#include +#include +#include + +namespace mock { + +struct dummy_handle_t_ { + dummy_handle_t_() {} + std::atomic refCounter = 1; +}; + +using dummy_handle_t = dummy_handle_t_ *; + +// Allocates a dummy handle of type T with support for reference counting. +// The handle has to be deallocated using 'releaseDummyHandle'. +template inline T createDummyHandle() { + dummy_handle_t DummyHandlePtr = new dummy_handle_t_(); + return reinterpret_cast(DummyHandlePtr); +} + +// Decrement reference counter for the handle and deallocates it if the +// reference counter becomes zero +template inline void releaseDummyHandle(T Handle) { + auto DummyHandlePtr = reinterpret_cast(Handle); + const size_t NewValue = --DummyHandlePtr->refCounter; + if (NewValue == 0) { + delete DummyHandlePtr; + } +} + +// Increment reference counter for the handle +template inline void retainDummyHandle(T Handle) { + auto DummyHandlePtr = reinterpret_cast(Handle); + ++DummyHandlePtr->refCounter; +} + +struct callbacks_t { + void set_before_callback(std::string name, ur_mock_callback_t callback) { + beforeCallbacks[name] = callback; + } + + ur_mock_callback_t get_before_callback(std::string name) { + auto callback = beforeCallbacks.find(name); + + if (callback != beforeCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + void set_replace_callback(std::string name, ur_mock_callback_t callback) { + replaceCallbacks[name] = callback; + } + + ur_mock_callback_t get_replace_callback(std::string name) { + auto callback = replaceCallbacks.find(name); + + if (callback != replaceCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + void set_after_callback(std::string name, ur_mock_callback_t callback) { + afterCallbacks[name] = callback; + } + + ur_mock_callback_t get_after_callback(std::string name) { + auto callback = afterCallbacks.find(name); + + if (callback != afterCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + void resetCallbacks() { + beforeCallbacks.clear(); + replaceCallbacks.clear(); + afterCallbacks.clear(); + } + + private: + std::unordered_map beforeCallbacks; + std::unordered_map replaceCallbacks; + std::unordered_map afterCallbacks; +}; + +extern callbacks_t callbacks; + +} // namespace mock diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 43b3c592d5..ffbe3044bc 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -48,7 +48,7 @@ ur_result_t UR_APICALL urLoaderConfigCreate( /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRetain( ur_loader_config_handle_t - hLoaderConfig ///< [in] loader config handle to retain + hLoaderConfig ///< [in][retain] loader config handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -181,6 +181,32 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief The only adapter reported with mock enabled will be the mock adapter. +/// +/// @details +/// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all +/// entry points. It will also create and correctly reference count dummy +/// handles where appropriate. Its behaviour can be modified by linking +/// the ::ur_mock_headers library and using the callbacks object. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +ur_result_t UR_APICALL urLoaderConfigSetMockingEnabled( + ur_loader_config_handle_t + hLoaderConfig, ///< [in] Handle to config object mocking will be enabled for. + ur_bool_t + enable ///< [in] Handle to config object the layer will be enabled for. +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -305,7 +331,7 @@ ur_result_t UR_APICALL urAdapterRelease( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -799,7 +825,7 @@ ur_result_t UR_APICALL urDeviceGetInfo( /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1087,7 +1113,7 @@ ur_result_t UR_APICALL urContextCreate( /// + `NULL == hContext` ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1402,7 +1428,8 @@ ur_result_t UR_APICALL urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1741,7 +1768,7 @@ ur_result_t UR_APICALL urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2139,7 +2166,7 @@ ur_result_t UR_APICALL urUSMPoolCreate( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2462,7 +2489,7 @@ ur_result_t UR_APICALL urPhysicalMemCreate( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2719,7 +2746,8 @@ ur_result_t UR_APICALL urProgramLink( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3245,7 +3273,7 @@ ur_result_t UR_APICALL urKernelGetSubGroupInfo( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3664,7 +3692,8 @@ ur_result_t UR_APICALL urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3963,7 +3992,7 @@ ur_result_t UR_APICALL urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -6244,7 +6273,7 @@ ur_result_t UR_APICALL urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; diff --git a/test/layers/mock/CMakeLists.txt b/test/layers/mock/CMakeLists.txt new file mode 100644 index 0000000000..e112ca4203 --- /dev/null +++ b/test/layers/mock/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (C) 2024 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(MOCK_TEST_NAME test-mock) + +add_ur_executable(${MOCK_TEST_NAME} mock.cpp) +target_link_libraries(${MOCK_TEST_NAME} + PRIVATE + ${PROJECT_NAME}::loader + ${PROJECT_NAME}::headers + ${PROJECT_NAME}::testing + GTest::gtest_main) + +add_test(NAME ${MOCK_TEST_NAME} + COMMAND ${MOCK_TEST_NAME} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + +set_tests_properties(${MOCK_TEST_NAME} PROPERTIES LABELS "mock") + +set_property(TEST ${MOCK_TEST_NAME} PROPERTY ENVIRONMENT + "UR_ADAPTERS_FORCE_LOAD=\"$\"") diff --git a/test/layers/mock/mock.cpp b/test/layers/mock/mock.cpp new file mode 100644 index 0000000000..75fcc5ef32 --- /dev/null +++ b/test/layers/mock/mock.cpp @@ -0,0 +1,141 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file codeloc.cpp + * + */ + +#include "uur/raii.h" +#include +#include + +TEST(Mock, NullProperties) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, nullptr), + UR_RESULT_ERROR_INVALID_NULL_POINTER); +} + +TEST(Mock, NullCallback) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + + ur_mock_callback_properties_t callback_properties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_REPLACE, nullptr}; + + ASSERT_EQ( + urLoaderConfigSetMockCallbacks(loader_config, &callback_properties), + UR_RESULT_ERROR_INVALID_NULL_POINTER); +} + +ur_result_t generic_callback(void *) { return UR_RESULT_SUCCESS; } + +TEST(Mock, NullHandle) { + ur_mock_callback_properties_t callback_properties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_REPLACE, &generic_callback}; + + ASSERT_EQ(urLoaderConfigSetMockCallbacks(nullptr, &callback_properties), + UR_RESULT_ERROR_INVALID_NULL_HANDLE); +} + +TEST(Mock, DefaultBehavior) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); + + // Set up as far as device and check we're getting sensible, different + // handles created. + ur_adapter_handle_t adapter = nullptr; + ur_platform_handle_t platform = nullptr; + ur_device_handle_t device = nullptr; + + ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); + ASSERT_EQ(urPlatformGet(&adapter, 1, 1, &platform, nullptr), + UR_RESULT_SUCCESS); + ASSERT_EQ(urDeviceGet(platform, UR_DEVICE_TYPE_ALL, 1, &device, nullptr), + UR_RESULT_SUCCESS); + + ASSERT_NE(adapter, nullptr); + ASSERT_NE(platform, nullptr); + ASSERT_NE(device, nullptr); + + ASSERT_NE(static_cast(adapter), static_cast(platform)); + ASSERT_NE(static_cast(adapter), static_cast(device)); + ASSERT_NE(static_cast(platform), static_cast(device)); + + ASSERT_EQ(urDeviceRelease(device), UR_RESULT_SUCCESS); +} + +void checkPreInitAdapter(ur_adapter_handle_t adapter) { + ur_adapter_handle_t preInitAdapter = + reinterpret_cast(0xF00DCAFE); + ASSERT_EQ(adapter, preInitAdapter); +} + +ur_result_t beforeUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + checkPreInitAdapter(**params->pphAdapters); + return UR_RESULT_SUCCESS; +} + +ur_result_t replaceUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + **params->pphAdapters = reinterpret_cast(0xDEADBEEF); + return UR_RESULT_SUCCESS; +} + +void checkPostInitAdapter(ur_adapter_handle_t adapter) { + ur_adapter_handle_t postInitAdapter = + reinterpret_cast(0xDEADBEEF); + ASSERT_EQ(adapter, postInitAdapter); +} + +ur_result_t afterUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + checkPostInitAdapter(**params->pphAdapters); + return UR_RESULT_SUCCESS; +} + +TEST(Mock, Callbacks) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + + // This callback is set up to check *phAdapters is still the pre-call + // init value we set below + ur_mock_callback_properties_t adapterGetBeforeProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_BEFORE, &beforeUrAdapterGet}; + + // This callback is set up to return a distinct test value in phAdapters + // rather than the default generic handle + ur_mock_callback_properties_t adapterGetReplaceProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, &adapterGetBeforeProperties, + "urAdapterGet", UR_CALLBACK_OVERRIDE_MODE_REPLACE, + &replaceUrAdapterGet}; + + // This callback is set up to check our replace callback did its job + ur_mock_callback_properties_t adapterGetAfterProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, + &adapterGetReplaceProperties, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_AFTER, &afterUrAdapterGet}; + + ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, + &adapterGetAfterProperties), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); + + ur_adapter_handle_t adapter = + reinterpret_cast(0xF00DCAFE); + ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); +} diff --git a/test/layers/tracing/CMakeLists.txt b/test/layers/tracing/CMakeLists.txt index 27c94b5cfd..969e4318b1 100644 --- a/test/layers/tracing/CMakeLists.txt +++ b/test/layers/tracing/CMakeLists.txt @@ -28,7 +28,7 @@ function(set_tracing_test_props target_name collector_name) "XPTI_TRACE_ENABLE=1" "XPTI_FRAMEWORK_DISPATCHER=$" "XPTI_SUBSCRIBERS=$" - "UR_ADAPTERS_FORCE_LOAD=\"$\"" + "UR_ADAPTERS_FORCE_LOAD=\"$\"" "UR_ENABLE_LAYERS=UR_LAYER_TRACING") endfunction() @@ -54,7 +54,7 @@ add_test(NAME example-logged-hello-world set_tests_properties(example-logged-hello-world PROPERTIES LABELS "tracing") set_property(TEST example-logged-hello-world PROPERTY ENVIRONMENT "UR_LOG_TRACING=level:info\;output:stdout" - "UR_ADAPTERS_FORCE_LOAD=\"$\"" + "UR_ADAPTERS_FORCE_LOAD=\"$\"" "UR_ENABLE_LAYERS=UR_LAYER_TRACING") function(add_tracing_test name) diff --git a/test/layers/validation/CMakeLists.txt b/test/layers/validation/CMakeLists.txt index 3e0446c5e4..63f7de7a8d 100644 --- a/test/layers/validation/CMakeLists.txt +++ b/test/layers/validation/CMakeLists.txt @@ -21,7 +21,7 @@ function(set_validation_test_properties name) set_tests_properties(${name} PROPERTIES LABELS "validation") set_property(TEST ${name} PROPERTY ENVIRONMENT "UR_ENABLE_LAYERS=UR_LAYER_FULL_VALIDATION" - "UR_ADAPTERS_FORCE_LOAD=\"$\"" + "UR_ADAPTERS_FORCE_LOAD=\"$\"" "UR_LOG_VALIDATION=level:debug\;flush:debug\;output:stdout") endfunction() diff --git a/test/loader/CMakeLists.txt b/test/loader/CMakeLists.txt index 5472da74bc..692a5f5d1d 100644 --- a/test/loader/CMakeLists.txt +++ b/test/loader/CMakeLists.txt @@ -5,7 +5,7 @@ add_test(NAME example-hello-world COMMAND hello_world DEPENDS hello_world) set_tests_properties(example-hello-world PROPERTIES LABELS "loader" - ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"" + ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"" ) add_subdirectory(adapter_registry) diff --git a/test/loader/handles/CMakeLists.txt b/test/loader/handles/CMakeLists.txt index 737216fc23..fada9e8ebb 100644 --- a/test/loader/handles/CMakeLists.txt +++ b/test/loader/handles/CMakeLists.txt @@ -23,5 +23,5 @@ add_test(NAME loader-handles set_tests_properties(loader-handles PROPERTIES LABELS "loader" - ENVIRONMENT "UR_ENABLE_LOADER_INTERCEPT=1;UR_ADAPTERS_FORCE_LOAD=\"$\"" + ENVIRONMENT "UR_ENABLE_LOADER_INTERCEPT=1;UR_ADAPTERS_FORCE_LOAD=\"$\"" ) diff --git a/test/loader/platforms/CMakeLists.txt b/test/loader/platforms/CMakeLists.txt index 86f8eea085..92e74856e7 100644 --- a/test/loader/platforms/CMakeLists.txt +++ b/test/loader/platforms/CMakeLists.txt @@ -25,7 +25,7 @@ function(add_loader_platform_test name ENV) -D MODE=stdout -D MATCH_FILE=${CMAKE_CURRENT_SOURCE_DIR}/${name}.match -P ${PROJECT_SOURCE_DIR}/cmake/match.cmake - DEPENDS test-loader-platforms ur_adapter_null + DEPENDS test-loader-platforms ur_adapter_mock ) set_tests_properties(${TEST_NAME} PROPERTIES LABELS "loader" @@ -34,4 +34,4 @@ function(add_loader_platform_test name ENV) endfunction() add_loader_platform_test(no_platforms "UR_ADAPTERS_FORCE_LOAD=\"\"") -add_loader_platform_test(null_platform "UR_ADAPTERS_FORCE_LOAD=\"$\"") +add_loader_platform_test(null_platform "UR_ADAPTERS_FORCE_LOAD=\"$\"") diff --git a/test/tools/urtrace/CMakeLists.txt b/test/tools/urtrace/CMakeLists.txt index 938fcf46f8..18212ce818 100644 --- a/test/tools/urtrace/CMakeLists.txt +++ b/test/tools/urtrace/CMakeLists.txt @@ -24,9 +24,9 @@ function(add_trace_test name CLI_ARGS) set_tests_properties(${TEST_NAME} PROPERTIES LABELS "urtrace") endfunction() -add_trace_test(null_hello "--libpath $ --null") -add_trace_test(null_hello_no_args "--libpath $ --null --no-args") -add_trace_test(null_hello_filter_device "--libpath $ --null --filter \".*Device.*\"") -add_trace_test(null_hello_profiling "--libpath $ --null --profiling --time-unit ns") -add_trace_test(null_hello_begin "--libpath $ --null --print-begin") -add_trace_test(null_hello_json "--libpath $ --null --json") +add_trace_test(null_hello "--libpath $ --null") +add_trace_test(null_hello_no_args "--libpath $ --null --no-args") +add_trace_test(null_hello_filter_device "--libpath $ --null --filter \".*Device.*\"") +add_trace_test(null_hello_profiling "--libpath $ --null --profiling --time-unit ns") +add_trace_test(null_hello_begin "--libpath $ --null --print-begin") +add_trace_test(null_hello_json "--libpath $ --null --json") diff --git a/test/usm/CMakeLists.txt b/test/usm/CMakeLists.txt index 3496328ce9..1e3d3eb78d 100644 --- a/test/usm/CMakeLists.txt +++ b/test/usm/CMakeLists.txt @@ -22,7 +22,7 @@ function(add_usm_test name) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) set_tests_properties(usm-${name} PROPERTIES LABELS "usm" - ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"") + ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"") target_compile_definitions("usm_test-${name}" PRIVATE DEVICES_ENVIRONMENT) endfunction() From 53acffb7305cd6677c1749d19121459cf5e5dcf5 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Wed, 26 Jun 2024 16:45:53 +0100 Subject: [PATCH 02/15] Add storage stuff back to dummy handles. --- source/mock/ur_mock_helpers.hpp | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index f1b350b736..900678d4bd 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -13,24 +13,31 @@ #pragma once #include "ur_api.h" + #include #include #include #include +#include namespace mock { struct dummy_handle_t_ { - dummy_handle_t_() {} - std::atomic refCounter = 1; + dummy_handle_t_(size_t DataSize = 0) + : MStorage(DataSize), MData(MStorage.data()) {} + dummy_handle_t_(unsigned char *Data) : MData(Data) {} + std::atomic MRefCounter = 1; + std::vector MStorage; + unsigned char *MData = nullptr; }; using dummy_handle_t = dummy_handle_t_ *; -// Allocates a dummy handle of type T with support for reference counting. -// The handle has to be deallocated using 'releaseDummyHandle'. -template inline T createDummyHandle() { - dummy_handle_t DummyHandlePtr = new dummy_handle_t_(); +// Allocates a dummy handle of type T with support of reference counting. +// Takes optional 'Size' parameter which can be used to allocate additional +// memory. The handle has to be deallocated using 'releaseDummyHandle'. +template inline T createDummyHandle(size_t Size = 0) { + dummy_handle_t DummyHandlePtr = new dummy_handle_t_(Size); return reinterpret_cast(DummyHandlePtr); } @@ -38,7 +45,7 @@ template inline T createDummyHandle() { // reference counter becomes zero template inline void releaseDummyHandle(T Handle) { auto DummyHandlePtr = reinterpret_cast(Handle); - const size_t NewValue = --DummyHandlePtr->refCounter; + const size_t NewValue = --DummyHandlePtr->MRefCounter; if (NewValue == 0) { delete DummyHandlePtr; } @@ -47,7 +54,7 @@ template inline void releaseDummyHandle(T Handle) { // Increment reference counter for the handle template inline void retainDummyHandle(T Handle) { auto DummyHandlePtr = reinterpret_cast(Handle); - ++DummyHandlePtr->refCounter; + ++DummyHandlePtr->MRefCounter; } struct callbacks_t { From 0eca88593f88ccb793c407a7398f575e25ad69e0 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 27 Jun 2024 15:56:30 +0100 Subject: [PATCH 03/15] Add missing dummy handle host pointer helper. --- source/mock/ur_mock_helpers.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index 900678d4bd..66eb0af588 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -41,6 +41,13 @@ template inline T createDummyHandle(size_t Size = 0) { return reinterpret_cast(DummyHandlePtr); } +// Allocates a dummy handle of type T with support of reference counting +// and associates it with the provided Data. +template inline T createDummyHandleWithData(unsigned char *Data) { + auto DummyHandlePtr = new dummy_handle_t_(Data); + return reinterpret_cast(DummyHandlePtr); +} + // Decrement reference counter for the handle and deallocates it if the // reference counter becomes zero template inline void releaseDummyHandle(T Handle) { From 87b8e67399b221cc53365ca8590686e670c1b0fd Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 27 Jun 2024 16:08:47 +0100 Subject: [PATCH 04/15] Add size tracking independent of storage buffer. --- source/mock/ur_mock_helpers.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index 66eb0af588..bb33a5b2b1 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -24,11 +24,13 @@ namespace mock { struct dummy_handle_t_ { dummy_handle_t_(size_t DataSize = 0) - : MStorage(DataSize), MData(MStorage.data()) {} - dummy_handle_t_(unsigned char *Data) : MData(Data) {} + : MStorage(DataSize), MData(MStorage.data()), MSize(DataSize) {} + dummy_handle_t_(unsigned char *Data, size_t Size) + : MData(Data), MSize(Size) {} std::atomic MRefCounter = 1; std::vector MStorage; unsigned char *MData = nullptr; + size_t MSize; }; using dummy_handle_t = dummy_handle_t_ *; @@ -43,8 +45,9 @@ template inline T createDummyHandle(size_t Size = 0) { // Allocates a dummy handle of type T with support of reference counting // and associates it with the provided Data. -template inline T createDummyHandleWithData(unsigned char *Data) { - auto DummyHandlePtr = new dummy_handle_t_(Data); +template +inline T createDummyHandleWithData(unsigned char *Data, size_t Size) { + auto DummyHandlePtr = new dummy_handle_t_(Data, Size); return reinterpret_cast(DummyHandlePtr); } From 56476fdc88c6648b329170369c1ff1d0d45fdb5a Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Mon, 1 Jul 2024 15:15:06 +0100 Subject: [PATCH 05/15] Add the reset of the memory stuff from the default overrides. --- scripts/templates/mockddi.cpp.mako | 25 +++++++++++++++++++++++++ source/adapters/mock/ur_mockddi.cpp | 17 +++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako index 56b333f798..70add5645f 100644 --- a/scripts/templates/mockddi.cpp.mako +++ b/scripts/templates/mockddi.cpp.mako @@ -71,6 +71,31 @@ namespace driver %else: *phNative${func_class} = reinterpret_cast(h${func_class}); %endif + ## These special cases handle memory stuff. Use verbose regex matching + ## to limit the possibility of unintentional stuff getting generated for + ## future entry points with similar names. + %elif re.search(r"MemBufferCreate$", fname): + if (pProperties && (pProperties)->pHost && + flags & UR_MEM_FLAG_USE_HOST_POINTER) { + *phBuffer = mock::createDummyHandleWithData( + reinterpret_cast((pProperties)->pHost), + size); + } else { + *phBuffer = + mock::createDummyHandle(size); + } + %elif re.search(r"EnqueueMemBufferMap$", fname): + if(phEvent) { + *phEvent = mock::createDummyHandle(); + } + + auto parentDummyHandle = + reinterpret_cast(hBuffer); + *ppRetMap = (void *)(parentDummyHandle->MData); + %elif re.search(r"USM(Host|Device|Shared)Alloc$", fname): + *ppMem = mock::createDummyHandle(size); + %elif re.search(r"USMPitchedAllocExp$", fname): + *ppMem = mock::createDummyHandle(widthInBytes * height); %else: %if fname == 'urAdapterGet' or fname == 'urDeviceGet' or fname == 'urPlatformGet': <% diff --git a/source/adapters/mock/ur_mockddi.cpp b/source/adapters/mock/ur_mockddi.cpp index 4ecce32beb..76f8fc8393 100644 --- a/source/adapters/mock/ur_mockddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -1446,7 +1446,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( result = replaceCallback(¶ms); } else { - *phBuffer = mock::createDummyHandle(); + if (pProperties && (pProperties)->pHost && + flags & UR_MEM_FLAG_USE_HOST_POINTER) { + *phBuffer = mock::createDummyHandleWithData( + reinterpret_cast((pProperties)->pHost), size); + } else { + *phBuffer = mock::createDummyHandle(size); + } result = UR_RESULT_SUCCESS; } @@ -2187,6 +2193,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( result = replaceCallback(¶ms); } else { + *ppMem = mock::createDummyHandle(size); result = UR_RESULT_SUCCESS; } @@ -2238,6 +2245,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( result = replaceCallback(¶ms); } else { + *ppMem = mock::createDummyHandle(size); result = UR_RESULT_SUCCESS; } @@ -2289,6 +2297,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( result = replaceCallback(¶ms); } else { + *ppMem = mock::createDummyHandle(size); result = UR_RESULT_SUCCESS; } @@ -6285,10 +6294,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( result = replaceCallback(¶ms); } else { - // optional output handle if (phEvent) { *phEvent = mock::createDummyHandle(); } + + auto parentDummyHandle = + reinterpret_cast(hBuffer); + *ppRetMap = (void *)(parentDummyHandle->MData); result = UR_RESULT_SUCCESS; } @@ -7073,6 +7085,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( result = replaceCallback(¶ms); } else { + *ppMem = mock::createDummyHandle(widthInBytes * height); result = UR_RESULT_SUCCESS; } From 0a5656955dad5d47a3953f20e6e1a96e0164279b Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 2 Jul 2024 11:18:58 +0100 Subject: [PATCH 06/15] Address review feedback --- include/ur_api.h | 9 +-------- scripts/core/INTRO.rst | 2 +- scripts/core/loader.yml | 12 ++---------- source/loader/ur_libapi.cpp | 2 +- source/mock/ur_mock_helpers.hpp | 5 +++++ source/ur_api.cpp | 2 +- 6 files changed, 11 insertions(+), 21 deletions(-) diff --git a/include/ur_api.h b/include/ur_api.h index c896978528..e7f9349b41 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -743,13 +743,6 @@ urLoaderConfigSetCodeLocationCallback( void *pUserData ///< [in][out][optional] pointer to data to be passed to callback. ); -/////////////////////////////////////////////////////////////////////////////// -/// @brief Callback to replace or instrument generic mock functionality in the -/// mock adapter. -typedef ur_result_t (*ur_mock_callback_t)( - void *pParams ///< [in][out] Pointer to the appropriate param struct for the function -); - /////////////////////////////////////////////////////////////////////////////// /// @brief The only adapter reported with mock enabled will be the mock adapter. /// @@ -757,7 +750,7 @@ typedef ur_result_t (*ur_mock_callback_t)( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the callbacks object. +/// the ::ur_mock_headers library and using the mock::callbacks object. /// /// @returns /// - ::UR_RESULT_SUCCESS diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index 5e716101ef..a48b32a8f9 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -280,7 +280,7 @@ given entry point can only have one of each kind of callback associated with it, multiple structs with the same function/mode combination will override eachother. -The callback signature defined by ``${x}_mock_callback_t`` takes a single +The callback signature defined by ``ur_mock_callback_t`` takes a single ``void *`` parameter. When calling a user callback the layer will pack the entry point's parameters into the appropriate ``_params_t`` struct (e.g. ``ur_adapter_get_params_t``) and pass a pointer to that struct into the diff --git a/scripts/core/loader.yml b/scripts/core/loader.yml index 1df23607e8..e94242718c 100644 --- a/scripts/core/loader.yml +++ b/scripts/core/loader.yml @@ -1,3 +1,4 @@ +# # Copyright (C) 2022-2023 Intel Corporation # # Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. @@ -186,19 +187,10 @@ params: name: pUserData desc: "[in][out][optional] pointer to data to be passed to callback." --- #-------------------------------------------------------------------------- -type: fptr_typedef -desc: "Callback to replace or instrument generic mock functionality in the mock adapter." -name: $x_mock_callback_t -return: $x_result_t -params: - - type: void* - name: pParams - desc: "[in][out] Pointer to the appropriate param struct for the function" ---- #-------------------------------------------------------------------------- type: function desc: "The only adapter reported with mock enabled will be the mock adapter." details: - - "The mock adapter will default to returning $X_RESULT_SUCCESS for all entry points. It will also create and correctly reference count dummy handles where appropriate. Its behaviour can be modified by linking the $x_mock_headers library and using the callbacks object." + - "The mock adapter will default to returning $X_RESULT_SUCCESS for all entry points. It will also create and correctly reference count dummy handles where appropriate. Its behaviour can be modified by linking the $x_mock_headers library and using the mock::callbacks object." class: $xLoaderConfig loader_only: True name: "SetMockingEnabled" diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 2fc6b18f9a..0c6097fa86 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -199,7 +199,7 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the callbacks object. +/// the ::ur_mock_headers library and using the mock::callbacks object. /// /// @returns /// - ::UR_RESULT_SUCCESS diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index bb33a5b2b1..df69f9fa7e 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -20,6 +20,11 @@ #include #include +// This is the callback function we accept to override or instrument +// entry-points. pParams is expected to be a pointer to the appropriate params_t +// struct for the given entry point. +typedef ur_result_t (*ur_mock_callback_t)(void *pParams); + namespace mock { struct dummy_handle_t_ { diff --git a/source/ur_api.cpp b/source/ur_api.cpp index ffbe3044bc..46296ae4de 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -188,7 +188,7 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the callbacks object. +/// the ::ur_mock_headers library and using the mock::callbacks object. /// /// @returns /// - ::UR_RESULT_SUCCESS From af4ddc7fac18bcc67ba425af8c5b6b9f2e9acc5d Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 2 Jul 2024 12:23:40 +0100 Subject: [PATCH 07/15] Re-do docs and testing --- scripts/core/INTRO.rst | 32 +++++++--------- scripts/templates/mockddi.cpp.mako | 2 +- source/CMakeLists.txt | 1 - source/loader/CMakeLists.txt | 1 + source/loader/ur_lib.cpp | 3 ++ test/CMakeLists.txt | 1 + test/{layers => }/mock/CMakeLists.txt | 4 +- test/{layers => }/mock/mock.cpp | 54 +++++---------------------- 8 files changed, 29 insertions(+), 69 deletions(-) rename test/{layers => }/mock/CMakeLists.txt (83%) rename test/{layers => }/mock/mock.cpp (61%) diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index a48b32a8f9..07e9ae65f1 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -258,30 +258,26 @@ For more information about the usage of mentioned environment variables see `Env Mocking --------------------- -A mock UR adapter can be accessed for test purposes by enabling the ``MOCK`` -layer as described below. When the mock layer is enabled, calls to the API will -still be intercepted by other layers (e.g. validation, tracing), but they will -stop short of the loader - the call chain will end in either a generic fallback -behavior defined by the mock layer itself, or a user defined replacement -callback. - -The default fallback behavior for entry points in the mock layer is to simply +A mock UR adapter can be accessed for test purposes by enabling it via +${x}LoaderConfigSetMockingEnabled. + +The default fallback behavior for entry points in the mock adapter is to simply return ``UR_RESULT_SUCCESS``. For entry points concerning handles, i.e. those that create a new handle or modify the reference count of an existing one, a -dummy handle mechanism is used. This means the layer will return generic +dummy handle mechanism is used. This means the adapter will return generic handles that track a reference count, and ``Retain``/``Release`` entry points will function as expected when used with these handles. -During global setup the behavior of the mock layer can be customized by setting -chain of structs, with each registering a callback with a given entry point in -the API. Callbacks can be registered to be called ``BEFORE`` or ``AFTER`` the -generic implementation, or they can be registered to entirely ``REPLACE`` it. A -given entry point can only have one of each kind of callback associated with -it, multiple structs with the same function/mode combination will override -eachother. +The behavior of the mock adapter can be customized by linking the +``unified-runtime::mock`` library and making use of the ``mock::callbacks`` +object. Callbacks can be passed into this object to run either before or after a +given entry point, or they can be set to entirely replace the default behavior. +Only one callback of each type (before, replace, after) can be set per entry +point, with subsequent callbacks set in the same "slot" overwriting any set +previously. The callback signature defined by ``ur_mock_callback_t`` takes a single -``void *`` parameter. When calling a user callback the layer will pack the +``void *`` parameter. When calling a user callback the adapter will pack the entry point's parameters into the appropriate ``_params_t`` struct (e.g. ``ur_adapter_get_params_t``) and pass a pointer to that struct into the callback. This allows parameters to be accessed and modified. The definitions @@ -309,8 +305,6 @@ Layers currently included with the runtime are as follows: - Enables the XPTI tracing layer, see Tracing_ for more detail. * - UR_LAYER_ASAN \| UR_LAYER_MSAN \| UR_LAYER_TSAN - Enables the device-side sanitizer layer, see Sanitizers_ for more detail. - * - UR_LAYER_MOCK - - Enables adapter mocking for test purposes. Similar behavior to the null adapter except entry points can be overridden or instrumented with callbacks. See Mocking_ for more detail. Environment Variables --------------------- diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako index 70add5645f..8cc7f4db2f 100644 --- a/scripts/templates/mockddi.cpp.mako +++ b/scripts/templates/mockddi.cpp.mako @@ -9,7 +9,7 @@ from templates import helper as th X=x.upper() %>/* * - * Copyright (C) 2019-2024 Intel Corporation + * Copyright (C) 2024 Intel Corporation * * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. * See LICENSE.TXT diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 0e994bfb9c..f0dd315313 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -8,6 +8,5 @@ add_definitions(-DUR_VALIDATION_LAYER_SUPPORTED_VERSION="${PROJECT_VERSION_MAJOR add_subdirectory(common) add_subdirectory(loader) -#add_subdirectory(layers) add_subdirectory(mock) add_subdirectory(adapters) diff --git a/source/loader/CMakeLists.txt b/source/loader/CMakeLists.txt index 4f9b5b53ec..075d9909b0 100644 --- a/source/loader/CMakeLists.txt +++ b/source/loader/CMakeLists.txt @@ -159,6 +159,7 @@ if(UR_ENABLE_SANITIZER) ) endif() + # link validation backtrace dependencies if(UNIX) find_package(Libbacktrace) diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index a4a391b73f..6baaa9fb02 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -225,6 +225,9 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, ur_result_t urLoaderConfigSetMockingEnabled(ur_loader_config_handle_t hLoaderConfig, ur_bool_t enable) { + if (!hLoaderConfig) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } hLoaderConfig->enableMock = enable; return UR_RESULT_SUCCESS; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3df71a081d..e648bac44a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -26,6 +26,7 @@ add_subdirectory(adapters) add_subdirectory(usm) add_subdirectory(layers) add_subdirectory(unit) +add_subdirectory(mock) if(UR_BUILD_TOOLS) add_subdirectory(tools) endif() diff --git a/test/layers/mock/CMakeLists.txt b/test/mock/CMakeLists.txt similarity index 83% rename from test/layers/mock/CMakeLists.txt rename to test/mock/CMakeLists.txt index e112ca4203..59b099505b 100644 --- a/test/layers/mock/CMakeLists.txt +++ b/test/mock/CMakeLists.txt @@ -11,6 +11,7 @@ target_link_libraries(${MOCK_TEST_NAME} ${PROJECT_NAME}::loader ${PROJECT_NAME}::headers ${PROJECT_NAME}::testing + ${PROJECT_NAME}::mock GTest::gtest_main) add_test(NAME ${MOCK_TEST_NAME} @@ -18,6 +19,3 @@ add_test(NAME ${MOCK_TEST_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) set_tests_properties(${MOCK_TEST_NAME} PROPERTIES LABELS "mock") - -set_property(TEST ${MOCK_TEST_NAME} PROPERTY ENVIRONMENT - "UR_ADAPTERS_FORCE_LOAD=\"$\"") diff --git a/test/layers/mock/mock.cpp b/test/mock/mock.cpp similarity index 61% rename from test/layers/mock/mock.cpp rename to test/mock/mock.cpp index 75fcc5ef32..36c3bab113 100644 --- a/test/layers/mock/mock.cpp +++ b/test/mock/mock.cpp @@ -14,41 +14,17 @@ #include #include -TEST(Mock, NullProperties) { - uur::raii::LoaderConfig loader_config; - ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); - ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, nullptr), - UR_RESULT_ERROR_INVALID_NULL_POINTER); -} - -TEST(Mock, NullCallback) { - uur::raii::LoaderConfig loader_config; - ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); - - ur_mock_callback_properties_t callback_properties = { - UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", - UR_CALLBACK_OVERRIDE_MODE_REPLACE, nullptr}; - - ASSERT_EQ( - urLoaderConfigSetMockCallbacks(loader_config, &callback_properties), - UR_RESULT_ERROR_INVALID_NULL_POINTER); -} - -ur_result_t generic_callback(void *) { return UR_RESULT_SUCCESS; } +#include TEST(Mock, NullHandle) { - ur_mock_callback_properties_t callback_properties = { - UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", - UR_CALLBACK_OVERRIDE_MODE_REPLACE, &generic_callback}; - - ASSERT_EQ(urLoaderConfigSetMockCallbacks(nullptr, &callback_properties), + ASSERT_EQ(urLoaderConfigSetMockingEnabled(nullptr, true), UR_RESULT_ERROR_INVALID_NULL_HANDLE); } TEST(Mock, DefaultBehavior) { uur::raii::LoaderConfig loader_config; ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); - ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), + ASSERT_EQ(urLoaderConfigSetMockingEnabled(loader_config, true), UR_RESULT_SUCCESS); ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); @@ -108,32 +84,20 @@ ur_result_t afterUrAdapterGet(void *pParams) { TEST(Mock, Callbacks) { uur::raii::LoaderConfig loader_config; ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigSetMockingEnabled(loader_config, true), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); // This callback is set up to check *phAdapters is still the pre-call // init value we set below - ur_mock_callback_properties_t adapterGetBeforeProperties = { - UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", - UR_CALLBACK_OVERRIDE_MODE_BEFORE, &beforeUrAdapterGet}; + mock::callbacks.set_before_callback("urAdapterGet", &beforeUrAdapterGet); // This callback is set up to return a distinct test value in phAdapters // rather than the default generic handle - ur_mock_callback_properties_t adapterGetReplaceProperties = { - UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, &adapterGetBeforeProperties, - "urAdapterGet", UR_CALLBACK_OVERRIDE_MODE_REPLACE, - &replaceUrAdapterGet}; + mock::callbacks.set_replace_callback("urAdapterGet", &replaceUrAdapterGet); // This callback is set up to check our replace callback did its job - ur_mock_callback_properties_t adapterGetAfterProperties = { - UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, - &adapterGetReplaceProperties, "urAdapterGet", - UR_CALLBACK_OVERRIDE_MODE_AFTER, &afterUrAdapterGet}; - - ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, - &adapterGetAfterProperties), - UR_RESULT_SUCCESS); - ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), - UR_RESULT_SUCCESS); - ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); + mock::callbacks.set_after_callback("urAdapterGet", &afterUrAdapterGet); ur_adapter_handle_t adapter = reinterpret_cast(0xF00DCAFE); From c90c6d04b9e861e547c2e0ff080955c704b97fa5 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 2 Jul 2024 14:31:34 +0100 Subject: [PATCH 08/15] Add missing retain and fix some null adapter mentions. --- examples/collector/README.md | 2 +- include/ur_api.h | 2 +- scripts/core/exp-command-buffer.yml | 2 +- source/adapters/mock/ur_mockddi.cpp | 5 +++-- source/loader/layers/tracing/ur_trcddi.cpp | 2 +- source/loader/layers/validation/ur_valddi.cpp | 2 +- source/loader/ur_ldrddi.cpp | 8 +++++++- source/loader/ur_libapi.cpp | 2 +- source/ur_api.cpp | 2 +- test/conformance/CMakeLists.txt | 2 +- test/fuzz/README.md | 4 ++-- tools/urtrace/urtrace.py | 16 ++++++++-------- 12 files changed, 28 insertions(+), 21 deletions(-) diff --git a/examples/collector/README.md b/examples/collector/README.md index fbdf18a8ae..de7755fa16 100644 --- a/examples/collector/README.md +++ b/examples/collector/README.md @@ -19,7 +19,7 @@ $ mkdir build $ cd build $ cmake .. -DUR_ENABLE_TRACING=ON $ make -$ UR_ADAPTERS_FORCE_LOAD=./lib/libur_adapter_null.so UR_ENABLE_LAYERS=UR_LAYER_TRACING XPTI_TRACE_ENABLE=1 XPTI_FRAMEWORK_DISPATCHER=./lib/libxptifw.so XPTI_SUBSCRIBERS=./lib/libcollector.so ./bin/hello_world +$ UR_ADAPTERS_FORCE_LOAD=./lib/libur_adapter_mock.so UR_ENABLE_LAYERS=UR_LAYER_TRACING XPTI_TRACE_ENABLE=1 XPTI_FRAMEWORK_DISPATCHER=./lib/libxptifw.so XPTI_SUBSCRIBERS=./lib/libcollector.so ./bin/hello_world ``` See [XPTI framework documentation](https://github.com/intel/llvm/blob/sycl/xptifw/doc/XPTI_Framework.md) for more information. diff --git a/include/ur_api.h b/include/ur_api.h index e7f9349b41..b8ceb18997 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -8822,7 +8822,7 @@ urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainCommandExp( - ur_exp_command_buffer_command_handle_t hCommand ///< [in] Handle of the command-buffer command. + ur_exp_command_buffer_command_handle_t hCommand ///< [in][retain] Handle of the command-buffer command. ); /////////////////////////////////////////////////////////////////////////////// diff --git a/scripts/core/exp-command-buffer.yml b/scripts/core/exp-command-buffer.yml index 73a76e6d87..72b4e63f74 100644 --- a/scripts/core/exp-command-buffer.yml +++ b/scripts/core/exp-command-buffer.yml @@ -879,7 +879,7 @@ name: RetainCommandExp params: - type: $x_exp_command_buffer_command_handle_t name: hCommand - desc: "[in] Handle of the command-buffer command." + desc: "[in][retain] Handle of the command-buffer command." returns: - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP - $X_RESULT_ERROR_OUT_OF_RESOURCES diff --git a/source/adapters/mock/ur_mockddi.cpp b/source/adapters/mock/ur_mockddi.cpp index 76f8fc8393..07757c4036 100644 --- a/source/adapters/mock/ur_mockddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -1,6 +1,6 @@ /* * - * Copyright (C) 2019-2024 Intel Corporation + * Copyright (C) 2024 Intel Corporation * * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. * See LICENSE.TXT @@ -9139,7 +9139,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// @brief Intercept function for urCommandBufferRetainCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -9161,6 +9161,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( result = replaceCallback(¶ms); } else { + mock::retainDummyHandle(hCommand); result = UR_RESULT_SUCCESS; } diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index 9f0f302399..e20a4dec25 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -7135,7 +7135,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// @brief Intercept function for urCommandBufferRetainCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) { auto pfnRetainCommandExp = context.urDdiTable.CommandBufferExp.pfnRetainCommandExp; diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 163dc76d03..d1029ea7ed 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -8722,7 +8722,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// @brief Intercept function for urCommandBufferRetainCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) { auto pfnRetainCommandExp = context.urDdiTable.CommandBufferExp.pfnRetainCommandExp; diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index b389b24713..ac6b9b2915 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -7492,7 +7492,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// @brief Intercept function for urCommandBufferRetainCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -7514,6 +7514,12 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( // forward to device-platform result = pfnRetainCommandExp(hCommand); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 0c6097fa86..95cf231ac4 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -8173,7 +8173,7 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) try { auto pfnRetainCommandExp = ur_lib::context->urDdiTable.CommandBufferExp.pfnRetainCommandExp; diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 46296ae4de..0460c3d663 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -6928,7 +6928,7 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][retain] Handle of the command-buffer command. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; diff --git a/test/conformance/CMakeLists.txt b/test/conformance/CMakeLists.txt index e7bc86f048..894ff93632 100644 --- a/test/conformance/CMakeLists.txt +++ b/test/conformance/CMakeLists.txt @@ -87,7 +87,7 @@ function(add_conformance_test name) if(NOT (UR_BUILD_ADAPTER_CUDA OR UR_BUILD_ADAPTER_HIP OR UR_BUILD_ADAPTER_L0 OR UR_BUILD_ADAPTER_OPENCL OR UR_BUILD_ADAPTER_NATIVE_CPU OR UR_BUILD_ADAPTER_ALL)) - add_test_adapter(${name} adapter_null) + add_test_adapter(${name} adapter_mock) endif() endfunction() diff --git a/test/fuzz/README.md b/test/fuzz/README.md index 9acc3f57ac..a919237e11 100644 --- a/test/fuzz/README.md +++ b/test/fuzz/README.md @@ -9,7 +9,7 @@ which provides the path where any new corpus will be saved. The path has to exis It's worth running the test with tracing enabled while picking scenarios to be added to the repository for future short fuzz tests runs. Example of running the test with generating new corpus files: ``` -UR_ADAPTERS_FORCE_LOAD=build/lib/libur_adapter_null.so \ +UR_ADAPTERS_FORCE_LOAD=build/lib/libur_adapter_mock.so \ XPTI_TRACE_ENABLE=1 \ XPTI_FRAMEWORK_DISPATCHER=build/lib/libxptifw.so \ XPTI_SUBSCRIBERS=build/lib/libcollector.so \ @@ -19,7 +19,7 @@ UR_ENABLE_LAYERS=UR_LAYER_TRACING \ Pass path to a corpus file instead to run a single scenario: ``` -UR_ADAPTERS_FORCE_LOAD=build/lib/libur_adapter_null.so \ +UR_ADAPTERS_FORCE_LOAD=build/lib/libur_adapter_mock.so \ XPTI_TRACE_ENABLE=1 \ XPTI_FRAMEWORK_DISPATCHER=build/lib/libxptifw.so \ XPTI_SUBSCRIBERS=build/lib/libcollector.so \ diff --git a/tools/urtrace/urtrace.py b/tools/urtrace/urtrace.py index 48a3ca72d2..afae095489 100755 --- a/tools/urtrace/urtrace.py +++ b/tools/urtrace/urtrace.py @@ -40,13 +40,13 @@ def get_dynamic_library_name(name): epilog='''examples: %(prog)s ./myapp --myapp-arg - %(prog)s --null --profiling --filter ".*(Device|Platform).*" ./hello_world + %(prog)s --mock --profiling --filter ".*(Device|Platform).*" ./hello_world %(prog)s --adapter libur_adapter_cuda.so --begin ./sycl_app''', formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("command", help="Command to run, including arguments.", nargs=argparse.REMAINDER) parser.add_argument("--profiling", help="Measure function execution time.", action="store_true") parser.add_argument("--filter", help="Only trace functions that match the provided regex filter.") -parser.add_argument("--null", help="Force the use of the null adapter.", action="store_true") +parser.add_argument("--mock", help="Force the use of the mock adapter.", action="store_true") parser.add_argument("--adapter", help="Force the use of the provided adapter.", action="append", default=[]) parser.add_argument("--json", help="Write output in a JSON Trace Event Format.", action="store_true") group = parser.add_mutually_exclusive_group() @@ -113,12 +113,12 @@ def get_dynamic_library_name(name): force_load = None -if args.null: - null_lib = get_dynamic_library_name("ur_adapter_null") - null_adapter = find_library(args.libpath, null_lib, args.recursive) - if null_adapter is None: - sys.exit("unable to find the null adapter - " + null_lib) - force_load = "\"" + null_adapter + "\"" +if args.mock: + mock_lib = get_dynamic_library_name("ur_adapter_mock") + mock_adapter = find_library(args.libpath, mock_lib, args.recursive) + if mock_adapter is None: + sys.exit("unable to find the mock adapter - " + mock_lib) + force_load = "\"" + mock_adapter + "\"" for adapter in args.adapter: adapter_path = find_library(args.libpath, adapter, args.recursive) if is_filename(adapter) else adapter From 088c6c6f063376a7009743ef752709ffa3282d81 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 2 Jul 2024 14:40:54 +0100 Subject: [PATCH 09/15] Make callbacks singleton into a helper function --- include/ur_api.h | 3 +- scripts/core/INTRO.rst | 12 +- scripts/core/loader.yml | 2 +- scripts/templates/mockddi.cpp.mako | 6 +- source/adapters/mock/ur_mockddi.cpp | 1252 ++++++++++++++------------- source/loader/ur_libapi.cpp | 3 +- source/mock/ur_mock_helpers.cpp | 5 +- source/mock/ur_mock_helpers.hpp | 8 +- source/ur_api.cpp | 3 +- test/mock/mock.cpp | 8 +- 10 files changed, 690 insertions(+), 612 deletions(-) diff --git a/include/ur_api.h b/include/ur_api.h index b8ceb18997..861984d36f 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -750,7 +750,8 @@ urLoaderConfigSetCodeLocationCallback( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the mock::callbacks object. +/// the mock library and using the object accessed via +/// mock::getCallbacks(). /// /// @returns /// - ::UR_RESULT_SUCCESS diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index 07e9ae65f1..56f1a7db19 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -269,12 +269,12 @@ handles that track a reference count, and ``Retain``/``Release`` entry points wi function as expected when used with these handles. The behavior of the mock adapter can be customized by linking the -``unified-runtime::mock`` library and making use of the ``mock::callbacks`` -object. Callbacks can be passed into this object to run either before or after a -given entry point, or they can be set to entirely replace the default behavior. -Only one callback of each type (before, replace, after) can be set per entry -point, with subsequent callbacks set in the same "slot" overwriting any set -previously. +``unified-runtime::mock`` library and making use of the object accessed via the +``mock::getCallbacks()`` helper. Callbacks can be passed into this object to +run either before or after a given entry point, or they can be set to entirely +replace the default behavior. Only one callback of each type (before, replace, +after) can be set per entry point, with subsequent callbacks set in the same +"slot" overwriting any set previously. The callback signature defined by ``ur_mock_callback_t`` takes a single ``void *`` parameter. When calling a user callback the adapter will pack the diff --git a/scripts/core/loader.yml b/scripts/core/loader.yml index e94242718c..fc02e60ef4 100644 --- a/scripts/core/loader.yml +++ b/scripts/core/loader.yml @@ -190,7 +190,7 @@ params: type: function desc: "The only adapter reported with mock enabled will be the mock adapter." details: - - "The mock adapter will default to returning $X_RESULT_SUCCESS for all entry points. It will also create and correctly reference count dummy handles where appropriate. Its behaviour can be modified by linking the $x_mock_headers library and using the mock::callbacks object." + - "The mock adapter will default to returning $X_RESULT_SUCCESS for all entry points. It will also create and correctly reference count dummy handles where appropriate. Its behaviour can be modified by linking the mock library and using the object accessed via mock::getCallbacks()." class: $xLoaderConfig loader_only: True name: "SetMockingEnabled" diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako index 8cc7f4db2f..539d3f8bc1 100644 --- a/scripts/templates/mockddi.cpp.mako +++ b/scripts/templates/mockddi.cpp.mako @@ -44,7 +44,7 @@ namespace driver ${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} }; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("${fname}")); + mock::getCallbacks().get_before_callback("${fname}")); if(beforeCallback) { result = beforeCallback( ¶ms ); if(result != UR_RESULT_SUCCESS) { @@ -53,7 +53,7 @@ namespace driver } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("${fname}")); + mock::getCallbacks().get_replace_callback("${fname}")); if(replaceCallback) { result = replaceCallback( ¶ms ); } @@ -130,7 +130,7 @@ namespace driver } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("${fname}")); + mock::getCallbacks().get_after_callback("${fname}")); if(afterCallback) { return afterCallback( ¶ms ); } diff --git a/source/adapters/mock/ur_mockddi.cpp b/source/adapters/mock/ur_mockddi.cpp index 07757c4036..22bfe69b85 100644 --- a/source/adapters/mock/ur_mockddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -33,7 +33,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( ur_adapter_get_params_t params = {&NumEntries, &phAdapters, &pNumAdapters}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urAdapterGet")); + mock::getCallbacks().get_before_callback("urAdapterGet")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -42,7 +42,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urAdapterGet")); + mock::getCallbacks().get_replace_callback("urAdapterGet")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -62,7 +62,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urAdapterGet")); + mock::getCallbacks().get_after_callback("urAdapterGet")); if (afterCallback) { return afterCallback(¶ms); } @@ -82,7 +82,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( ur_adapter_release_params_t params = {&hAdapter}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urAdapterRelease")); + mock::getCallbacks().get_before_callback("urAdapterRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -91,7 +91,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urAdapterRelease")); + mock::getCallbacks().get_replace_callback("urAdapterRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -105,7 +105,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urAdapterRelease")); + mock::getCallbacks().get_after_callback("urAdapterRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -125,7 +125,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( ur_adapter_retain_params_t params = {&hAdapter}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urAdapterRetain")); + mock::getCallbacks().get_before_callback("urAdapterRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -134,7 +134,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urAdapterRetain")); + mock::getCallbacks().get_replace_callback("urAdapterRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -148,7 +148,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urAdapterRetain")); + mock::getCallbacks().get_after_callback("urAdapterRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -175,7 +175,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetLastError( &pError}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urAdapterGetLastError")); + mock::getCallbacks().get_before_callback("urAdapterGetLastError")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -184,7 +184,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetLastError( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urAdapterGetLastError")); + mock::getCallbacks().get_replace_callback("urAdapterGetLastError")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -197,7 +197,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetLastError( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urAdapterGetLastError")); + mock::getCallbacks().get_after_callback("urAdapterGetLastError")); if (afterCallback) { return afterCallback(¶ms); } @@ -228,7 +228,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urAdapterGetInfo")); + mock::getCallbacks().get_before_callback("urAdapterGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -237,7 +237,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urAdapterGetInfo")); + mock::getCallbacks().get_replace_callback("urAdapterGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -250,7 +250,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urAdapterGetInfo")); + mock::getCallbacks().get_after_callback("urAdapterGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -284,7 +284,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( &phPlatforms, &pNumPlatforms}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPlatformGet")); + mock::getCallbacks().get_before_callback("urPlatformGet")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -293,7 +293,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPlatformGet")); + mock::getCallbacks().get_replace_callback("urPlatformGet")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -313,7 +313,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformGet")); + mock::getCallbacks().get_after_callback("urPlatformGet")); if (afterCallback) { return afterCallback(¶ms); } @@ -344,7 +344,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPlatformGetInfo")); + mock::getCallbacks().get_before_callback("urPlatformGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -353,7 +353,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPlatformGetInfo")); + mock::getCallbacks().get_replace_callback("urPlatformGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -366,7 +366,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformGetInfo")); + mock::getCallbacks().get_after_callback("urPlatformGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -387,7 +387,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetApiVersion( ur_platform_get_api_version_params_t params = {&hPlatform, &pVersion}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPlatformGetApiVersion")); + mock::getCallbacks().get_before_callback("urPlatformGetApiVersion")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -396,7 +396,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetApiVersion( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPlatformGetApiVersion")); + mock::getCallbacks().get_replace_callback("urPlatformGetApiVersion")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -409,7 +409,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetApiVersion( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformGetApiVersion")); + mock::getCallbacks().get_after_callback("urPlatformGetApiVersion")); if (afterCallback) { return afterCallback(¶ms); } @@ -432,7 +432,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( &phNativePlatform}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPlatformGetNativeHandle")); + mock::getCallbacks().get_before_callback("urPlatformGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -441,7 +441,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPlatformGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urPlatformGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -455,7 +455,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformGetNativeHandle")); + mock::getCallbacks().get_after_callback("urPlatformGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -483,7 +483,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( &hNativePlatform, &hAdapter, &pProperties, &phPlatform}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urPlatformCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -493,7 +493,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urPlatformCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -509,7 +509,8 @@ __urdlllocal ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urPlatformCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -535,7 +536,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetBackendOption( &hPlatform, &pFrontendOption, &ppPlatformOption}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPlatformGetBackendOption")); + mock::getCallbacks().get_before_callback("urPlatformGetBackendOption")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -544,7 +545,8 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetBackendOption( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPlatformGetBackendOption")); + mock::getCallbacks().get_replace_callback( + "urPlatformGetBackendOption")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -557,7 +559,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetBackendOption( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPlatformGetBackendOption")); + mock::getCallbacks().get_after_callback("urPlatformGetBackendOption")); if (afterCallback) { return afterCallback(¶ms); } @@ -590,7 +592,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet( &phDevices, &pNumDevices}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceGet")); + mock::getCallbacks().get_before_callback("urDeviceGet")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -599,7 +601,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceGet")); + mock::getCallbacks().get_replace_callback("urDeviceGet")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -619,7 +621,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceGet")); + mock::getCallbacks().get_after_callback("urDeviceGet")); if (afterCallback) { return afterCallback(¶ms); } @@ -651,7 +653,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceGetInfo")); + mock::getCallbacks().get_before_callback("urDeviceGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -660,7 +662,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceGetInfo")); + mock::getCallbacks().get_replace_callback("urDeviceGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -673,7 +675,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceGetInfo")); + mock::getCallbacks().get_after_callback("urDeviceGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -694,7 +696,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_retain_params_t params = {&hDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceRetain")); + mock::getCallbacks().get_before_callback("urDeviceRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -703,7 +705,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceRetain")); + mock::getCallbacks().get_replace_callback("urDeviceRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -717,7 +719,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceRetain")); + mock::getCallbacks().get_after_callback("urDeviceRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -738,7 +740,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( ur_device_release_params_t params = {&hDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceRelease")); + mock::getCallbacks().get_before_callback("urDeviceRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -747,7 +749,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceRelease")); + mock::getCallbacks().get_replace_callback("urDeviceRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -761,7 +763,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceRelease")); + mock::getCallbacks().get_after_callback("urDeviceRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -792,7 +794,7 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition( &phSubDevices, &pNumDevicesRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDevicePartition")); + mock::getCallbacks().get_before_callback("urDevicePartition")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -801,7 +803,7 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDevicePartition")); + mock::getCallbacks().get_replace_callback("urDevicePartition")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -818,7 +820,7 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDevicePartition")); + mock::getCallbacks().get_after_callback("urDevicePartition")); if (afterCallback) { return afterCallback(¶ms); } @@ -848,7 +850,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceSelectBinary( &NumBinaries, &pSelectedBinary}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceSelectBinary")); + mock::getCallbacks().get_before_callback("urDeviceSelectBinary")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -857,7 +859,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceSelectBinary( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceSelectBinary")); + mock::getCallbacks().get_replace_callback("urDeviceSelectBinary")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -870,7 +872,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceSelectBinary( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceSelectBinary")); + mock::getCallbacks().get_after_callback("urDeviceSelectBinary")); if (afterCallback) { return afterCallback(¶ms); } @@ -892,7 +894,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( ur_device_get_native_handle_params_t params = {&hDevice, &phNativeDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceGetNativeHandle")); + mock::getCallbacks().get_before_callback("urDeviceGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -901,7 +903,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urDeviceGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -915,7 +917,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceGetNativeHandle")); + mock::getCallbacks().get_after_callback("urDeviceGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -942,7 +944,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( &hNativeDevice, &hPlatform, &pProperties, &phDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urDeviceCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -951,7 +954,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceCreateWithNativeHandle")); + mock::getCallbacks().get_replace_callback( + "urDeviceCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -966,7 +970,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urDeviceCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -993,7 +998,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetGlobalTimestamps( &hDevice, &pDeviceTimestamp, &pHostTimestamp}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urDeviceGetGlobalTimestamps")); + mock::getCallbacks().get_before_callback( + "urDeviceGetGlobalTimestamps")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1002,7 +1008,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetGlobalTimestamps( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urDeviceGetGlobalTimestamps")); + mock::getCallbacks().get_replace_callback( + "urDeviceGetGlobalTimestamps")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1015,7 +1022,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetGlobalTimestamps( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urDeviceGetGlobalTimestamps")); + mock::getCallbacks().get_after_callback("urDeviceGetGlobalTimestamps")); if (afterCallback) { return afterCallback(¶ms); } @@ -1042,7 +1049,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( &phContext}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextCreate")); + mock::getCallbacks().get_before_callback("urContextCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1051,7 +1058,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextCreate")); + mock::getCallbacks().get_replace_callback("urContextCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1065,7 +1072,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextCreate")); + mock::getCallbacks().get_after_callback("urContextCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -1086,7 +1093,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_retain_params_t params = {&hContext}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextRetain")); + mock::getCallbacks().get_before_callback("urContextRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1095,7 +1102,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextRetain")); + mock::getCallbacks().get_replace_callback("urContextRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1109,7 +1116,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextRetain")); + mock::getCallbacks().get_after_callback("urContextRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -1130,7 +1137,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( ur_context_release_params_t params = {&hContext}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextRelease")); + mock::getCallbacks().get_before_callback("urContextRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1139,7 +1146,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextRelease")); + mock::getCallbacks().get_replace_callback("urContextRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1153,7 +1160,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextRelease")); + mock::getCallbacks().get_after_callback("urContextRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -1186,7 +1193,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextGetInfo")); + mock::getCallbacks().get_before_callback("urContextGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1195,7 +1202,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextGetInfo")); + mock::getCallbacks().get_replace_callback("urContextGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1208,7 +1215,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextGetInfo")); + mock::getCallbacks().get_after_callback("urContextGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -1231,7 +1238,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( &phNativeContext}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextGetNativeHandle")); + mock::getCallbacks().get_before_callback("urContextGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1240,7 +1247,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urContextGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1254,7 +1261,7 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextGetNativeHandle")); + mock::getCallbacks().get_after_callback("urContextGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -1283,7 +1290,8 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( &hNativeContext, &numDevices, &phDevices, &pProperties, &phContext}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urContextCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1292,7 +1300,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urContextCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -1308,7 +1316,8 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urContextCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -1333,7 +1342,8 @@ __urdlllocal ur_result_t UR_APICALL urContextSetExtendedDeleter( &pUserData}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urContextSetExtendedDeleter")); + mock::getCallbacks().get_before_callback( + "urContextSetExtendedDeleter")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1342,7 +1352,8 @@ __urdlllocal ur_result_t UR_APICALL urContextSetExtendedDeleter( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urContextSetExtendedDeleter")); + mock::getCallbacks().get_replace_callback( + "urContextSetExtendedDeleter")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1355,7 +1366,7 @@ __urdlllocal ur_result_t UR_APICALL urContextSetExtendedDeleter( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urContextSetExtendedDeleter")); + mock::getCallbacks().get_after_callback("urContextSetExtendedDeleter")); if (afterCallback) { return afterCallback(¶ms); } @@ -1382,7 +1393,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreate( &pImageDesc, &pHost, &phMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemImageCreate")); + mock::getCallbacks().get_before_callback("urMemImageCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1391,7 +1402,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemImageCreate")); + mock::getCallbacks().get_replace_callback("urMemImageCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1405,7 +1416,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemImageCreate")); + mock::getCallbacks().get_after_callback("urMemImageCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -1432,7 +1443,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( &pProperties, &phBuffer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemBufferCreate")); + mock::getCallbacks().get_before_callback("urMemBufferCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1441,7 +1452,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemBufferCreate")); + mock::getCallbacks().get_replace_callback("urMemBufferCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1461,7 +1472,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemBufferCreate")); + mock::getCallbacks().get_after_callback("urMemBufferCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -1482,7 +1493,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( ur_mem_retain_params_t params = {&hMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemRetain")); + mock::getCallbacks().get_before_callback("urMemRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1491,7 +1502,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemRetain")); + mock::getCallbacks().get_replace_callback("urMemRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1505,7 +1516,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemRetain")); + mock::getCallbacks().get_after_callback("urMemRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -1526,7 +1537,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( ur_mem_release_params_t params = {&hMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemRelease")); + mock::getCallbacks().get_before_callback("urMemRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1535,7 +1546,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemRelease")); + mock::getCallbacks().get_replace_callback("urMemRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1549,7 +1560,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemRelease")); + mock::getCallbacks().get_after_callback("urMemRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -1577,7 +1588,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( &hBuffer, &flags, &bufferCreateType, &pRegion, &phMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemBufferPartition")); + mock::getCallbacks().get_before_callback("urMemBufferPartition")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1586,7 +1597,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemBufferPartition")); + mock::getCallbacks().get_replace_callback("urMemBufferPartition")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1600,7 +1611,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemBufferPartition")); + mock::getCallbacks().get_after_callback("urMemBufferPartition")); if (afterCallback) { return afterCallback(¶ms); } @@ -1625,7 +1636,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_get_native_handle_params_t params = {&hMem, &hDevice, &phNativeMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemGetNativeHandle")); + mock::getCallbacks().get_before_callback("urMemGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1634,7 +1645,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urMemGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1648,7 +1659,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemGetNativeHandle")); + mock::getCallbacks().get_after_callback("urMemGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -1675,7 +1686,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( &hNativeMem, &hContext, &pProperties, &phMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urMemBufferCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -1685,7 +1696,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urMemBufferCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -1700,8 +1711,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urMemBufferCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); @@ -1733,7 +1744,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( &pImageDesc, &pProperties, &phMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urMemImageCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -1743,7 +1754,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urMemImageCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -1759,7 +1770,8 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemImageCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urMemImageCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -1792,7 +1804,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemGetInfo")); + mock::getCallbacks().get_before_callback("urMemGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1801,7 +1813,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemGetInfo")); + mock::getCallbacks().get_replace_callback("urMemGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1814,7 +1826,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemGetInfo")); + mock::getCallbacks().get_after_callback("urMemGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -1846,7 +1858,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urMemImageGetInfo")); + mock::getCallbacks().get_before_callback("urMemImageGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1855,7 +1867,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urMemImageGetInfo")); + mock::getCallbacks().get_replace_callback("urMemImageGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1868,7 +1880,7 @@ __urdlllocal ur_result_t UR_APICALL urMemImageGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urMemImageGetInfo")); + mock::getCallbacks().get_after_callback("urMemImageGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -1891,7 +1903,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( ur_sampler_create_params_t params = {&hContext, &pDesc, &phSampler}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerCreate")); + mock::getCallbacks().get_before_callback("urSamplerCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1900,7 +1912,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urSamplerCreate")); + mock::getCallbacks().get_replace_callback("urSamplerCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1914,7 +1926,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerCreate")); + mock::getCallbacks().get_after_callback("urSamplerCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -1935,7 +1947,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_retain_params_t params = {&hSampler}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerRetain")); + mock::getCallbacks().get_before_callback("urSamplerRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1944,7 +1956,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urSamplerRetain")); + mock::getCallbacks().get_replace_callback("urSamplerRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -1958,7 +1970,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerRetain")); + mock::getCallbacks().get_after_callback("urSamplerRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -1979,7 +1991,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ur_sampler_release_params_t params = {&hSampler}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerRelease")); + mock::getCallbacks().get_before_callback("urSamplerRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -1988,7 +2000,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urSamplerRelease")); + mock::getCallbacks().get_replace_callback("urSamplerRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2002,7 +2014,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerRelease")); + mock::getCallbacks().get_after_callback("urSamplerRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -2031,7 +2043,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerGetInfo")); + mock::getCallbacks().get_before_callback("urSamplerGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2040,7 +2052,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urSamplerGetInfo")); + mock::getCallbacks().get_replace_callback("urSamplerGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2053,7 +2065,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerGetInfo")); + mock::getCallbacks().get_after_callback("urSamplerGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -2076,7 +2088,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( &phNativeSampler}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerGetNativeHandle")); + mock::getCallbacks().get_before_callback("urSamplerGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2085,7 +2097,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urSamplerGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urSamplerGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2099,7 +2111,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerGetNativeHandle")); + mock::getCallbacks().get_after_callback("urSamplerGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -2126,7 +2138,8 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( &hNativeSampler, &hContext, &pProperties, &phSampler}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urSamplerCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urSamplerCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2135,7 +2148,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urSamplerCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -2151,7 +2164,8 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urSamplerCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urSamplerCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -2179,7 +2193,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( &ppMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMHostAlloc")); + mock::getCallbacks().get_before_callback("urUSMHostAlloc")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2188,7 +2202,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMHostAlloc")); + mock::getCallbacks().get_replace_callback("urUSMHostAlloc")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2202,7 +2216,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMHostAlloc")); + mock::getCallbacks().get_after_callback("urUSMHostAlloc")); if (afterCallback) { return afterCallback(¶ms); } @@ -2231,7 +2245,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( &pool, &size, &ppMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMDeviceAlloc")); + mock::getCallbacks().get_before_callback("urUSMDeviceAlloc")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2240,7 +2254,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMDeviceAlloc")); + mock::getCallbacks().get_replace_callback("urUSMDeviceAlloc")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2254,7 +2268,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMDeviceAlloc")); + mock::getCallbacks().get_after_callback("urUSMDeviceAlloc")); if (afterCallback) { return afterCallback(¶ms); } @@ -2283,7 +2297,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( &pool, &size, &ppMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMSharedAlloc")); + mock::getCallbacks().get_before_callback("urUSMSharedAlloc")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2292,7 +2306,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMSharedAlloc")); + mock::getCallbacks().get_replace_callback("urUSMSharedAlloc")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2306,7 +2320,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMSharedAlloc")); + mock::getCallbacks().get_after_callback("urUSMSharedAlloc")); if (afterCallback) { return afterCallback(¶ms); } @@ -2327,7 +2341,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMFree( ur_usm_free_params_t params = {&hContext, &pMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMFree")); + mock::getCallbacks().get_before_callback("urUSMFree")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2336,7 +2350,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMFree( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMFree")); + mock::getCallbacks().get_replace_callback("urUSMFree")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2349,7 +2363,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMFree( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMFree")); + mock::getCallbacks().get_after_callback("urUSMFree")); if (afterCallback) { return afterCallback(¶ms); } @@ -2380,7 +2394,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( &hContext, &pMem, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMGetMemAllocInfo")); + mock::getCallbacks().get_before_callback("urUSMGetMemAllocInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2389,7 +2403,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMGetMemAllocInfo")); + mock::getCallbacks().get_replace_callback("urUSMGetMemAllocInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2402,7 +2416,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMGetMemAllocInfo")); + mock::getCallbacks().get_after_callback("urUSMGetMemAllocInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -2426,7 +2440,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( ur_usm_pool_create_params_t params = {&hContext, &pPoolDesc, &ppPool}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMPoolCreate")); + mock::getCallbacks().get_before_callback("urUSMPoolCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2435,7 +2449,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMPoolCreate")); + mock::getCallbacks().get_replace_callback("urUSMPoolCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2449,7 +2463,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMPoolCreate")); + mock::getCallbacks().get_after_callback("urUSMPoolCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -2469,7 +2483,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( ur_usm_pool_retain_params_t params = {&pPool}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMPoolRetain")); + mock::getCallbacks().get_before_callback("urUSMPoolRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2478,7 +2492,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMPoolRetain")); + mock::getCallbacks().get_replace_callback("urUSMPoolRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2492,7 +2506,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMPoolRetain")); + mock::getCallbacks().get_after_callback("urUSMPoolRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -2512,7 +2526,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( ur_usm_pool_release_params_t params = {&pPool}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMPoolRelease")); + mock::getCallbacks().get_before_callback("urUSMPoolRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2521,7 +2535,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMPoolRelease")); + mock::getCallbacks().get_replace_callback("urUSMPoolRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2535,7 +2549,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMPoolRelease")); + mock::getCallbacks().get_after_callback("urUSMPoolRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -2563,7 +2577,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMPoolGetInfo")); + mock::getCallbacks().get_before_callback("urUSMPoolGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2572,7 +2586,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMPoolGetInfo")); + mock::getCallbacks().get_replace_callback("urUSMPoolGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2585,7 +2599,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMPoolGetInfo")); + mock::getCallbacks().get_after_callback("urUSMPoolGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -2620,7 +2634,8 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGranularityGetInfo( &hContext, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemGranularityGetInfo")); + mock::getCallbacks().get_before_callback( + "urVirtualMemGranularityGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2629,7 +2644,8 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGranularityGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemGranularityGetInfo")); + mock::getCallbacks().get_replace_callback( + "urVirtualMemGranularityGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2642,7 +2658,8 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGranularityGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemGranularityGetInfo")); + mock::getCallbacks().get_after_callback( + "urVirtualMemGranularityGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -2672,7 +2689,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemReserve( &ppStart}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemReserve")); + mock::getCallbacks().get_before_callback("urVirtualMemReserve")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2681,7 +2698,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemReserve( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemReserve")); + mock::getCallbacks().get_replace_callback("urVirtualMemReserve")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2694,7 +2711,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemReserve( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemReserve")); + mock::getCallbacks().get_after_callback("urVirtualMemReserve")); if (afterCallback) { return afterCallback(¶ms); } @@ -2717,7 +2734,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemFree( ur_virtual_mem_free_params_t params = {&hContext, &pStart, &size}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemFree")); + mock::getCallbacks().get_before_callback("urVirtualMemFree")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2726,7 +2743,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemFree( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemFree")); + mock::getCallbacks().get_replace_callback("urVirtualMemFree")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2739,7 +2756,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemFree( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemFree")); + mock::getCallbacks().get_after_callback("urVirtualMemFree")); if (afterCallback) { return afterCallback(¶ms); } @@ -2769,7 +2786,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemMap( &hPhysicalMem, &offset, &flags}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemMap")); + mock::getCallbacks().get_before_callback("urVirtualMemMap")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2778,7 +2795,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemMap( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemMap")); + mock::getCallbacks().get_replace_callback("urVirtualMemMap")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2791,7 +2808,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemMap( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemMap")); + mock::getCallbacks().get_after_callback("urVirtualMemMap")); if (afterCallback) { return afterCallback(¶ms); } @@ -2814,7 +2831,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemUnmap( ur_virtual_mem_unmap_params_t params = {&hContext, &pStart, &size}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemUnmap")); + mock::getCallbacks().get_before_callback("urVirtualMemUnmap")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2823,7 +2840,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemUnmap( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemUnmap")); + mock::getCallbacks().get_replace_callback("urVirtualMemUnmap")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2836,7 +2853,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemUnmap( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemUnmap")); + mock::getCallbacks().get_after_callback("urVirtualMemUnmap")); if (afterCallback) { return afterCallback(¶ms); } @@ -2862,7 +2879,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemSetAccess( &flags}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemSetAccess")); + mock::getCallbacks().get_before_callback("urVirtualMemSetAccess")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2871,7 +2888,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemSetAccess( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemSetAccess")); + mock::getCallbacks().get_replace_callback("urVirtualMemSetAccess")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2884,7 +2901,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemSetAccess( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemSetAccess")); + mock::getCallbacks().get_after_callback("urVirtualMemSetAccess")); if (afterCallback) { return afterCallback(¶ms); } @@ -2919,7 +2936,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGetInfo( &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urVirtualMemGetInfo")); + mock::getCallbacks().get_before_callback("urVirtualMemGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2928,7 +2945,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urVirtualMemGetInfo")); + mock::getCallbacks().get_replace_callback("urVirtualMemGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2941,7 +2958,7 @@ __urdlllocal ur_result_t UR_APICALL urVirtualMemGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urVirtualMemGetInfo")); + mock::getCallbacks().get_after_callback("urVirtualMemGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -2970,7 +2987,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( &pProperties, &phPhysicalMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPhysicalMemCreate")); + mock::getCallbacks().get_before_callback("urPhysicalMemCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -2979,7 +2996,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPhysicalMemCreate")); + mock::getCallbacks().get_replace_callback("urPhysicalMemCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -2993,7 +3010,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPhysicalMemCreate")); + mock::getCallbacks().get_after_callback("urPhysicalMemCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -3014,7 +3031,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_retain_params_t params = {&hPhysicalMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPhysicalMemRetain")); + mock::getCallbacks().get_before_callback("urPhysicalMemRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3023,7 +3040,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPhysicalMemRetain")); + mock::getCallbacks().get_replace_callback("urPhysicalMemRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3037,7 +3054,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPhysicalMemRetain")); + mock::getCallbacks().get_after_callback("urPhysicalMemRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -3058,7 +3075,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_release_params_t params = {&hPhysicalMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urPhysicalMemRelease")); + mock::getCallbacks().get_before_callback("urPhysicalMemRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3067,7 +3084,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urPhysicalMemRelease")); + mock::getCallbacks().get_replace_callback("urPhysicalMemRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3081,7 +3098,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urPhysicalMemRelease")); + mock::getCallbacks().get_after_callback("urPhysicalMemRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -3108,7 +3125,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( &pProperties, &phProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramCreateWithIL")); + mock::getCallbacks().get_before_callback("urProgramCreateWithIL")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3117,7 +3134,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramCreateWithIL")); + mock::getCallbacks().get_replace_callback("urProgramCreateWithIL")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3131,7 +3148,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramCreateWithIL")); + mock::getCallbacks().get_after_callback("urProgramCreateWithIL")); if (afterCallback) { return afterCallback(¶ms); } @@ -3160,7 +3177,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( &hContext, &hDevice, &size, &pBinary, &pProperties, &phProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramCreateWithBinary")); + mock::getCallbacks().get_before_callback("urProgramCreateWithBinary")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3169,7 +3186,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramCreateWithBinary")); + mock::getCallbacks().get_replace_callback("urProgramCreateWithBinary")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3183,7 +3200,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramCreateWithBinary")); + mock::getCallbacks().get_after_callback("urProgramCreateWithBinary")); if (afterCallback) { return afterCallback(¶ms); } @@ -3206,7 +3223,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuild( ur_program_build_params_t params = {&hContext, &hProgram, &pOptions}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramBuild")); + mock::getCallbacks().get_before_callback("urProgramBuild")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3215,7 +3232,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuild( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramBuild")); + mock::getCallbacks().get_replace_callback("urProgramBuild")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3228,7 +3245,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuild( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramBuild")); + mock::getCallbacks().get_after_callback("urProgramBuild")); if (afterCallback) { return afterCallback(¶ms); } @@ -3252,7 +3269,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompile( ur_program_compile_params_t params = {&hContext, &hProgram, &pOptions}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramCompile")); + mock::getCallbacks().get_before_callback("urProgramCompile")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3261,7 +3278,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompile( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramCompile")); + mock::getCallbacks().get_replace_callback("urProgramCompile")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3274,7 +3291,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompile( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramCompile")); + mock::getCallbacks().get_after_callback("urProgramCompile")); if (afterCallback) { return afterCallback(¶ms); } @@ -3305,7 +3322,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( &pOptions, &phProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramLink")); + mock::getCallbacks().get_before_callback("urProgramLink")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3314,7 +3331,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramLink")); + mock::getCallbacks().get_replace_callback("urProgramLink")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3328,7 +3345,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramLink")); + mock::getCallbacks().get_after_callback("urProgramLink")); if (afterCallback) { return afterCallback(¶ms); } @@ -3349,7 +3366,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( ur_program_retain_params_t params = {&hProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramRetain")); + mock::getCallbacks().get_before_callback("urProgramRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3358,7 +3375,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramRetain")); + mock::getCallbacks().get_replace_callback("urProgramRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3372,7 +3389,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramRetain")); + mock::getCallbacks().get_after_callback("urProgramRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -3393,7 +3410,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( ur_program_release_params_t params = {&hProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramRelease")); + mock::getCallbacks().get_before_callback("urProgramRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3402,7 +3419,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramRelease")); + mock::getCallbacks().get_replace_callback("urProgramRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3416,7 +3433,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramRelease")); + mock::getCallbacks().get_after_callback("urProgramRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -3446,7 +3463,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer( &hDevice, &hProgram, &pFunctionName, &ppFunctionPointer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramGetFunctionPointer")); + mock::getCallbacks().get_before_callback( + "urProgramGetFunctionPointer")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3455,7 +3473,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramGetFunctionPointer")); + mock::getCallbacks().get_replace_callback( + "urProgramGetFunctionPointer")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3468,7 +3487,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramGetFunctionPointer")); + mock::getCallbacks().get_after_callback("urProgramGetFunctionPointer")); if (afterCallback) { return afterCallback(¶ms); } @@ -3500,7 +3519,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( &ppGlobalVariablePointerRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urProgramGetGlobalVariablePointer")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -3510,7 +3529,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urProgramGetGlobalVariablePointer")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -3523,8 +3542,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urProgramGetGlobalVariablePointer")); if (afterCallback) { return afterCallback(¶ms); @@ -3557,7 +3576,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramGetInfo")); + mock::getCallbacks().get_before_callback("urProgramGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3566,7 +3585,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramGetInfo")); + mock::getCallbacks().get_replace_callback("urProgramGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3579,7 +3598,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramGetInfo")); + mock::getCallbacks().get_after_callback("urProgramGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -3613,7 +3632,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetBuildInfo( &hProgram, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramGetBuildInfo")); + mock::getCallbacks().get_before_callback("urProgramGetBuildInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3622,7 +3641,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetBuildInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramGetBuildInfo")); + mock::getCallbacks().get_replace_callback("urProgramGetBuildInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3635,7 +3654,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetBuildInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramGetBuildInfo")); + mock::getCallbacks().get_after_callback("urProgramGetBuildInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -3660,7 +3679,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramSetSpecializationConstants( &hProgram, &count, &pSpecConstants}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urProgramSetSpecializationConstants")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -3670,7 +3689,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramSetSpecializationConstants( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urProgramSetSpecializationConstants")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -3683,8 +3702,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramSetSpecializationConstants( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urProgramSetSpecializationConstants")); if (afterCallback) { return afterCallback(¶ms); @@ -3708,7 +3727,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( &phNativeProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramGetNativeHandle")); + mock::getCallbacks().get_before_callback("urProgramGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3717,7 +3736,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urProgramGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3731,7 +3750,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramGetNativeHandle")); + mock::getCallbacks().get_after_callback("urProgramGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -3758,7 +3777,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( &hNativeProgram, &hContext, &pProperties, &phProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urProgramCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3767,7 +3787,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urProgramCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -3783,7 +3803,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urProgramCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -3806,7 +3827,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate( ur_kernel_create_params_t params = {&hProgram, &pKernelName, &phKernel}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelCreate")); + mock::getCallbacks().get_before_callback("urKernelCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3815,7 +3836,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelCreate")); + mock::getCallbacks().get_replace_callback("urKernelCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3829,7 +3850,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelCreate")); + mock::getCallbacks().get_after_callback("urKernelCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -3856,7 +3877,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( &pProperties, &pArgValue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetArgValue")); + mock::getCallbacks().get_before_callback("urKernelSetArgValue")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3865,7 +3886,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetArgValue")); + mock::getCallbacks().get_replace_callback("urKernelSetArgValue")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3878,7 +3899,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetArgValue")); + mock::getCallbacks().get_after_callback("urKernelSetArgValue")); if (afterCallback) { return afterCallback(¶ms); } @@ -3904,7 +3925,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( &pProperties}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetArgLocal")); + mock::getCallbacks().get_before_callback("urKernelSetArgLocal")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3913,7 +3934,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetArgLocal")); + mock::getCallbacks().get_replace_callback("urKernelSetArgLocal")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3926,7 +3947,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetArgLocal")); + mock::getCallbacks().get_after_callback("urKernelSetArgLocal")); if (afterCallback) { return afterCallback(¶ms); } @@ -3959,7 +3980,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelGetInfo")); + mock::getCallbacks().get_before_callback("urKernelGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -3968,7 +3989,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelGetInfo")); + mock::getCallbacks().get_replace_callback("urKernelGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -3981,7 +4002,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelGetInfo")); + mock::getCallbacks().get_after_callback("urKernelGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -4012,7 +4033,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetGroupInfo( &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelGetGroupInfo")); + mock::getCallbacks().get_before_callback("urKernelGetGroupInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4021,7 +4042,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetGroupInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelGetGroupInfo")); + mock::getCallbacks().get_replace_callback("urKernelGetGroupInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4034,7 +4055,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetGroupInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelGetGroupInfo")); + mock::getCallbacks().get_after_callback("urKernelGetGroupInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -4065,7 +4086,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelGetSubGroupInfo")); + mock::getCallbacks().get_before_callback("urKernelGetSubGroupInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4074,7 +4095,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelGetSubGroupInfo")); + mock::getCallbacks().get_replace_callback("urKernelGetSubGroupInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4087,7 +4108,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelGetSubGroupInfo")); + mock::getCallbacks().get_after_callback("urKernelGetSubGroupInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -4107,7 +4128,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( ur_kernel_retain_params_t params = {&hKernel}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelRetain")); + mock::getCallbacks().get_before_callback("urKernelRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4116,7 +4137,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelRetain")); + mock::getCallbacks().get_replace_callback("urKernelRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4130,7 +4151,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelRetain")); + mock::getCallbacks().get_after_callback("urKernelRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -4151,7 +4172,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( ur_kernel_release_params_t params = {&hKernel}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelRelease")); + mock::getCallbacks().get_before_callback("urKernelRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4160,7 +4181,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelRelease")); + mock::getCallbacks().get_replace_callback("urKernelRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4174,7 +4195,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelRelease")); + mock::getCallbacks().get_after_callback("urKernelRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -4201,7 +4222,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( &pProperties, &pArgValue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetArgPointer")); + mock::getCallbacks().get_before_callback("urKernelSetArgPointer")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4210,7 +4231,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetArgPointer")); + mock::getCallbacks().get_replace_callback("urKernelSetArgPointer")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4223,7 +4244,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetArgPointer")); + mock::getCallbacks().get_after_callback("urKernelSetArgPointer")); if (afterCallback) { return afterCallback(¶ms); } @@ -4251,7 +4272,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetExecInfo( &pProperties, &pPropValue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetExecInfo")); + mock::getCallbacks().get_before_callback("urKernelSetExecInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4260,7 +4281,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetExecInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetExecInfo")); + mock::getCallbacks().get_replace_callback("urKernelSetExecInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4273,7 +4294,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetExecInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetExecInfo")); + mock::getCallbacks().get_after_callback("urKernelSetExecInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -4298,7 +4319,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler( &pProperties, &hArgValue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetArgSampler")); + mock::getCallbacks().get_before_callback("urKernelSetArgSampler")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4307,7 +4328,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetArgSampler")); + mock::getCallbacks().get_replace_callback("urKernelSetArgSampler")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4320,7 +4341,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetArgSampler")); + mock::getCallbacks().get_after_callback("urKernelSetArgSampler")); if (afterCallback) { return afterCallback(¶ms); } @@ -4345,7 +4366,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( &pProperties, &hArgValue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelSetArgMemObj")); + mock::getCallbacks().get_before_callback("urKernelSetArgMemObj")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4354,7 +4375,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelSetArgMemObj")); + mock::getCallbacks().get_replace_callback("urKernelSetArgMemObj")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4367,7 +4388,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelSetArgMemObj")); + mock::getCallbacks().get_after_callback("urKernelSetArgMemObj")); if (afterCallback) { return afterCallback(¶ms); } @@ -4391,7 +4412,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetSpecializationConstants( &pSpecConstants}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urKernelSetSpecializationConstants")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -4401,7 +4422,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetSpecializationConstants( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urKernelSetSpecializationConstants")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -4414,8 +4435,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetSpecializationConstants( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urKernelSetSpecializationConstants")); if (afterCallback) { return afterCallback(¶ms); @@ -4438,7 +4459,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( ur_kernel_get_native_handle_params_t params = {&hKernel, &phNativeKernel}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelGetNativeHandle")); + mock::getCallbacks().get_before_callback("urKernelGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4447,7 +4468,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urKernelGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4461,7 +4482,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelGetNativeHandle")); + mock::getCallbacks().get_after_callback("urKernelGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -4490,7 +4511,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( &hNativeKernel, &hContext, &hProgram, &pProperties, &phKernel}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urKernelCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urKernelCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4499,7 +4521,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urKernelCreateWithNativeHandle")); + mock::getCallbacks().get_replace_callback( + "urKernelCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4514,7 +4537,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urKernelCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urKernelCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -4550,7 +4574,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize( &pGlobalWorkOffset, &pGlobalWorkSize, &pSuggestedLocalWorkSize}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urKernelGetSuggestedLocalWorkSize")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -4560,7 +4584,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urKernelGetSuggestedLocalWorkSize")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -4573,8 +4597,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urKernelGetSuggestedLocalWorkSize")); if (afterCallback) { return afterCallback(¶ms); @@ -4604,7 +4628,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueGetInfo")); + mock::getCallbacks().get_before_callback("urQueueGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4613,7 +4637,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueGetInfo")); + mock::getCallbacks().get_replace_callback("urQueueGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4626,7 +4650,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueGetInfo")); + mock::getCallbacks().get_after_callback("urQueueGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -4652,7 +4676,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( &phQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueCreate")); + mock::getCallbacks().get_before_callback("urQueueCreate")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4661,7 +4685,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueCreate")); + mock::getCallbacks().get_replace_callback("urQueueCreate")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4675,7 +4699,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueCreate")); + mock::getCallbacks().get_after_callback("urQueueCreate")); if (afterCallback) { return afterCallback(¶ms); } @@ -4696,7 +4720,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( ur_queue_retain_params_t params = {&hQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueRetain")); + mock::getCallbacks().get_before_callback("urQueueRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4705,7 +4729,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueRetain")); + mock::getCallbacks().get_replace_callback("urQueueRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4719,7 +4743,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueRetain")); + mock::getCallbacks().get_after_callback("urQueueRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -4740,7 +4764,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( ur_queue_release_params_t params = {&hQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueRelease")); + mock::getCallbacks().get_before_callback("urQueueRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4749,7 +4773,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueRelease")); + mock::getCallbacks().get_replace_callback("urQueueRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4763,7 +4787,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueRelease")); + mock::getCallbacks().get_after_callback("urQueueRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -4788,7 +4812,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( &phNativeQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueGetNativeHandle")); + mock::getCallbacks().get_before_callback("urQueueGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4797,7 +4821,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urQueueGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4811,7 +4835,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueGetNativeHandle")); + mock::getCallbacks().get_after_callback("urQueueGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -4839,7 +4863,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( &hNativeQueue, &hContext, &hDevice, &pProperties, &phQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urQueueCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4848,7 +4873,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueCreateWithNativeHandle")); + mock::getCallbacks().get_replace_callback( + "urQueueCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4863,7 +4889,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urQueueCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -4883,7 +4910,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFinish( ur_queue_finish_params_t params = {&hQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueFinish")); + mock::getCallbacks().get_before_callback("urQueueFinish")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4892,7 +4919,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFinish( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueFinish")); + mock::getCallbacks().get_replace_callback("urQueueFinish")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4905,7 +4932,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFinish( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueFinish")); + mock::getCallbacks().get_after_callback("urQueueFinish")); if (afterCallback) { return afterCallback(¶ms); } @@ -4925,7 +4952,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFlush( ur_queue_flush_params_t params = {&hQueue}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urQueueFlush")); + mock::getCallbacks().get_before_callback("urQueueFlush")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4934,7 +4961,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFlush( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urQueueFlush")); + mock::getCallbacks().get_replace_callback("urQueueFlush")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4947,7 +4974,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueFlush( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urQueueFlush")); + mock::getCallbacks().get_after_callback("urQueueFlush")); if (afterCallback) { return afterCallback(¶ms); } @@ -4974,7 +5001,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventGetInfo")); + mock::getCallbacks().get_before_callback("urEventGetInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -4983,7 +5010,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventGetInfo")); + mock::getCallbacks().get_replace_callback("urEventGetInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -4996,7 +5023,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventGetInfo")); + mock::getCallbacks().get_after_callback("urEventGetInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -5026,7 +5053,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetProfilingInfo( &hEvent, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventGetProfilingInfo")); + mock::getCallbacks().get_before_callback("urEventGetProfilingInfo")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5035,7 +5062,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetProfilingInfo( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventGetProfilingInfo")); + mock::getCallbacks().get_replace_callback("urEventGetProfilingInfo")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5048,7 +5075,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetProfilingInfo( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventGetProfilingInfo")); + mock::getCallbacks().get_after_callback("urEventGetProfilingInfo")); if (afterCallback) { return afterCallback(¶ms); } @@ -5071,7 +5098,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( ur_event_wait_params_t params = {&numEvents, &phEventWaitList}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventWait")); + mock::getCallbacks().get_before_callback("urEventWait")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5080,7 +5107,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventWait")); + mock::getCallbacks().get_replace_callback("urEventWait")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5093,7 +5120,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventWait")); + mock::getCallbacks().get_after_callback("urEventWait")); if (afterCallback) { return afterCallback(¶ms); } @@ -5113,7 +5140,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( ur_event_retain_params_t params = {&hEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventRetain")); + mock::getCallbacks().get_before_callback("urEventRetain")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5122,7 +5149,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventRetain")); + mock::getCallbacks().get_replace_callback("urEventRetain")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5136,7 +5163,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventRetain")); + mock::getCallbacks().get_after_callback("urEventRetain")); if (afterCallback) { return afterCallback(¶ms); } @@ -5156,7 +5183,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( ur_event_release_params_t params = {&hEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventRelease")); + mock::getCallbacks().get_before_callback("urEventRelease")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5165,7 +5192,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventRelease")); + mock::getCallbacks().get_replace_callback("urEventRelease")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5179,7 +5206,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventRelease")); + mock::getCallbacks().get_after_callback("urEventRelease")); if (afterCallback) { return afterCallback(¶ms); } @@ -5201,7 +5228,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( ur_event_get_native_handle_params_t params = {&hEvent, &phNativeEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventGetNativeHandle")); + mock::getCallbacks().get_before_callback("urEventGetNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5210,7 +5237,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventGetNativeHandle")); + mock::getCallbacks().get_replace_callback("urEventGetNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5224,7 +5251,7 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventGetNativeHandle")); + mock::getCallbacks().get_after_callback("urEventGetNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -5251,7 +5278,8 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( &hNativeEvent, &hContext, &pProperties, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventCreateWithNativeHandle")); + mock::getCallbacks().get_before_callback( + "urEventCreateWithNativeHandle")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5260,7 +5288,8 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventCreateWithNativeHandle")); + mock::getCallbacks().get_replace_callback( + "urEventCreateWithNativeHandle")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5275,7 +5304,8 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventCreateWithNativeHandle")); + mock::getCallbacks().get_after_callback( + "urEventCreateWithNativeHandle")); if (afterCallback) { return afterCallback(¶ms); } @@ -5300,7 +5330,7 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback( &pUserData}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEventSetCallback")); + mock::getCallbacks().get_before_callback("urEventSetCallback")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5309,7 +5339,7 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEventSetCallback")); + mock::getCallbacks().get_replace_callback("urEventSetCallback")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5322,7 +5352,7 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEventSetCallback")); + mock::getCallbacks().get_after_callback("urEventSetCallback")); if (afterCallback) { return afterCallback(¶ms); } @@ -5376,7 +5406,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueKernelLaunch")); + mock::getCallbacks().get_before_callback("urEnqueueKernelLaunch")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5385,7 +5415,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueKernelLaunch")); + mock::getCallbacks().get_replace_callback("urEnqueueKernelLaunch")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5402,7 +5432,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueKernelLaunch")); + mock::getCallbacks().get_after_callback("urEnqueueKernelLaunch")); if (afterCallback) { return afterCallback(¶ms); } @@ -5433,7 +5463,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWait( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueEventsWait")); + mock::getCallbacks().get_before_callback("urEnqueueEventsWait")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5442,7 +5472,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWait( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueEventsWait")); + mock::getCallbacks().get_replace_callback("urEnqueueEventsWait")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5459,7 +5489,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWait( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueEventsWait")); + mock::getCallbacks().get_after_callback("urEnqueueEventsWait")); if (afterCallback) { return afterCallback(¶ms); } @@ -5490,7 +5520,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( &hQueue, &numEventsInWaitList, &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueEventsWaitWithBarrier")); + mock::getCallbacks().get_before_callback( + "urEnqueueEventsWaitWithBarrier")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5499,7 +5530,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueEventsWaitWithBarrier")); + mock::getCallbacks().get_replace_callback( + "urEnqueueEventsWaitWithBarrier")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5516,7 +5548,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueEventsWaitWithBarrier")); + mock::getCallbacks().get_after_callback( + "urEnqueueEventsWaitWithBarrier")); if (afterCallback) { return afterCallback(¶ms); } @@ -5554,7 +5587,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferRead")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferRead")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5563,7 +5596,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferRead")); + mock::getCallbacks().get_replace_callback("urEnqueueMemBufferRead")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5580,7 +5613,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferRead")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferRead")); if (afterCallback) { return afterCallback(¶ms); } @@ -5620,7 +5653,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferWrite")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferWrite")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5629,7 +5662,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferWrite")); + mock::getCallbacks().get_replace_callback("urEnqueueMemBufferWrite")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5646,7 +5679,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferWrite")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferWrite")); if (afterCallback) { return afterCallback(¶ms); } @@ -5706,7 +5739,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferReadRect")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferReadRect")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5715,7 +5748,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferReadRect")); + mock::getCallbacks().get_replace_callback( + "urEnqueueMemBufferReadRect")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5732,7 +5766,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferReadRect")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferReadRect")); if (afterCallback) { return afterCallback(¶ms); } @@ -5795,7 +5829,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferWriteRect")); + mock::getCallbacks().get_before_callback( + "urEnqueueMemBufferWriteRect")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5804,7 +5839,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferWriteRect")); + mock::getCallbacks().get_replace_callback( + "urEnqueueMemBufferWriteRect")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5821,7 +5857,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferWriteRect")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferWriteRect")); if (afterCallback) { return afterCallback(¶ms); } @@ -5859,7 +5895,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( &size, &numEventsInWaitList, &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferCopy")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferCopy")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5868,7 +5904,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferCopy")); + mock::getCallbacks().get_replace_callback("urEnqueueMemBufferCopy")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5885,7 +5921,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferCopy")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferCopy")); if (afterCallback) { return afterCallback(¶ms); } @@ -5934,7 +5970,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferCopyRect")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferCopyRect")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -5943,7 +5979,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferCopyRect")); + mock::getCallbacks().get_replace_callback( + "urEnqueueMemBufferCopyRect")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -5960,7 +5997,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferCopyRect")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferCopyRect")); if (afterCallback) { return afterCallback(¶ms); } @@ -6003,7 +6040,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferFill")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferFill")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6012,7 +6049,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferFill")); + mock::getCallbacks().get_replace_callback("urEnqueueMemBufferFill")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6029,7 +6066,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferFill")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferFill")); if (afterCallback) { return afterCallback(¶ms); } @@ -6073,7 +6110,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemImageRead")); + mock::getCallbacks().get_before_callback("urEnqueueMemImageRead")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6082,7 +6119,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemImageRead")); + mock::getCallbacks().get_replace_callback("urEnqueueMemImageRead")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6099,7 +6136,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemImageRead")); + mock::getCallbacks().get_after_callback("urEnqueueMemImageRead")); if (afterCallback) { return afterCallback(¶ms); } @@ -6144,7 +6181,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemImageWrite")); + mock::getCallbacks().get_before_callback("urEnqueueMemImageWrite")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6153,7 +6190,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemImageWrite")); + mock::getCallbacks().get_replace_callback("urEnqueueMemImageWrite")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6170,7 +6207,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemImageWrite")); + mock::getCallbacks().get_after_callback("urEnqueueMemImageWrite")); if (afterCallback) { return afterCallback(¶ms); } @@ -6214,7 +6251,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy( ®ion, &numEventsInWaitList, &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemImageCopy")); + mock::getCallbacks().get_before_callback("urEnqueueMemImageCopy")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6223,7 +6260,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemImageCopy")); + mock::getCallbacks().get_replace_callback("urEnqueueMemImageCopy")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6240,7 +6277,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemImageCopy")); + mock::getCallbacks().get_after_callback("urEnqueueMemImageCopy")); if (afterCallback) { return afterCallback(¶ms); } @@ -6280,7 +6317,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( &phEvent, &ppRetMap}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemBufferMap")); + mock::getCallbacks().get_before_callback("urEnqueueMemBufferMap")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6289,7 +6326,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemBufferMap")); + mock::getCallbacks().get_replace_callback("urEnqueueMemBufferMap")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6309,7 +6346,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemBufferMap")); + mock::getCallbacks().get_after_callback("urEnqueueMemBufferMap")); if (afterCallback) { return afterCallback(¶ms); } @@ -6343,7 +6380,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueMemUnmap")); + mock::getCallbacks().get_before_callback("urEnqueueMemUnmap")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6352,7 +6389,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueMemUnmap")); + mock::getCallbacks().get_replace_callback("urEnqueueMemUnmap")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6369,7 +6406,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueMemUnmap")); + mock::getCallbacks().get_after_callback("urEnqueueMemUnmap")); if (afterCallback) { return afterCallback(¶ms); } @@ -6409,7 +6446,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMFill")); + mock::getCallbacks().get_before_callback("urEnqueueUSMFill")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6418,7 +6455,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMFill")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMFill")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6435,7 +6472,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMFill")); + mock::getCallbacks().get_after_callback("urEnqueueUSMFill")); if (afterCallback) { return afterCallback(¶ms); } @@ -6472,7 +6509,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMMemcpy")); + mock::getCallbacks().get_before_callback("urEnqueueUSMMemcpy")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6481,7 +6518,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMMemcpy")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMMemcpy")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6498,7 +6535,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMMemcpy")); + mock::getCallbacks().get_after_callback("urEnqueueUSMMemcpy")); if (afterCallback) { return afterCallback(¶ms); } @@ -6533,7 +6570,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMPrefetch")); + mock::getCallbacks().get_before_callback("urEnqueueUSMPrefetch")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6542,7 +6579,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMPrefetch")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMPrefetch")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6559,7 +6596,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMPrefetch")); + mock::getCallbacks().get_after_callback("urEnqueueUSMPrefetch")); if (afterCallback) { return afterCallback(¶ms); } @@ -6587,7 +6624,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMAdvise")); + mock::getCallbacks().get_before_callback("urEnqueueUSMAdvise")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6596,7 +6633,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMAdvise")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMAdvise")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6613,7 +6650,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMAdvise")); + mock::getCallbacks().get_after_callback("urEnqueueUSMAdvise")); if (afterCallback) { return afterCallback(¶ms); } @@ -6658,7 +6695,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMFill2D")); + mock::getCallbacks().get_before_callback("urEnqueueUSMFill2D")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6667,7 +6704,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMFill2D")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMFill2D")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6684,7 +6721,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMFill2D")); + mock::getCallbacks().get_after_callback("urEnqueueUSMFill2D")); if (afterCallback) { return afterCallback(¶ms); } @@ -6729,7 +6766,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueUSMMemcpy2D")); + mock::getCallbacks().get_before_callback("urEnqueueUSMMemcpy2D")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6738,7 +6775,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueUSMMemcpy2D")); + mock::getCallbacks().get_replace_callback("urEnqueueUSMMemcpy2D")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6755,7 +6792,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueUSMMemcpy2D")); + mock::getCallbacks().get_after_callback("urEnqueueUSMMemcpy2D")); if (afterCallback) { return afterCallback(¶ms); } @@ -6796,7 +6833,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urEnqueueDeviceGlobalVariableWrite")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -6806,7 +6843,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urEnqueueDeviceGlobalVariableWrite")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -6823,8 +6860,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urEnqueueDeviceGlobalVariableWrite")); if (afterCallback) { return afterCallback(¶ms); @@ -6866,7 +6903,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urEnqueueDeviceGlobalVariableRead")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -6876,7 +6913,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urEnqueueDeviceGlobalVariableRead")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -6893,8 +6930,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urEnqueueDeviceGlobalVariableRead")); if (afterCallback) { return afterCallback(¶ms); @@ -6940,7 +6977,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueReadHostPipe( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueReadHostPipe")); + mock::getCallbacks().get_before_callback("urEnqueueReadHostPipe")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -6949,7 +6986,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueReadHostPipe( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueReadHostPipe")); + mock::getCallbacks().get_replace_callback("urEnqueueReadHostPipe")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -6966,7 +7003,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueReadHostPipe( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueReadHostPipe")); + mock::getCallbacks().get_after_callback("urEnqueueReadHostPipe")); if (afterCallback) { return afterCallback(¶ms); } @@ -7011,7 +7048,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueWriteHostPipe( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueWriteHostPipe")); + mock::getCallbacks().get_before_callback("urEnqueueWriteHostPipe")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7020,7 +7057,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueWriteHostPipe( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueWriteHostPipe")); + mock::getCallbacks().get_replace_callback("urEnqueueWriteHostPipe")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -7037,7 +7074,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueWriteHostPipe( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueWriteHostPipe")); + mock::getCallbacks().get_after_callback("urEnqueueWriteHostPipe")); if (afterCallback) { return afterCallback(¶ms); } @@ -7071,7 +7108,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( &height, &elementSizeBytes, &ppMem, &pResultPitch}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMPitchedAllocExp")); + mock::getCallbacks().get_before_callback("urUSMPitchedAllocExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7080,7 +7117,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMPitchedAllocExp")); + mock::getCallbacks().get_replace_callback("urUSMPitchedAllocExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -7094,7 +7131,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMPitchedAllocExp")); + mock::getCallbacks().get_after_callback("urUSMPitchedAllocExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7119,7 +7156,7 @@ urBindlessImagesUnsampledImageHandleDestroyExp( &hContext, &hDevice, &hImage}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesUnsampledImageHandleDestroyExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7129,7 +7166,7 @@ urBindlessImagesUnsampledImageHandleDestroyExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesUnsampledImageHandleDestroyExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7143,8 +7180,8 @@ urBindlessImagesUnsampledImageHandleDestroyExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesUnsampledImageHandleDestroyExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7170,7 +7207,7 @@ urBindlessImagesSampledImageHandleDestroyExp( &hContext, &hDevice, &hImage}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesSampledImageHandleDestroyExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7180,7 +7217,7 @@ urBindlessImagesSampledImageHandleDestroyExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesSampledImageHandleDestroyExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7194,8 +7231,8 @@ urBindlessImagesSampledImageHandleDestroyExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesSampledImageHandleDestroyExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7223,7 +7260,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( &hContext, &hDevice, &pImageFormat, &pImageDesc, &phImageMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesImageAllocateExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7233,7 +7270,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesImageAllocateExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7249,7 +7286,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urBindlessImagesImageAllocateExp")); + mock::getCallbacks().get_after_callback( + "urBindlessImagesImageAllocateExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7273,7 +7311,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageFreeExp( &hImageMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urBindlessImagesImageFreeExp")); + mock::getCallbacks().get_before_callback( + "urBindlessImagesImageFreeExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7282,7 +7321,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageFreeExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urBindlessImagesImageFreeExp")); + mock::getCallbacks().get_replace_callback( + "urBindlessImagesImageFreeExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -7296,7 +7336,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageFreeExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urBindlessImagesImageFreeExp")); + mock::getCallbacks().get_after_callback( + "urBindlessImagesImageFreeExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7325,7 +7366,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( &hContext, &hDevice, &hImageMem, &pImageFormat, &pImageDesc, &phImage}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesUnsampledImageCreateExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7335,7 +7376,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesUnsampledImageCreateExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7349,8 +7390,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesUnsampledImageCreateExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7382,7 +7423,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( &pImageDesc, &hSampler, &phImage}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesSampledImageCreateExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7392,7 +7433,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesSampledImageCreateExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7406,8 +7447,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesSampledImageCreateExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7469,7 +7510,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageCopyExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urBindlessImagesImageCopyExp")); + mock::getCallbacks().get_before_callback( + "urBindlessImagesImageCopyExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7478,7 +7520,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageCopyExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urBindlessImagesImageCopyExp")); + mock::getCallbacks().get_replace_callback( + "urBindlessImagesImageCopyExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -7495,7 +7538,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageCopyExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urBindlessImagesImageCopyExp")); + mock::getCallbacks().get_after_callback( + "urBindlessImagesImageCopyExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7521,7 +7565,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( &hContext, &hImageMem, &propName, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urBindlessImagesImageGetInfoExp")); + mock::getCallbacks().get_before_callback( + "urBindlessImagesImageGetInfoExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7530,7 +7575,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesImageGetInfoExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7544,7 +7589,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urBindlessImagesImageGetInfoExp")); + mock::getCallbacks().get_after_callback( + "urBindlessImagesImageGetInfoExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7571,7 +7617,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( &hContext, &hDevice, &hImageMem, &mipmapLevel, &phImageMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesMipmapGetLevelExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7581,7 +7627,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesMipmapGetLevelExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7596,8 +7642,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesMipmapGetLevelExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7622,7 +7668,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( &hMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urBindlessImagesMipmapFreeExp")); + mock::getCallbacks().get_before_callback( + "urBindlessImagesMipmapFreeExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -7631,7 +7678,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urBindlessImagesMipmapFreeExp")); + mock::getCallbacks().get_replace_callback( + "urBindlessImagesMipmapFreeExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -7645,7 +7693,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urBindlessImagesMipmapFreeExp")); + mock::getCallbacks().get_after_callback( + "urBindlessImagesMipmapFreeExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -7675,7 +7724,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalMemoryExp( &memHandleType, &pInteropMemDesc, &phInteropMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesImportExternalMemoryExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7685,7 +7734,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalMemoryExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesImportExternalMemoryExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7699,8 +7748,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalMemoryExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesImportExternalMemoryExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7731,7 +7780,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp( &pImageDesc, &hInteropMem, &phImageMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesMapExternalArrayExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7741,7 +7790,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesMapExternalArrayExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7756,8 +7805,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesMapExternalArrayExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7782,7 +7831,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( &hContext, &hDevice, &hInteropMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesReleaseInteropExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7792,7 +7841,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesReleaseInteropExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7806,8 +7855,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesReleaseInteropExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7837,7 +7886,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalSemaphoreExp( &phInteropSemaphore}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesImportExternalSemaphoreExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7847,7 +7896,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalSemaphoreExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesImportExternalSemaphoreExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7862,8 +7911,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesImportExternalSemaphoreExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesImportExternalSemaphoreExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7888,7 +7937,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesDestroyExternalSemaphoreExp( &hContext, &hDevice, &hInteropSemaphore}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesDestroyExternalSemaphoreExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7898,7 +7947,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesDestroyExternalSemaphoreExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesDestroyExternalSemaphoreExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7912,8 +7961,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesDestroyExternalSemaphoreExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesDestroyExternalSemaphoreExp")); if (afterCallback) { return afterCallback(¶ms); @@ -7955,7 +8004,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesWaitExternalSemaphoreExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesWaitExternalSemaphoreExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -7965,7 +8014,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesWaitExternalSemaphoreExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesWaitExternalSemaphoreExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -7982,8 +8031,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesWaitExternalSemaphoreExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesWaitExternalSemaphoreExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8025,7 +8074,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSignalExternalSemaphoreExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urBindlessImagesSignalExternalSemaphoreExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8035,7 +8084,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSignalExternalSemaphoreExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urBindlessImagesSignalExternalSemaphoreExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8052,8 +8101,8 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesSignalExternalSemaphoreExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urBindlessImagesSignalExternalSemaphoreExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8080,7 +8129,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( &hContext, &hDevice, &pCommandBufferDesc, &phCommandBuffer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferCreateExp")); + mock::getCallbacks().get_before_callback("urCommandBufferCreateExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -8089,7 +8138,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferCreateExp")); + mock::getCallbacks().get_replace_callback("urCommandBufferCreateExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -8104,7 +8153,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferCreateExp")); + mock::getCallbacks().get_after_callback("urCommandBufferCreateExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -8125,7 +8174,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_command_buffer_retain_exp_params_t params = {&hCommandBuffer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferRetainExp")); + mock::getCallbacks().get_before_callback("urCommandBufferRetainExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -8134,7 +8183,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferRetainExp")); + mock::getCallbacks().get_replace_callback("urCommandBufferRetainExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -8148,7 +8197,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferRetainExp")); + mock::getCallbacks().get_after_callback("urCommandBufferRetainExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -8169,7 +8218,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_command_buffer_release_exp_params_t params = {&hCommandBuffer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferReleaseExp")); + mock::getCallbacks().get_before_callback("urCommandBufferReleaseExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -8178,7 +8227,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferReleaseExp")); + mock::getCallbacks().get_replace_callback("urCommandBufferReleaseExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -8192,7 +8241,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferReleaseExp")); + mock::getCallbacks().get_after_callback("urCommandBufferReleaseExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -8213,7 +8262,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferFinalizeExp( ur_command_buffer_finalize_exp_params_t params = {&hCommandBuffer}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferFinalizeExp")); + mock::getCallbacks().get_before_callback("urCommandBufferFinalizeExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -8222,7 +8271,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferFinalizeExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferFinalizeExp")); + mock::getCallbacks().get_replace_callback( + "urCommandBufferFinalizeExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -8235,7 +8285,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferFinalizeExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferFinalizeExp")); + mock::getCallbacks().get_after_callback("urCommandBufferFinalizeExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -8283,7 +8333,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( &phCommand}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendKernelLaunchExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8293,7 +8343,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendKernelLaunchExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8311,8 +8361,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendKernelLaunchExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8346,7 +8396,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( &pSyncPointWaitList, &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendUSMMemcpyExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8356,7 +8406,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendUSMMemcpyExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8369,8 +8419,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendUSMMemcpyExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8407,7 +8457,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( &pSyncPointWaitList, &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferAppendUSMFillExp")); + mock::getCallbacks().get_before_callback( + "urCommandBufferAppendUSMFillExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -8416,7 +8467,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendUSMFillExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8430,7 +8481,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferAppendUSMFillExp")); + mock::getCallbacks().get_after_callback( + "urCommandBufferAppendUSMFillExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -8472,7 +8524,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferCopyExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8482,7 +8534,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferCopyExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8495,8 +8547,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferCopyExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8538,7 +8590,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferWriteExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8548,7 +8600,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferWriteExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8561,8 +8613,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferWriteExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8603,7 +8655,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferReadExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8613,7 +8665,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferReadExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8626,8 +8678,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferReadExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8681,7 +8733,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferCopyRectExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8691,7 +8743,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferCopyRectExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8704,8 +8756,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferCopyRectExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8765,7 +8817,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferWriteRectExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8775,7 +8827,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferWriteRectExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8788,8 +8840,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferWriteRectExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8847,7 +8899,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferReadRectExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8857,7 +8909,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferReadRectExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8870,8 +8922,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferReadRectExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8915,7 +8967,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendMemBufferFillExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8925,7 +8977,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendMemBufferFillExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -8938,8 +8990,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendMemBufferFillExp")); if (afterCallback) { return afterCallback(¶ms); @@ -8978,7 +9030,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendUSMPrefetchExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -8988,7 +9040,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendUSMPrefetchExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9001,8 +9053,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendUSMPrefetchExp")); if (afterCallback) { return afterCallback(¶ms); @@ -9041,7 +9093,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( &pSyncPoint}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferAppendUSMAdviseExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9051,7 +9103,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferAppendUSMAdviseExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9064,8 +9116,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferAppendUSMAdviseExp")); if (afterCallback) { return afterCallback(¶ms); @@ -9099,7 +9151,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferEnqueueExp")); + mock::getCallbacks().get_before_callback("urCommandBufferEnqueueExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9108,7 +9160,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferEnqueueExp")); + mock::getCallbacks().get_replace_callback("urCommandBufferEnqueueExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9125,7 +9177,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferEnqueueExp")); + mock::getCallbacks().get_after_callback("urCommandBufferEnqueueExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9146,7 +9198,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( ur_command_buffer_retain_command_exp_params_t params = {&hCommand}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferRetainCommandExp")); + mock::getCallbacks().get_before_callback( + "urCommandBufferRetainCommandExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9155,7 +9208,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferRetainCommandExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9170,7 +9223,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferRetainCommandExp")); + mock::getCallbacks().get_after_callback( + "urCommandBufferRetainCommandExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9191,7 +9245,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_command_buffer_release_command_exp_params_t params = {&hCommand}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferReleaseCommandExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9201,7 +9255,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferReleaseCommandExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9216,7 +9270,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferReleaseCommandExp")); + mock::getCallbacks().get_after_callback( + "urCommandBufferReleaseCommandExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9240,7 +9295,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( &hCommand, &pUpdateKernelLaunch}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferUpdateKernelLaunchExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9250,7 +9305,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferUpdateKernelLaunchExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9263,8 +9318,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urCommandBufferUpdateKernelLaunchExp")); if (afterCallback) { return afterCallback(¶ms); @@ -9296,7 +9351,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetInfoExp( &hCommandBuffer, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urCommandBufferGetInfoExp")); + mock::getCallbacks().get_before_callback("urCommandBufferGetInfoExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9305,7 +9360,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetInfoExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urCommandBufferGetInfoExp")); + mock::getCallbacks().get_replace_callback("urCommandBufferGetInfoExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9318,7 +9373,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetInfoExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferGetInfoExp")); + mock::getCallbacks().get_after_callback("urCommandBufferGetInfoExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9349,7 +9404,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( &hCommand, &propName, &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urCommandBufferCommandGetInfoExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9359,7 +9414,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urCommandBufferCommandGetInfoExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9373,7 +9428,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urCommandBufferCommandGetInfoExp")); + mock::getCallbacks().get_after_callback( + "urCommandBufferCommandGetInfoExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9428,7 +9484,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urEnqueueCooperativeKernelLaunchExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9438,7 +9494,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urEnqueueCooperativeKernelLaunchExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9455,8 +9511,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urEnqueueCooperativeKernelLaunchExp")); if (afterCallback) { return afterCallback(¶ms); @@ -9485,7 +9541,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( &hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback( + mock::getCallbacks().get_before_callback( "urKernelSuggestMaxCooperativeGroupCountExp")); if (beforeCallback) { result = beforeCallback(¶ms); @@ -9495,7 +9551,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback( + mock::getCallbacks().get_replace_callback( "urKernelSuggestMaxCooperativeGroupCountExp")); if (replaceCallback) { result = replaceCallback(¶ms); @@ -9508,8 +9564,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( return result; } - auto afterCallback = - reinterpret_cast(mock::callbacks.get_after_callback( + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback( "urKernelSuggestMaxCooperativeGroupCountExp")); if (afterCallback) { return afterCallback(¶ms); @@ -9549,7 +9605,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( &hQueue, &blocking, &numEventsInWaitList, &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueTimestampRecordingExp")); + mock::getCallbacks().get_before_callback( + "urEnqueueTimestampRecordingExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9558,7 +9615,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueTimestampRecordingExp")); + mock::getCallbacks().get_replace_callback( + "urEnqueueTimestampRecordingExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9572,7 +9630,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueTimestampRecordingExp")); + mock::getCallbacks().get_after_callback( + "urEnqueueTimestampRecordingExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9622,7 +9681,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( &phEventWaitList, &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueKernelLaunchCustomExp")); + mock::getCallbacks().get_before_callback( + "urEnqueueKernelLaunchCustomExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9631,7 +9691,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueKernelLaunchCustomExp")); + mock::getCallbacks().get_replace_callback( + "urEnqueueKernelLaunchCustomExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9644,7 +9705,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueKernelLaunchCustomExp")); + mock::getCallbacks().get_after_callback( + "urEnqueueKernelLaunchCustomExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9670,7 +9732,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( &pOptions}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramBuildExp")); + mock::getCallbacks().get_before_callback("urProgramBuildExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9679,7 +9741,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramBuildExp")); + mock::getCallbacks().get_replace_callback("urProgramBuildExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9692,7 +9754,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramBuildExp")); + mock::getCallbacks().get_after_callback("urProgramBuildExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9719,7 +9781,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( &phDevices, &pOptions}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramCompileExp")); + mock::getCallbacks().get_before_callback("urProgramCompileExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9728,7 +9790,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramCompileExp")); + mock::getCallbacks().get_replace_callback("urProgramCompileExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9741,7 +9803,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramCompileExp")); + mock::getCallbacks().get_after_callback("urProgramCompileExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9776,7 +9838,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( &phProgram}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urProgramLinkExp")); + mock::getCallbacks().get_before_callback("urProgramLinkExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9785,7 +9847,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urProgramLinkExp")); + mock::getCallbacks().get_replace_callback("urProgramLinkExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9799,7 +9861,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urProgramLinkExp")); + mock::getCallbacks().get_after_callback("urProgramLinkExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9821,7 +9883,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMImportExp( ur_usm_import_exp_params_t params = {&hContext, &pMem, &size}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMImportExp")); + mock::getCallbacks().get_before_callback("urUSMImportExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9830,7 +9892,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMImportExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMImportExp")); + mock::getCallbacks().get_replace_callback("urUSMImportExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9843,7 +9905,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMImportExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMImportExp")); + mock::getCallbacks().get_after_callback("urUSMImportExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9864,7 +9926,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMReleaseExp( ur_usm_release_exp_params_t params = {&hContext, &pMem}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUSMReleaseExp")); + mock::getCallbacks().get_before_callback("urUSMReleaseExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9873,7 +9935,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMReleaseExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUSMReleaseExp")); + mock::getCallbacks().get_replace_callback("urUSMReleaseExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9886,7 +9948,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMReleaseExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUSMReleaseExp")); + mock::getCallbacks().get_after_callback("urUSMReleaseExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9909,7 +9971,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( &peerDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUsmP2PEnablePeerAccessExp")); + mock::getCallbacks().get_before_callback( + "urUsmP2PEnablePeerAccessExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9918,7 +9981,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUsmP2PEnablePeerAccessExp")); + mock::getCallbacks().get_replace_callback( + "urUsmP2PEnablePeerAccessExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9931,7 +9995,7 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUsmP2PEnablePeerAccessExp")); + mock::getCallbacks().get_after_callback("urUsmP2PEnablePeerAccessExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -9954,7 +10018,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( &peerDevice}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUsmP2PDisablePeerAccessExp")); + mock::getCallbacks().get_before_callback( + "urUsmP2PDisablePeerAccessExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -9963,7 +10028,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUsmP2PDisablePeerAccessExp")); + mock::getCallbacks().get_replace_callback( + "urUsmP2PDisablePeerAccessExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -9976,7 +10042,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUsmP2PDisablePeerAccessExp")); + mock::getCallbacks().get_after_callback( + "urUsmP2PDisablePeerAccessExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -10011,7 +10078,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( &propSize, &pPropValue, &pPropSizeRet}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urUsmP2PPeerAccessGetInfoExp")); + mock::getCallbacks().get_before_callback( + "urUsmP2PPeerAccessGetInfoExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -10020,7 +10088,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urUsmP2PPeerAccessGetInfoExp")); + mock::getCallbacks().get_replace_callback( + "urUsmP2PPeerAccessGetInfoExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -10033,7 +10102,8 @@ __urdlllocal ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urUsmP2PPeerAccessGetInfoExp")); + mock::getCallbacks().get_after_callback( + "urUsmP2PPeerAccessGetInfoExp")); if (afterCallback) { return afterCallback(¶ms); } @@ -10081,7 +10151,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp( &phEvent}; auto beforeCallback = reinterpret_cast( - mock::callbacks.get_before_callback("urEnqueueNativeCommandExp")); + mock::getCallbacks().get_before_callback("urEnqueueNativeCommandExp")); if (beforeCallback) { result = beforeCallback(¶ms); if (result != UR_RESULT_SUCCESS) { @@ -10090,7 +10160,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp( } auto replaceCallback = reinterpret_cast( - mock::callbacks.get_replace_callback("urEnqueueNativeCommandExp")); + mock::getCallbacks().get_replace_callback("urEnqueueNativeCommandExp")); if (replaceCallback) { result = replaceCallback(¶ms); } else { @@ -10107,7 +10177,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp( } auto afterCallback = reinterpret_cast( - mock::callbacks.get_after_callback("urEnqueueNativeCommandExp")); + mock::getCallbacks().get_after_callback("urEnqueueNativeCommandExp")); if (afterCallback) { return afterCallback(¶ms); } diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 95cf231ac4..ab4fe098b4 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -199,7 +199,8 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the mock::callbacks object. +/// the mock library and using the object accessed via +/// mock::getCallbacks(). /// /// @returns /// - ::UR_RESULT_SUCCESS diff --git a/source/mock/ur_mock_helpers.cpp b/source/mock/ur_mock_helpers.cpp index 5fb4391a2d..a8304492a6 100644 --- a/source/mock/ur_mock_helpers.cpp +++ b/source/mock/ur_mock_helpers.cpp @@ -14,6 +14,9 @@ namespace mock { -callbacks_t callbacks; +callbacks_t &getCallbacks() { + static callbacks_t callbacks; + return callbacks; +} } // namespace mock diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index df69f9fa7e..06ef4e5739 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -77,7 +77,7 @@ struct callbacks_t { beforeCallbacks[name] = callback; } - ur_mock_callback_t get_before_callback(std::string name) { + ur_mock_callback_t get_before_callback(std::string name) const { auto callback = beforeCallbacks.find(name); if (callback != beforeCallbacks.end()) { @@ -90,7 +90,7 @@ struct callbacks_t { replaceCallbacks[name] = callback; } - ur_mock_callback_t get_replace_callback(std::string name) { + ur_mock_callback_t get_replace_callback(std::string name) const { auto callback = replaceCallbacks.find(name); if (callback != replaceCallbacks.end()) { @@ -103,7 +103,7 @@ struct callbacks_t { afterCallbacks[name] = callback; } - ur_mock_callback_t get_after_callback(std::string name) { + ur_mock_callback_t get_after_callback(std::string name) const { auto callback = afterCallbacks.find(name); if (callback != afterCallbacks.end()) { @@ -124,6 +124,6 @@ struct callbacks_t { std::unordered_map afterCallbacks; }; -extern callbacks_t callbacks; +callbacks_t &getCallbacks(); } // namespace mock diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 0460c3d663..7fc87f9eae 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -188,7 +188,8 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( /// - The mock adapter will default to returning ::UR_RESULT_SUCCESS for all /// entry points. It will also create and correctly reference count dummy /// handles where appropriate. Its behaviour can be modified by linking -/// the ::ur_mock_headers library and using the mock::callbacks object. +/// the mock library and using the object accessed via +/// mock::getCallbacks(). /// /// @returns /// - ::UR_RESULT_SUCCESS diff --git a/test/mock/mock.cpp b/test/mock/mock.cpp index 36c3bab113..0dac92e8b4 100644 --- a/test/mock/mock.cpp +++ b/test/mock/mock.cpp @@ -90,14 +90,16 @@ TEST(Mock, Callbacks) { // This callback is set up to check *phAdapters is still the pre-call // init value we set below - mock::callbacks.set_before_callback("urAdapterGet", &beforeUrAdapterGet); + mock::getCallbacks().set_before_callback("urAdapterGet", + &beforeUrAdapterGet); // This callback is set up to return a distinct test value in phAdapters // rather than the default generic handle - mock::callbacks.set_replace_callback("urAdapterGet", &replaceUrAdapterGet); + mock::getCallbacks().set_replace_callback("urAdapterGet", + &replaceUrAdapterGet); // This callback is set up to check our replace callback did its job - mock::callbacks.set_after_callback("urAdapterGet", &afterUrAdapterGet); + mock::getCallbacks().set_after_callback("urAdapterGet", &afterUrAdapterGet); ur_adapter_handle_t adapter = reinterpret_cast(0xF00DCAFE); From cca747d68302563d9f959b4e3bcef4207ae1c5cf Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 4 Jul 2024 10:32:01 +0100 Subject: [PATCH 10/15] Fix tests that relied on old null adapter behaviour. --- source/adapters/mock/ur_mock.cpp | 270 +++++------------- source/adapters/mock/ur_mock.hpp | 7 - source/loader/ur_lib.cpp | 2 +- .../tracing/hello_world.out.logged.match | 4 +- test/layers/tracing/hello_world.out.match | 2 +- test/layers/validation/CMakeLists.txt | 1 + test/layers/validation/fixtures.hpp | 29 +- test/layers/validation/leaks.cpp | 30 +- test/layers/validation/leaks.out.match | 6 +- test/loader/handles/CMakeLists.txt | 1 + test/loader/handles/fixtures.hpp | 34 +++ test/loader/platforms/null_platform.match | 2 +- test/tools/urtrace/CMakeLists.txt | 12 +- .../{null_hello.match => mock_hello.match} | 2 +- ...llo_begin.match => mock_hello_begin.match} | 2 +- ...e.match => mock_hello_filter_device.match} | 2 +- ...hello_json.match => mock_hello_json.match} | 4 +- ...no_args.match => mock_hello_no_args.match} | 2 +- ...iling.match => mock_hello_profiling.match} | 2 +- tools/urtrace/README.md | 4 +- 20 files changed, 178 insertions(+), 240 deletions(-) rename test/tools/urtrace/{null_hello.match => mock_hello.match} (98%) rename test/tools/urtrace/{null_hello_begin.match => mock_hello_begin.match} (99%) rename test/tools/urtrace/{null_hello_filter_device.match => mock_hello_filter_device.match} (96%) rename test/tools/urtrace/{null_hello_json.match => mock_hello_json.match} (97%) rename test/tools/urtrace/{null_hello_no_args.match => mock_hello_no_args.match} (94%) rename test/tools/urtrace/{null_hello_profiling.match => mock_hello_profiling.match} (98%) diff --git a/source/adapters/mock/ur_mock.cpp b/source/adapters/mock/ur_mock.cpp index cf95f08aee..502b532470 100644 --- a/source/adapters/mock/ur_mock.cpp +++ b/source/adapters/mock/ur_mock.cpp @@ -10,230 +10,88 @@ * */ #include "ur_mock.hpp" +#include "ur_mock_helpers.hpp" namespace driver { ////////////////////////////////////////////////////////////////////////// context_t d_context; -////////////////////////////////////////////////////////////////////////// -context_t::context_t() { - platform = get(); - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Global.pfnAdapterGet = [](uint32_t NumAdapters, - ur_adapter_handle_t *phAdapters, - uint32_t *pNumAdapters) { - if (phAdapters != nullptr && NumAdapters != 1) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pNumAdapters != nullptr) { - *pNumAdapters = 1; - } - if (nullptr != phAdapters) { - *reinterpret_cast(phAdapters) = d_context.platform; - } - - return UR_RESULT_SUCCESS; - }; - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Global.pfnAdapterRelease = [](ur_adapter_handle_t) { - return UR_RESULT_SUCCESS; - }; - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Platform.pfnGet = - [](ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, - ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { - if (phPlatforms != nullptr && NumEntries != 1) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pNumPlatforms != nullptr) { - *pNumPlatforms = 1; - } - if (nullptr != phPlatforms) { - *reinterpret_cast(phPlatforms) = d_context.platform; - } - return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Platform.pfnGetApiVersion = [](ur_platform_handle_t, - ur_api_version_t *version) { - *version = d_context.version; - return UR_RESULT_SUCCESS; - }; +ur_result_t mock_urPlatformGetApiVersion(void *pParams) { + auto params = *static_cast(pParams); + **params.ppVersion = d_context.version; + return UR_RESULT_SUCCESS; +} - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Platform.pfnGetInfo = - [](ur_platform_handle_t hPlatform, ur_platform_info_t PlatformInfoType, - size_t Size, void *pPlatformInfo, size_t *pSizeRet) { - if (!hPlatform) { - return UR_RESULT_ERROR_INVALID_NULL_HANDLE; - } +ur_result_t mock_urPlatformGetInfo(void *pParams) { + auto params = *static_cast(pParams); + if (!*params.phPlatform) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } - switch (PlatformInfoType) { - case UR_PLATFORM_INFO_NAME: { - const char null_platform_name[] = "UR_PLATFORM_NULL"; - if (pSizeRet) { - *pSizeRet = sizeof(null_platform_name); - } - if (pPlatformInfo && Size != sizeof(null_platform_name)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pPlatformInfo) { + if (*params.ppropName == UR_PLATFORM_INFO_NAME) { + const char mock_platform_name[] = "UR_PLATFORM_MOCK"; + if (*params.ppPropSizeRet) { + **params.ppPropSizeRet = sizeof(mock_platform_name); + } + if (*params.ppPropValue) { #if defined(_WIN32) - strncpy_s(reinterpret_cast(pPlatformInfo), Size, - null_platform_name, sizeof(null_platform_name)); + strncpy_s(reinterpret_cast(*params.ppPropValue), + *params.ppropSize, null_platform_name, + sizeof(mock_platform_name)); #else - strncpy(reinterpret_cast(pPlatformInfo), - null_platform_name, Size); + strncpy(reinterpret_cast(*params.ppPropValue), + mock_platform_name, *params.ppropSize); #endif - } - } break; - - default: - return UR_RESULT_ERROR_INVALID_ENUMERATION; - } - - return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Device.pfnGet = - [](ur_platform_handle_t hPlatform, ur_device_type_t DevicesType, - uint32_t NumEntries, ur_device_handle_t *phDevices, - uint32_t *pNumDevices) { - (void)DevicesType; - if (hPlatform == nullptr) { - return UR_RESULT_ERROR_INVALID_NULL_HANDLE; - } - if (UR_DEVICE_TYPE_VPU < DevicesType) { - return UR_RESULT_ERROR_INVALID_ENUMERATION; - } - if (phDevices != nullptr && NumEntries != 1) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pNumDevices != nullptr) { - *pNumDevices = 1; - } - if (nullptr != phDevices) { - *reinterpret_cast(phDevices) = d_context.get(); - } - return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.Device.pfnGetInfo = [](ur_device_handle_t, - ur_device_info_t infoType, - size_t propSize, void *pDeviceInfo, - size_t *pPropSizeRet) { - switch (infoType) { - case UR_DEVICE_INFO_TYPE: - if (pDeviceInfo && propSize != sizeof(ur_device_type_t)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } + } + } + return UR_RESULT_SUCCESS; +} - if (pDeviceInfo != nullptr) { - *reinterpret_cast(pDeviceInfo) = - UR_DEVICE_TYPE_GPU; - } - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(ur_device_type_t); - } - break; +////////////////////////////////////////////////////////////////////////// +ur_result_t mock_urDeviceGetInfo(void *pParams) { + auto params = *static_cast(pParams); + switch (*params.ppropName) { + case UR_DEVICE_INFO_TYPE: + if (*params.ppPropValue != nullptr) { + *reinterpret_cast(*params.ppPropValue) = + UR_DEVICE_TYPE_GPU; + } + if (*params.ppPropSizeRet != nullptr) { + **params.ppPropSizeRet = sizeof(ur_device_type_t); + } + break; - case UR_DEVICE_INFO_NAME: { - char deviceName[] = "Null Device"; - if (pDeviceInfo && propSize < sizeof(deviceName)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pDeviceInfo != nullptr) { + case UR_DEVICE_INFO_NAME: { + char deviceName[] = "Mock Device"; + if (*params.ppPropValue && *params.ppropSize < sizeof(deviceName)) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + if (*params.ppPropValue != nullptr) { #if defined(_WIN32) - strncpy_s(reinterpret_cast(pDeviceInfo), propSize, - deviceName, sizeof(deviceName)); + strncpy_s(reinterpret_cast(*params.ppPropValue), + *params.ppropSize, deviceName, sizeof(deviceName)); #else - strncpy(reinterpret_cast(pDeviceInfo), deviceName, - propSize); + strncpy(reinterpret_cast(*params.ppPropValue), deviceName, + *params.ppropSize); #endif - } - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(deviceName); - } - } break; - case UR_DEVICE_INFO_PLATFORM: { - if (pDeviceInfo && propSize < sizeof(pDeviceInfo)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pDeviceInfo != nullptr) { - *reinterpret_cast(pDeviceInfo) = d_context.platform; - } - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(intptr_t); - } - } break; - default: - return UR_RESULT_ERROR_INVALID_ARGUMENT; - } - return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.USM.pfnHostAlloc = [](ur_context_handle_t, const ur_usm_desc_t *, - ur_usm_pool_handle_t, size_t size, - void **ppMem) { - if (size == 0) { - *ppMem = nullptr; - return UR_RESULT_ERROR_UNSUPPORTED_SIZE; } - *ppMem = malloc(size); - if (*ppMem == nullptr) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + if (*params.ppPropSizeRet != nullptr) { + **params.ppPropSizeRet = sizeof(deviceName); } + } break; + default: return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.USM.pfnDeviceAlloc = - [](ur_context_handle_t, ur_device_handle_t, const ur_usm_desc_t *, - ur_usm_pool_handle_t, size_t size, void **ppMem) { - if (size == 0) { - *ppMem = nullptr; - return UR_RESULT_ERROR_UNSUPPORTED_SIZE; - } - *ppMem = malloc(size); - if (*ppMem == nullptr) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return UR_RESULT_SUCCESS; - }; - - ////////////////////////////////////////////////////////////////////////// - urDdiTable.USM.pfnFree = [](ur_context_handle_t, void *pMem) { - free(pMem); - return UR_RESULT_SUCCESS; - }; + } + return UR_RESULT_SUCCESS; +} - ////////////////////////////////////////////////////////////////////////// - urDdiTable.USM.pfnGetMemAllocInfo = - [](ur_context_handle_t, const void *pMem, ur_usm_alloc_info_t propName, - size_t, void *pPropValue, size_t *pPropSizeRet) { - switch (propName) { - case UR_USM_ALLOC_INFO_TYPE: - *reinterpret_cast(pPropValue) = - pMem ? UR_USM_TYPE_DEVICE : UR_USM_TYPE_UNKNOWN; - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(ur_usm_type_t); - } - break; - case UR_USM_ALLOC_INFO_SIZE: - *reinterpret_cast(pPropValue) = pMem ? SIZE_MAX : 0; - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(size_t); - } - break; - default: - pPropValue = nullptr; - break; - } - return UR_RESULT_SUCCESS; - }; +////////////////////////////////////////////////////////////////////////// +context_t::context_t() { + mock::getCallbacks().set_replace_callback("urPlatformGetApiVersion", + &mock_urPlatformGetApiVersion); + mock::getCallbacks().set_replace_callback("urPlatformGetInfo", + &mock_urPlatformGetInfo); + mock::getCallbacks().set_replace_callback("urDeviceGetInfo", + &mock_urDeviceGetInfo); } } // namespace driver diff --git a/source/adapters/mock/ur_mock.hpp b/source/adapters/mock/ur_mock.hpp index f82a56bfcd..a4a458fbef 100644 --- a/source/adapters/mock/ur_mock.hpp +++ b/source/adapters/mock/ur_mock.hpp @@ -25,13 +25,6 @@ class __urdlllocal context_t { ur_dditable_t urDdiTable = {}; context_t(); ~context_t() = default; - - void *platform; - - void *get() { - static uint64_t count = 0x80800000; - return reinterpret_cast(++count); - } }; extern context_t d_context; diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index 6baaa9fb02..1fbfacb897 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -81,7 +81,7 @@ void context_t::tearDownLayers() const { ////////////////////////////////////////////////////////////////////////// __urdlllocal ur_result_t context_t::Init( ur_device_init_flags_t, ur_loader_config_handle_t hLoaderConfig) { - if (hLoaderConfig->enableMock) { + if (hLoaderConfig && hLoaderConfig->enableMock) { // This clears default known adapters and replaces them with the mock // adapter. ur_loader::context->adapter_registry.enableMock(); diff --git a/test/layers/tracing/hello_world.out.logged.match b/test/layers/tracing/hello_world.out.logged.match index 5bd7a33eae..99b9cac909 100644 --- a/test/layers/tracing/hello_world.out.logged.match +++ b/test/layers/tracing/hello_world.out.logged.match @@ -8,6 +8,6 @@ API version: {{0\.[0-9]+}} ---> urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 0, .phDevices = {}, .pNumDevices = {{.*}} (1)) -> UR_RESULT_SUCCESS; ---> urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 1, .phDevices = {{.*}}, .pNumDevices = nullptr) -> UR_RESULT_SUCCESS; ---> urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = 4, .pPropValue = {{.*}} (UR_DEVICE_TYPE_GPU), .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; ----> urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = {{.*}}, .pPropValue = {{.*}} (Null Device), .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; -Found a Null Device gpu. +---> urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = {{.*}}, .pPropValue = {{.*}} (Mock Device), .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; +Found a Mock Device gpu. ---> urAdapterRelease(.hAdapter = {{.*}}) -> UR_RESULT_SUCCESS; diff --git a/test/layers/tracing/hello_world.out.match b/test/layers/tracing/hello_world.out.match index cef17b8fdf..cd25b5b9ba 100644 --- a/test/layers/tracing/hello_world.out.match +++ b/test/layers/tracing/hello_world.out.match @@ -18,6 +18,6 @@ function_with_args_begin(8) - urDeviceGetInfo(unimplemented); function_with_args_end(8) - urDeviceGetInfo(...) -> ur_result_t(0); function_with_args_begin(9) - urDeviceGetInfo(unimplemented); function_with_args_end(9) - urDeviceGetInfo(...) -> ur_result_t(0); -Found a Null Device gpu. +Found a Mock Device gpu. function_with_args_begin(10) - urAdapterRelease(unimplemented); function_with_args_end(10) - urAdapterRelease(...) -> ur_result_t(0); diff --git a/test/layers/validation/CMakeLists.txt b/test/layers/validation/CMakeLists.txt index 63f7de7a8d..063c5ffc89 100644 --- a/test/layers/validation/CMakeLists.txt +++ b/test/layers/validation/CMakeLists.txt @@ -14,6 +14,7 @@ function(add_validation_test_executable name) ${PROJECT_NAME}::loader ${PROJECT_NAME}::headers ${PROJECT_NAME}::testing + ${PROJECT_NAME}::mock GTest::gtest_main) endfunction() diff --git a/test/layers/validation/fixtures.hpp b/test/layers/validation/fixtures.hpp index 4dcb553282..bc76f9dedf 100644 --- a/test/layers/validation/fixtures.hpp +++ b/test/layers/validation/fixtures.hpp @@ -6,8 +6,11 @@ #ifndef UR_VALIDATION_TEST_HELPERS_H #define UR_VALIDATION_TEST_HELPERS_H +#include #include + #include +#include struct urTest : ::testing::Test { @@ -123,23 +126,47 @@ struct valAllDevicesTest : valPlatformTest { std::vector devices; }; +// We use this to avoid segfaults in the mock adapter when we're doing stuff +// like double releases in the leak detection tests. +inline ur_result_t genericSuccessCallback(void *) { return UR_RESULT_SUCCESS; }; + +// This returns valid (non-null) handles that we can safely leak. +inline ur_result_t fakeContext_urContextCreate(void *pParams) { + static std::atomic_int handle = 1; + auto params = *static_cast(pParams); + **params.pphContext = reinterpret_cast(handle++); + return UR_RESULT_SUCCESS; +} + struct valDeviceTest : valAllDevicesTest { void SetUp() override { valAllDevicesTest::SetUp(); ASSERT_GE(devices.size(), 1); device = devices[0]; + mock::getCallbacks().set_replace_callback("urContextRetain", + &genericSuccessCallback); + mock::getCallbacks().set_replace_callback("urContextRelease", + &genericSuccessCallback); + mock::getCallbacks().set_replace_callback("urContextCreate", + &fakeContext_urContextCreate); + } + + void TearDown() override { + mock::getCallbacks().resetCallbacks(); + valAllDevicesTest::TearDown(); } ur_device_handle_t device; }; struct valDeviceTestMultithreaded : valDeviceTest, public ::testing::WithParamInterface { - void SetUp() override { valDeviceTest::SetUp(); + threadCount = GetParam(); } + int threadCount; }; diff --git a/test/layers/validation/leaks.cpp b/test/layers/validation/leaks.cpp index 794e8a3ef0..59b6bdb750 100644 --- a/test/layers/validation/leaks.cpp +++ b/test/layers/validation/leaks.cpp @@ -5,20 +5,44 @@ #include "fixtures.hpp" -TEST_F(urTest, testUrAdapterGetLeak) { +#include + +// We need a fake handle for the below adapter leak test. +inline ur_result_t fakeAdapter_urAdapterGet(void *pParams) { + auto params = *static_cast(pParams); + **params.pphAdapters = reinterpret_cast(0x1); + return UR_RESULT_SUCCESS; +} + +class adapterLeakTest : public urTest { + void SetUp() override { + urTest::SetUp(); + mock::getCallbacks().set_replace_callback("urAdapterGet", + &fakeAdapter_urAdapterGet); + mock::getCallbacks().set_replace_callback("urAdapterRetain", + &genericSuccessCallback); + } + + void TearDown() override { + mock::getCallbacks().resetCallbacks(); + urTest::TearDown(); + } +}; + +TEST_F(adapterLeakTest, testUrAdapterGetLeak) { ur_adapter_handle_t adapter = nullptr; ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); ASSERT_NE(nullptr, adapter); } -TEST_F(urTest, testUrAdapterRetainLeak) { +TEST_F(adapterLeakTest, testUrAdapterRetainLeak) { ur_adapter_handle_t adapter = nullptr; ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); ASSERT_NE(nullptr, adapter); ASSERT_EQ(urAdapterRetain(adapter), UR_RESULT_SUCCESS); } -TEST_F(urTest, testUrAdapterRetainNonexistent) { +TEST_F(adapterLeakTest, testUrAdapterRetainNonexistent) { ur_adapter_handle_t adapter = (ur_adapter_handle_t)0xBEEF; ASSERT_EQ(urAdapterRetain(adapter), UR_RESULT_SUCCESS); ASSERT_NE(nullptr, adapter); diff --git a/test/layers/validation/leaks.out.match b/test/layers/validation/leaks.out.match index 4c7d6d3546..431ad52cef 100644 --- a/test/layers/validation/leaks.out.match +++ b/test/layers/validation/leaks.out.match @@ -1,16 +1,16 @@ {{IGNORE}} -[ RUN ] urTest.testUrAdapterGetLeak +[ RUN ] adapterLeakTest.testUrAdapterGetLeak [DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1 [ERROR]: Retained 1 reference(s) to handle {{[0-9xa-fA-F]+}} [ERROR]: Handle {{[0-9xa-fA-F]+}} was recorded for first time here: {{IGNORE}} -[ RUN ] urTest.testUrAdapterRetainLeak +[ RUN ] adapterLeakTest.testUrAdapterRetainLeak [DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1 [DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 2 [ERROR]: Retained 2 reference(s) to handle {{[0-9xa-fA-F]+}} [ERROR]: Handle {{[0-9xa-fA-F]+}} was recorded for first time here: {{IGNORE}} -[ RUN ] urTest.testUrAdapterRetainNonexistent +[ RUN ] adapterLeakTest.testUrAdapterRetainNonexistent [ERROR]: Attempting to retain nonexistent handle {{[0-9xa-fA-F]+}} {{IGNORE}} [ RUN ] valDeviceTest.testUrContextCreateLeak diff --git a/test/loader/handles/CMakeLists.txt b/test/loader/handles/CMakeLists.txt index fada9e8ebb..2504c92bae 100644 --- a/test/loader/handles/CMakeLists.txt +++ b/test/loader/handles/CMakeLists.txt @@ -12,6 +12,7 @@ target_link_libraries(test-loader-handles ${PROJECT_NAME}::common ${PROJECT_NAME}::headers ${PROJECT_NAME}::loader + ${PROJECT_NAME}::mock gmock GTest::gtest_main ) diff --git a/test/loader/handles/fixtures.hpp b/test/loader/handles/fixtures.hpp index c903de11ce..8044c90414 100644 --- a/test/loader/handles/fixtures.hpp +++ b/test/loader/handles/fixtures.hpp @@ -8,14 +8,47 @@ #include "ur_api.h" #include +#include #ifndef ASSERT_SUCCESS #define ASSERT_SUCCESS(ACTUAL) ASSERT_EQ(UR_RESULT_SUCCESS, ACTUAL) #endif +ur_result_t replace_urPlatformGet(void *pParams) { + auto params = *static_cast(pParams); + + if (*params.ppNumPlatforms) { + **params.ppNumPlatforms = 1; + } + + if (*params.pphPlatforms && *params.pNumEntries == 1) { + **params.pphPlatforms = reinterpret_cast(0x1); + } + + return UR_RESULT_SUCCESS; +} + +ur_result_t replace_urDeviceGetInfo(void *pParams) { + auto params = *static_cast(pParams); + if (*params.ppropName == UR_DEVICE_INFO_PLATFORM) { + if (*params.ppPropSizeRet) { + **params.ppPropSizeRet = sizeof(ur_platform_handle_t); + } + if (*params.ppPropValue) { + **(reinterpret_cast(params.ppPropValue)) = + reinterpret_cast(0x1); + } + } + return UR_RESULT_SUCCESS; +} + struct LoaderHandleTest : ::testing::Test { void SetUp() override { urLoaderInit(0, nullptr); + mock::getCallbacks().set_replace_callback("urDeviceGetInfo", + &replace_urDeviceGetInfo); + mock::getCallbacks().set_replace_callback("urPlatformGet", + &replace_urPlatformGet); uint32_t nadapters = 0; adapter = nullptr; ASSERT_SUCCESS(urAdapterGet(1, &adapter, &nadapters)); @@ -32,6 +65,7 @@ struct LoaderHandleTest : ::testing::Test { } void TearDown() override { + mock::getCallbacks().resetCallbacks(); urDeviceRelease(device); urAdapterRelease(adapter); urLoaderTearDown(); diff --git a/test/loader/platforms/null_platform.match b/test/loader/platforms/null_platform.match index 29cadc78b5..2e7feb0f08 100644 --- a/test/loader/platforms/null_platform.match +++ b/test/loader/platforms/null_platform.match @@ -1,3 +1,3 @@ [INFO]: urLoaderInit succeeded. [INFO]: urPlatformGet found 1 platforms -[INFO]: Found UR_PLATFORM_NULL +[INFO]: Found UR_PLATFORM_MOCK diff --git a/test/tools/urtrace/CMakeLists.txt b/test/tools/urtrace/CMakeLists.txt index 18212ce818..629982898e 100644 --- a/test/tools/urtrace/CMakeLists.txt +++ b/test/tools/urtrace/CMakeLists.txt @@ -24,9 +24,9 @@ function(add_trace_test name CLI_ARGS) set_tests_properties(${TEST_NAME} PROPERTIES LABELS "urtrace") endfunction() -add_trace_test(null_hello "--libpath $ --null") -add_trace_test(null_hello_no_args "--libpath $ --null --no-args") -add_trace_test(null_hello_filter_device "--libpath $ --null --filter \".*Device.*\"") -add_trace_test(null_hello_profiling "--libpath $ --null --profiling --time-unit ns") -add_trace_test(null_hello_begin "--libpath $ --null --print-begin") -add_trace_test(null_hello_json "--libpath $ --null --json") +add_trace_test(mock_hello "--libpath $ --mock") +add_trace_test(mock_hello_no_args "--libpath $ --mock --no-args") +add_trace_test(mock_hello_filter_device "--libpath $ --mock --filter \".*Device.*\"") +add_trace_test(mock_hello_profiling "--libpath $ --mock --profiling --time-unit ns") +add_trace_test(mock_hello_begin "--libpath $ --mock --print-begin") +add_trace_test(mock_hello_json "--libpath $ --mock --json") diff --git a/test/tools/urtrace/null_hello.match b/test/tools/urtrace/mock_hello.match similarity index 98% rename from test/tools/urtrace/null_hello.match rename to test/tools/urtrace/mock_hello.match index 54c6efb9cb..cdab3c5c81 100644 --- a/test/tools/urtrace/null_hello.match +++ b/test/tools/urtrace/mock_hello.match @@ -9,5 +9,5 @@ urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 1, .phDevices = {{{.*}}}, .pNumDevices = nullptr) -> UR_RESULT_SUCCESS; urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = {{.*}}, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = {{.*}}, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; -Found a Null Device gpu. +Found a Mock Device gpu. urAdapterRelease(.hAdapter = {{.*}}) -> UR_RESULT_SUCCESS; diff --git a/test/tools/urtrace/null_hello_begin.match b/test/tools/urtrace/mock_hello_begin.match similarity index 99% rename from test/tools/urtrace/null_hello_begin.match rename to test/tools/urtrace/mock_hello_begin.match index bf2d85145a..0fa8e075d6 100644 --- a/test/tools/urtrace/null_hello_begin.match +++ b/test/tools/urtrace/mock_hello_begin.match @@ -18,6 +18,6 @@ begin(8) - urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, . end(8) - urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = 4, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; begin(9) - urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = 1023, .pPropValue = {{.*}}, .pPropSizeRet = nullptr); end(9) - urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = 1023, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; -Found a Null Device gpu. +Found a Mock Device gpu. begin(10) - urAdapterRelease(.hAdapter = {{.*}}); end(10) - urAdapterRelease(.hAdapter = {{.*}}) -> UR_RESULT_SUCCESS; diff --git a/test/tools/urtrace/null_hello_filter_device.match b/test/tools/urtrace/mock_hello_filter_device.match similarity index 96% rename from test/tools/urtrace/null_hello_filter_device.match rename to test/tools/urtrace/mock_hello_filter_device.match index a7c06db739..4460759d7e 100644 --- a/test/tools/urtrace/null_hello_filter_device.match +++ b/test/tools/urtrace/mock_hello_filter_device.match @@ -4,4 +4,4 @@ urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 1, .phDevices = {{{.*}}}, .pNumDevices = nullptr) -> UR_RESULT_SUCCESS; urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = 4, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = 1023, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; -Found a Null Device gpu. +Found a Mock Device gpu. diff --git a/test/tools/urtrace/null_hello_json.match b/test/tools/urtrace/mock_hello_json.match similarity index 97% rename from test/tools/urtrace/null_hello_json.match rename to test/tools/urtrace/mock_hello_json.match index 5b9377e8d6..82ad2910c7 100644 --- a/test/tools/urtrace/null_hello_json.match +++ b/test/tools/urtrace/mock_hello_json.match @@ -10,8 +10,8 @@ API version: @PROJECT_VERSION_MAJOR@.@PROJECT_VERSION_MINOR@ { "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urDeviceGet", "args": "(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 0, .phDevices = {}, .pNumDevices = {{.*}} (1))" }, { "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urDeviceGet", "args": "(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 1, .phDevices = {{{.*}}}, .pNumDevices = nullptr)" }, { "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urDeviceGetInfo", "args": "(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = 4, .pPropValue = {{.*}} (UR_DEVICE_TYPE_GPU), .pPropSizeRet = nullptr)" }, -{ "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urDeviceGetInfo", "args": "(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = 1023, .pPropValue = {{.*}} (Null Device), .pPropSizeRet = nullptr)" }, -Found a Null Device gpu. +{ "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urDeviceGetInfo", "args": "(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = 1023, .pPropValue = {{.*}} (Mock Device), .pPropSizeRet = nullptr)" }, +Found a Mock Device gpu. { "cat": "UR", "ph": "X", "pid": {{.*}}, "tid": {{.*}}, "ts": {{.*}}, "dur": {{.*}}, "name": "urAdapterRelease", "args": "(.hAdapter = {{.*}})" }, {"name": "", "cat": "", "ph": "", "pid": "", "tid": "", "ts": ""} ] diff --git a/test/tools/urtrace/null_hello_no_args.match b/test/tools/urtrace/mock_hello_no_args.match similarity index 94% rename from test/tools/urtrace/null_hello_no_args.match rename to test/tools/urtrace/mock_hello_no_args.match index 6462f41d02..a36fca577f 100644 --- a/test/tools/urtrace/null_hello_no_args.match +++ b/test/tools/urtrace/mock_hello_no_args.match @@ -9,5 +9,5 @@ urDeviceGet(...) -> UR_RESULT_SUCCESS; urDeviceGet(...) -> UR_RESULT_SUCCESS; urDeviceGetInfo(...) -> UR_RESULT_SUCCESS; urDeviceGetInfo(...) -> UR_RESULT_SUCCESS; -Found a Null Device gpu. +Found a Mock Device gpu. urAdapterRelease(...) -> UR_RESULT_SUCCESS; diff --git a/test/tools/urtrace/null_hello_profiling.match b/test/tools/urtrace/mock_hello_profiling.match similarity index 98% rename from test/tools/urtrace/null_hello_profiling.match rename to test/tools/urtrace/mock_hello_profiling.match index 7bd3bd53c1..fe496aab31 100644 --- a/test/tools/urtrace/null_hello_profiling.match +++ b/test/tools/urtrace/mock_hello_profiling.match @@ -9,5 +9,5 @@ urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 1, .phDevices = {{{.*}}}, .pNumDevices = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_TYPE, .propSize = {{.*}}, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) urDeviceGetInfo(.hDevice = {{.*}}, .propName = UR_DEVICE_INFO_NAME, .propSize = {{.*}}, .pPropValue = {{.*}}, .pPropSizeRet = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) -Found a Null Device gpu. +Found a Mock Device gpu. urAdapterRelease(.hAdapter = {{.*}}) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) diff --git a/tools/urtrace/README.md b/tools/urtrace/README.md index b04277b22b..b77bc398d1 100644 --- a/tools/urtrace/README.md +++ b/tools/urtrace/README.md @@ -36,8 +36,8 @@ Here are a few examples: ### Use a custom adapter and also trace function begins `$ urtrace --adapter libur_adapter_cuda.so --begin ./sycl_app` -### Force load the null adapter and look for it in a custom path -`$ urtrace --null --libpath /opt/custom/ ./foo` +### Force load the mock adapter and look for it in a custom path +`$ urtrace --mock --libpath /opt/custom/ ./foo` ### Trace UR calls made by `./myapp --my-arg` and write JSON traces to a file `$ urtrace --json --file myapp.perf ./myapp --my-arg` From b36567830dca5ca1ce5e45e1486a071aa10805c9 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 4 Jul 2024 13:39:51 +0100 Subject: [PATCH 11/15] Fix rebase issues --- include/ur_api.h | 2 +- scripts/core/INTRO.rst | 2 +- scripts/core/registry.yml | 2 +- scripts/templates/ldrddi.cpp.mako | 4 +- source/loader/ur_ldrddi.cpp | 151 ------------------------------ 5 files changed, 5 insertions(+), 156 deletions(-) diff --git a/include/ur_api.h b/include/ur_api.h index 861984d36f..56a6f8f00c 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -226,7 +226,7 @@ typedef enum ur_function_t { UR_FUNCTION_BINDLESS_IMAGES_IMPORT_EXTERNAL_MEMORY_EXP = 226, ///< Enumerator for ::urBindlessImagesImportExternalMemoryExp UR_FUNCTION_BINDLESS_IMAGES_IMPORT_EXTERNAL_SEMAPHORE_EXP = 227, ///< Enumerator for ::urBindlessImagesImportExternalSemaphoreExp UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP = 228, ///< Enumerator for ::urEnqueueNativeCommandExp - UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED = 231, ///< Enumerator for ::urLoaderConfigSetMockingEnabled + UR_FUNCTION_LOADER_CONFIG_SET_MOCKING_ENABLED = 229, ///< Enumerator for ::urLoaderConfigSetMockingEnabled /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index 56f1a7db19..1f319d6884 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -259,7 +259,7 @@ For more information about the usage of mentioned environment variables see `Env Mocking --------------------- A mock UR adapter can be accessed for test purposes by enabling it via -${x}LoaderConfigSetMockingEnabled. +${x}LoaderConfigSetMockingEnabled. The default fallback behavior for entry points in the mock adapter is to simply return ``UR_RESULT_SUCCESS``. For entry points concerning handles, i.e. those diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 2b1fcf9f32..07e398db00 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -594,7 +594,7 @@ etors: value: '228' - name: LOADER_CONFIG_SET_MOCKING_ENABLED desc: Enumerator for $xLoaderConfigSetMockingEnabled - value: '231' + value: '229' --- type: enum desc: Defines structure types diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index d2795801ac..d1f1018a18 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -271,7 +271,7 @@ namespace ur_loader del add_local %> %for i, item in enumerate(epilogue): - %if 0 == i and not item['release'] and not th.always_wrap_outputs(obj): + %if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj): if( ${X}_RESULT_SUCCESS != result ) return result; @@ -281,7 +281,7 @@ namespace ur_loader ##%if item['release']: ##// release loader handle ##${item['factory']}.release( ${item['name']} ); - %if not item['release'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': + %if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { %if 'typename' in item: diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index ac6b9b2915..e1a7576b51 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -122,18 +122,6 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hAdapter = reinterpret_cast( - ur_adapter_factory.getInstance(*hAdapter, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -615,18 +603,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hDevice = reinterpret_cast( - ur_device_factory.getInstance(*hDevice, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -904,18 +880,6 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hContext = reinterpret_cast( - ur_context_factory.getInstance(*hContext, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -1222,18 +1186,6 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hMem = reinterpret_cast( - ur_mem_factory.getInstance(*hMem, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -1592,18 +1544,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hSampler = reinterpret_cast( - ur_sampler_factory.getInstance(*hSampler, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -2039,18 +1979,6 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *pPool = reinterpret_cast( - ur_usm_pool_factory.getInstance(*pPool, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -2438,18 +2366,6 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hPhysicalMem = reinterpret_cast( - ur_physical_mem_factory.getInstance(*hPhysicalMem, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -2698,18 +2614,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hProgram = reinterpret_cast( - ur_program_factory.getInstance(*hProgram, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -3300,18 +3204,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hKernel = reinterpret_cast( - ur_kernel_factory.getInstance(*hKernel, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -3759,18 +3651,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hQueue = reinterpret_cast( - ur_queue_factory.getInstance(*hQueue, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -4079,18 +3959,6 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hEvent = reinterpret_cast( - ur_event_factory.getInstance(*hEvent, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -6756,19 +6624,6 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - try { - // convert platform handle to loader handle - *hCommandBuffer = reinterpret_cast( - ur_exp_command_buffer_factory.getInstance(*hCommandBuffer, - dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -7514,12 +7369,6 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( // forward to device-platform result = pfnRetainCommandExp(hCommand); - if (UR_RESULT_SUCCESS != result) { - return result; - } - - // TODO: do we need to ref count the loader handles? - return result; } From 39b257ba8365232dc076ee8c1344ac321cc2b72a Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Fri, 5 Jul 2024 11:58:05 +0100 Subject: [PATCH 12/15] Hopefully fix more CI issues. --- scripts/templates/mockddi.cpp.mako | 14 ++++++++++++-- source/adapters/mock/ur_mock.hpp | 4 ++++ source/adapters/mock/ur_mockddi.cpp | 18 ++++-------------- source/loader/ur_lib.hpp | 2 +- test/conformance/source/environment.cpp | 3 ++- test/layers/validation/fixtures.hpp | 2 +- test/mock/mock.cpp | 17 +++++++---------- 7 files changed, 31 insertions(+), 29 deletions(-) diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako index 539d3f8bc1..6b79303c42 100644 --- a/scripts/templates/mockddi.cpp.mako +++ b/scripts/templates/mockddi.cpp.mako @@ -96,15 +96,25 @@ namespace driver *ppMem = mock::createDummyHandle(size); %elif re.search(r"USMPitchedAllocExp$", fname): *ppMem = mock::createDummyHandle(widthInBytes * height); - %else: - %if fname == 'urAdapterGet' or fname == 'urDeviceGet' or fname == 'urPlatformGet': + ## We need a special case for USM free since it doesn't have the handle release tag + %elif re.search(r"USMFree$", fname): + mock::releaseDummyHandle(pMem); + ## adapter, platform and device have special lifetime considerations + %elif 'urAdapter' in fname or 'urDevice' in fname or 'urPlatform' in fname: + %if re.match(r"ur(.*)Get$", fname): <% num_param = th.find_param_name(".*pNum.*", n, tags, obj) + object = re.match(r"ur(.*)Get", fname).group(1) + out_param = "ph" + object + "s" %> if(${num_param}) { *${num_param} = 1; } + if(${out_param}) { + *${out_param} = d_context.${object.lower()}; + } %endif + %else: %for item in epilogue: %if item['release']: mock::releaseDummyHandle(${item['name']}); diff --git a/source/adapters/mock/ur_mock.hpp b/source/adapters/mock/ur_mock.hpp index a4a458fbef..e12012fea3 100644 --- a/source/adapters/mock/ur_mock.hpp +++ b/source/adapters/mock/ur_mock.hpp @@ -25,6 +25,10 @@ class __urdlllocal context_t { ur_dditable_t urDdiTable = {}; context_t(); ~context_t() = default; + + ur_adapter_handle_t adapter = reinterpret_cast(1); + ur_device_handle_t device = reinterpret_cast(2); + ur_platform_handle_t platform = reinterpret_cast(3); }; extern context_t d_context; diff --git a/source/adapters/mock/ur_mockddi.cpp b/source/adapters/mock/ur_mockddi.cpp index 22bfe69b85..f00aa19b53 100644 --- a/source/adapters/mock/ur_mockddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -50,9 +50,8 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( if (pNumAdapters) { *pNumAdapters = 1; } - // optional output handle if (phAdapters) { - *phAdapters = mock::createDummyHandle(); + *phAdapters = d_context.adapter; } result = UR_RESULT_SUCCESS; } @@ -96,7 +95,6 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( result = replaceCallback(¶ms); } else { - mock::releaseDummyHandle(hAdapter); result = UR_RESULT_SUCCESS; } @@ -139,7 +137,6 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( result = replaceCallback(¶ms); } else { - mock::retainDummyHandle(hAdapter); result = UR_RESULT_SUCCESS; } @@ -301,9 +298,8 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( if (pNumPlatforms) { *pNumPlatforms = 1; } - // optional output handle if (phPlatforms) { - *phPlatforms = mock::createDummyHandle(); + *phPlatforms = d_context.platform; } result = UR_RESULT_SUCCESS; } @@ -609,9 +605,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet( if (pNumDevices) { *pNumDevices = 1; } - // optional output handle if (phDevices) { - *phDevices = mock::createDummyHandle(); + *phDevices = d_context.device; } result = UR_RESULT_SUCCESS; } @@ -710,7 +705,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( result = replaceCallback(¶ms); } else { - mock::retainDummyHandle(hDevice); result = UR_RESULT_SUCCESS; } @@ -754,7 +748,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( result = replaceCallback(¶ms); } else { - mock::releaseDummyHandle(hDevice); result = UR_RESULT_SUCCESS; } @@ -808,10 +801,6 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition( result = replaceCallback(¶ms); } else { - // optional output handle - if (phSubDevices) { - *phSubDevices = mock::createDummyHandle(); - } result = UR_RESULT_SUCCESS; } @@ -2355,6 +2344,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMFree( result = replaceCallback(¶ms); } else { + mock::releaseDummyHandle(pMem); result = UR_RESULT_SUCCESS; } diff --git a/source/loader/ur_lib.hpp b/source/loader/ur_lib.hpp index 6fafe3a32d..092208c8ff 100644 --- a/source/loader/ur_lib.hpp +++ b/source/loader/ur_lib.hpp @@ -48,7 +48,7 @@ struct ur_loader_config_handle_t_ { std::set &getEnabledLayerNames() { return enabledLayers; } codeloc_data codelocData; - bool enableMock; + bool enableMock = false; }; namespace ur_lib { diff --git a/test/conformance/source/environment.cpp b/test/conformance/source/environment.cpp index 209bc8f9a2..5105a6ae38 100644 --- a/test/conformance/source/environment.cpp +++ b/test/conformance/source/environment.cpp @@ -65,7 +65,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv) ur_loader_config_handle_t config; if (urLoaderConfigCreate(&config) == UR_RESULT_SUCCESS) { - if (urLoaderConfigEnableLayer(config, "UR_LAYER_FULL_VALIDATION")) { + if (urLoaderConfigEnableLayer(config, "UR_LAYER_FULL_VALIDATION") != + UR_RESULT_SUCCESS) { urLoaderConfigRelease(config); error = "Failed to enable validation layer"; return; diff --git a/test/layers/validation/fixtures.hpp b/test/layers/validation/fixtures.hpp index bc76f9dedf..40ec7fc66d 100644 --- a/test/layers/validation/fixtures.hpp +++ b/test/layers/validation/fixtures.hpp @@ -132,7 +132,7 @@ inline ur_result_t genericSuccessCallback(void *) { return UR_RESULT_SUCCESS; }; // This returns valid (non-null) handles that we can safely leak. inline ur_result_t fakeContext_urContextCreate(void *pParams) { - static std::atomic_int handle = 1; + static std::atomic_int handle = 42; auto params = *static_cast(pParams); **params.pphContext = reinterpret_cast(handle++); return UR_RESULT_SUCCESS; diff --git a/test/mock/mock.cpp b/test/mock/mock.cpp index 0dac92e8b4..0224e550d6 100644 --- a/test/mock/mock.cpp +++ b/test/mock/mock.cpp @@ -51,27 +51,24 @@ TEST(Mock, DefaultBehavior) { ASSERT_EQ(urDeviceRelease(device), UR_RESULT_SUCCESS); } -void checkPreInitAdapter(ur_adapter_handle_t adapter) { - ur_adapter_handle_t preInitAdapter = - reinterpret_cast(0xF00DCAFE); - ASSERT_EQ(adapter, preInitAdapter); -} - ur_result_t beforeUrAdapterGet(void *pParams) { auto params = reinterpret_cast(pParams); - checkPreInitAdapter(**params->pphAdapters); + ur_adapter_handle_t preInitAdapter = + reinterpret_cast(uintptr_t(0xF00DCAFE)); + EXPECT_EQ(**params->pphAdapters, preInitAdapter); return UR_RESULT_SUCCESS; } ur_result_t replaceUrAdapterGet(void *pParams) { auto params = reinterpret_cast(pParams); - **params->pphAdapters = reinterpret_cast(0xDEADBEEF); + **params->pphAdapters = + reinterpret_cast(uintptr_t(0xDEADBEEF)); return UR_RESULT_SUCCESS; } void checkPostInitAdapter(ur_adapter_handle_t adapter) { ur_adapter_handle_t postInitAdapter = - reinterpret_cast(0xDEADBEEF); + reinterpret_cast(uintptr_t(0xDEADBEEF)); ASSERT_EQ(adapter, postInitAdapter); } @@ -102,6 +99,6 @@ TEST(Mock, Callbacks) { mock::getCallbacks().set_after_callback("urAdapterGet", &afterUrAdapterGet); ur_adapter_handle_t adapter = - reinterpret_cast(0xF00DCAFE); + reinterpret_cast(uintptr_t(0xF00DCAFE)); ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); } From e323e526b931981591dff6d6661233a1ed70dab0 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Mon, 8 Jul 2024 12:28:49 +0100 Subject: [PATCH 13/15] Fix windows build. --- source/adapters/mock/ur_mock.cpp | 2 +- test/fuzz/CMakeLists.txt | 2 +- test/layers/validation/fixtures.hpp | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/source/adapters/mock/ur_mock.cpp b/source/adapters/mock/ur_mock.cpp index 502b532470..93fff9bd99 100644 --- a/source/adapters/mock/ur_mock.cpp +++ b/source/adapters/mock/ur_mock.cpp @@ -36,7 +36,7 @@ ur_result_t mock_urPlatformGetInfo(void *pParams) { if (*params.ppPropValue) { #if defined(_WIN32) strncpy_s(reinterpret_cast(*params.ppPropValue), - *params.ppropSize, null_platform_name, + *params.ppropSize, mock_platform_name, sizeof(mock_platform_name)); #else strncpy(reinterpret_cast(*params.ppPropValue), diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index c3b867928e..e4d0c5e0de 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -28,7 +28,7 @@ function(add_fuzz_test name label) DisableDeepBind=1) endif() else() - list(APPEND ENV_VARS UR_ADAPTERS_FORCE_LOAD=\"$\") + list(APPEND ENV_VARS UR_ADAPTERS_FORCE_LOAD=\"$\") endif() add_test(NAME ${TEST_TARGET_NAME} diff --git a/test/layers/validation/fixtures.hpp b/test/layers/validation/fixtures.hpp index 40ec7fc66d..9e261f0a1d 100644 --- a/test/layers/validation/fixtures.hpp +++ b/test/layers/validation/fixtures.hpp @@ -134,7 +134,10 @@ inline ur_result_t genericSuccessCallback(void *) { return UR_RESULT_SUCCESS; }; inline ur_result_t fakeContext_urContextCreate(void *pParams) { static std::atomic_int handle = 42; auto params = *static_cast(pParams); - **params.pphContext = reinterpret_cast(handle++); + // There are two casts because windows doesn't implicitly extend the 32 bit + // result of atomic_int::operator++. + **params.pphContext = + reinterpret_cast(static_cast(handle++)); return UR_RESULT_SUCCESS; } From b841dd9ba911ba451afc930ead3c36a4dd919c93 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Mon, 8 Jul 2024 15:23:09 +0100 Subject: [PATCH 14/15] Hopefully fix windows link error. --- source/mock/ur_mock_helpers.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/mock/ur_mock_helpers.hpp b/source/mock/ur_mock_helpers.hpp index 06ef4e5739..034d47c02d 100644 --- a/source/mock/ur_mock_helpers.hpp +++ b/source/mock/ur_mock_helpers.hpp @@ -124,6 +124,6 @@ struct callbacks_t { std::unordered_map afterCallbacks; }; -callbacks_t &getCallbacks(); +UR_DLLEXPORT callbacks_t &getCallbacks(); } // namespace mock From 2bc0e0a69ffa6e53c582d7568f39f01d9516ffa1 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 9 Jul 2024 12:27:25 +0100 Subject: [PATCH 15/15] Set mock adapter default info overrides as "before" callbacks. This way any override for these entry points from an application will take precedence over our defaults. --- source/adapters/mock/ur_mock.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/source/adapters/mock/ur_mock.cpp b/source/adapters/mock/ur_mock.cpp index 93fff9bd99..2ba6529158 100644 --- a/source/adapters/mock/ur_mock.cpp +++ b/source/adapters/mock/ur_mock.cpp @@ -89,9 +89,11 @@ ur_result_t mock_urDeviceGetInfo(void *pParams) { context_t::context_t() { mock::getCallbacks().set_replace_callback("urPlatformGetApiVersion", &mock_urPlatformGetApiVersion); - mock::getCallbacks().set_replace_callback("urPlatformGetInfo", - &mock_urPlatformGetInfo); - mock::getCallbacks().set_replace_callback("urDeviceGetInfo", - &mock_urDeviceGetInfo); + // Set the default info stuff as before overrides, this way any application + // passing in an override for them in any slot will take precedence. + mock::getCallbacks().set_before_callback("urPlatformGetInfo", + &mock_urPlatformGetInfo); + mock::getCallbacks().set_before_callback("urDeviceGetInfo", + &mock_urDeviceGetInfo); } } // namespace driver