Skip to content

[DeviceMSAN] Fix "urEnqueueUSM" APIs #2513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 6, 2025
258 changes: 257 additions & 1 deletion source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
UR_CALL(DI->allocShadowMemory(Context));
}
CI->DeviceList.emplace_back(hDevice);
CI->AllocInfosMap[hDevice];
}
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -110,6 +109,17 @@ ur_result_t urUSMDeviceAlloc(
pool, size, ppMem);
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urUSMFree
__urdlllocal ur_result_t UR_APICALL urUSMFree(
ur_context_handle_t hContext, ///< [in] handle of the context object
void *pMem ///< [in] pointer to USM memory object
) {
getContext()->logger.debug("==== urUSMFree");

return getMsanInterceptor()->releaseMemory(hContext, pMem);
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urProgramCreateWithIL
ur_result_t urProgramCreateWithIL(
Expand Down Expand Up @@ -1271,6 +1281,247 @@ ur_result_t urKernelSetArgMemObj(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMFill
ur_result_t UR_APICALL urEnqueueUSMFill(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
void *pMem, ///< [in][bounds(0, size)] pointer to USM memory object
size_t
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
const void
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
size_t
size, ///< [in] size in bytes to be set. Must be a multiple of patternSize.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before this command can be executed.
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
///< command does not wait on any event to complete.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
///< must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
getContext()->logger.debug("==== urEnqueueUSMFill");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
if (MemInfoItOp) {
auto MemInfo = (*MemInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMMemcpy
ur_result_t UR_APICALL urEnqueueUSMMemcpy(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
bool blocking, ///< [in] blocking or non-blocking copy
void *
pDst, ///< [in][bounds(0, size)] pointer to the destination USM memory object
const void *
pSrc, ///< [in][bounds(0, size)] pointer to the source USM memory object
size_t size, ///< [in] size in bytes to be copied
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before this command can be executed.
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
///< command does not wait on any event to complete.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
///< must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
getContext()->logger.debug("==== pfnUSMMemcpy");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);

if (SrcInfoItOp && DstInfoItOp) {
auto SrcInfo = (*SrcInfoItOp)->second;
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMFill2D
ur_result_t UR_APICALL urEnqueueUSMFill2D(
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
void *
pMem, ///< [in][bounds(0, pitch * height)] pointer to memory to be filled.
size_t
pitch, ///< [in] the total width of the destination memory including padding.
size_t
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
const void
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
size_t
width, ///< [in] the width in bytes of each row to fill. Must be a multiple of
///< patternSize.
size_t height, ///< [in] the height of the columns to fill.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before the kernel execution.
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< kernel execution instance. If phEventWaitList and phEvent are not
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
getContext()->logger.debug("==== urEnqueueUSMFill2D");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
if (MemInfoItOp) {
auto MemInfo = (*MemInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

const char Pattern = 0;
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
width, height, 0, nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMMemcpy2D
ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
bool blocking, ///< [in] indicates if this operation should block the host.
void *
pDst, ///< [in][bounds(0, dstPitch * height)] pointer to memory where data will
///< be copied.
size_t
dstPitch, ///< [in] the total width of the source memory including padding.
const void *
pSrc, ///< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
size_t
srcPitch, ///< [in] the total width of the source memory including padding.
size_t width, ///< [in] the width in bytes of each row to be copied.
size_t height, ///< [in] the height of columns to be copied.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before the kernel execution.
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< kernel execution instance. If phEventWaitList and phEvent are not
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
getContext()->logger.debug("==== pfnUSMMemcpy2D");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
width, height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);

if (SrcInfoItOp && DstInfoItOp) {
auto SrcInfo = (*SrcInfoItOp)->second;
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
(void *)SrcShadow, srcPitch, width, height, 0,
nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const char Pattern = 0;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's Global table
/// with current process' addresses
Expand Down Expand Up @@ -1429,6 +1680,10 @@ ur_result_t urGetEnqueueProcAddrTable(
pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
pDdiTable->pfnKernelLaunch =
ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;

return result;
}
Expand All @@ -1446,6 +1701,7 @@ ur_result_t urGetUSMProcAddrTable(
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;

return result;
}
Expand Down
Loading
Loading