Skip to content

Commit afbb289

Browse files
authored
Merge pull request #2520 from zhaomaosu/fix-buffer-shadow
[DevMSAN] Propagate shadow memory in buffer related APIs
2 parents ef70004 + d7c33f8 commit afbb289

File tree

3 files changed

+164
-27
lines changed

3 files changed

+164
-27
lines changed

source/loader/layers/sanitizer/msan/msan_buffer.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,67 @@ ur_result_t EnqueueMemCopyRectHelper(
4848
char *DstOrigin = pDst + DstOffset.x + DstRowPitch * DstOffset.y +
4949
DstSlicePitch * DstOffset.z;
5050

51+
const bool IsDstDeviceUSM = getMsanInterceptor()
52+
->findAllocInfoByAddress((uptr)DstOrigin)
53+
.has_value();
54+
const bool IsSrcDeviceUSM = getMsanInterceptor()
55+
->findAllocInfoByAddress((uptr)SrcOrigin)
56+
.has_value();
57+
58+
ur_device_handle_t Device = GetDevice(Queue);
59+
std::shared_ptr<DeviceInfo> DeviceInfo =
60+
getMsanInterceptor()->getDeviceInfo(Device);
5161
std::vector<ur_event_handle_t> Events;
52-
Events.reserve(Region.depth);
62+
5363
// For now, USM doesn't support 3D memory copy operation, so we can only
5464
// loop call 2D memory copy function to implement it.
5565
for (size_t i = 0; i < Region.depth; i++) {
5666
ur_event_handle_t NewEvent{};
5767
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
58-
Queue, Blocking, DstOrigin + (i * DstSlicePitch), DstRowPitch,
68+
Queue, false, DstOrigin + (i * DstSlicePitch), DstRowPitch,
5969
SrcOrigin + (i * SrcSlicePitch), SrcRowPitch, Region.width,
6070
Region.height, NumEventsInWaitList, EventWaitList, &NewEvent));
61-
6271
Events.push_back(NewEvent);
72+
73+
// Update shadow memory
74+
if (IsDstDeviceUSM && IsSrcDeviceUSM) {
75+
NewEvent = nullptr;
76+
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
77+
(uptr)DstOrigin + (i * DstSlicePitch));
78+
uptr SrcShadowAddr = DeviceInfo->Shadow->MemToShadow(
79+
(uptr)SrcOrigin + (i * SrcSlicePitch));
80+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
81+
Queue, false, (void *)DstShadowAddr, DstRowPitch,
82+
(void *)SrcShadowAddr, SrcRowPitch, Region.width, Region.height,
83+
NumEventsInWaitList, EventWaitList, &NewEvent));
84+
Events.push_back(NewEvent);
85+
} else if (IsDstDeviceUSM && !IsSrcDeviceUSM) {
86+
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
87+
(uptr)DstOrigin + (i * DstSlicePitch));
88+
const char Val = 0;
89+
// opencl & l0 adapter doesn't implement urEnqueueUSMFill2D, so
90+
// emulate the operation with urEnqueueUSMFill.
91+
for (size_t HeightIndex = 0; HeightIndex < Region.height;
92+
HeightIndex++) {
93+
NewEvent = nullptr;
94+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
95+
Queue, (void *)(DstShadowAddr + HeightIndex * DstRowPitch),
96+
1, &Val, Region.width, NumEventsInWaitList, EventWaitList,
97+
&NewEvent));
98+
Events.push_back(NewEvent);
99+
}
100+
}
63101
}
64102

65-
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
66-
Queue, Events.size(), Events.data(), Event));
103+
if (Blocking) {
104+
UR_CALL(
105+
getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0]));
106+
}
107+
108+
if (Event) {
109+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
110+
Queue, Events.size(), &Events[0], Event));
111+
}
67112

68113
return UR_RESULT_SUCCESS;
69114
}
@@ -112,6 +157,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
112157
Size, HostPtr, this);
113158
return URes;
114159
}
160+
161+
// Update shadow memory
162+
std::shared_ptr<DeviceInfo> DeviceInfo =
163+
getMsanInterceptor()->getDeviceInfo(Device);
164+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
165+
Queue, (uptr)Allocation, Size, 0));
115166
}
116167
}
117168

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,12 @@ ur_result_t urMemBufferCreate(
515515
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
516516
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
517517
InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr));
518+
519+
// Update shadow memory
520+
std::shared_ptr<DeviceInfo> DeviceInfo =
521+
getMsanInterceptor()->getDeviceInfo(hDevice);
522+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
523+
InternalQueue, (uptr)Handle, size, 0));
518524
}
519525
}
520526

@@ -730,10 +736,29 @@ ur_result_t urEnqueueMemBufferWrite(
730736
if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
731737
ur_device_handle_t Device = GetDevice(hQueue);
732738
char *pDst = nullptr;
739+
std::vector<ur_event_handle_t> Events;
740+
ur_event_handle_t Event{};
733741
UR_CALL(MemBuffer->getHandle(Device, pDst));
734742
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
735743
hQueue, blockingWrite, pDst + offset, pSrc, size,
736-
numEventsInWaitList, phEventWaitList, phEvent));
744+
numEventsInWaitList, phEventWaitList, &Event));
745+
Events.push_back(Event);
746+
747+
// Update shadow memory
748+
std::shared_ptr<DeviceInfo> DeviceInfo =
749+
getMsanInterceptor()->getDeviceInfo(Device);
750+
const char Val = 0;
751+
uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)pDst + offset);
752+
Event = nullptr;
753+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
754+
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
755+
phEventWaitList, &Event));
756+
Events.push_back(Event);
757+
758+
if (phEvent) {
759+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
760+
hQueue, Events.size(), Events.data(), phEvent));
761+
}
737762
} else {
738763
UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size,
739764
pSrc, numEventsInWaitList, phEventWaitList,
@@ -893,15 +918,36 @@ ur_result_t urEnqueueMemBufferCopy(
893918

894919
if (SrcBuffer && DstBuffer) {
895920
ur_device_handle_t Device = GetDevice(hQueue);
921+
std::shared_ptr<DeviceInfo> DeviceInfo =
922+
getMsanInterceptor()->getDeviceInfo(Device);
896923
char *SrcHandle = nullptr;
897924
UR_CALL(SrcBuffer->getHandle(Device, SrcHandle));
898925

899926
char *DstHandle = nullptr;
900927
UR_CALL(DstBuffer->getHandle(Device, DstHandle));
901928

929+
std::vector<ur_event_handle_t> Events;
930+
ur_event_handle_t Event{};
902931
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
903932
hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size,
904-
numEventsInWaitList, phEventWaitList, phEvent));
933+
numEventsInWaitList, phEventWaitList, &Event));
934+
Events.push_back(Event);
935+
936+
// Update shadow memory
937+
uptr DstShadowAddr =
938+
DeviceInfo->Shadow->MemToShadow((uptr)DstHandle + dstOffset);
939+
uptr SrcShadowAddr =
940+
DeviceInfo->Shadow->MemToShadow((uptr)SrcHandle + srcOffset);
941+
Event = nullptr;
942+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
943+
hQueue, false, (void *)DstShadowAddr, (void *)SrcShadowAddr, size,
944+
numEventsInWaitList, phEventWaitList, &Event));
945+
Events.push_back(Event);
946+
947+
if (phEvent) {
948+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
949+
hQueue, Events.size(), Events.data(), phEvent));
950+
}
905951
} else {
906952
UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset,
907953
dstOffset, size, numEventsInWaitList,
@@ -1000,11 +1046,31 @@ ur_result_t urEnqueueMemBufferFill(
10001046

10011047
if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
10021048
char *Handle = nullptr;
1049+
std::vector<ur_event_handle_t> Events;
1050+
ur_event_handle_t Event{};
10031051
ur_device_handle_t Device = GetDevice(hQueue);
10041052
UR_CALL(MemBuffer->getHandle(Device, Handle));
10051053
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
10061054
hQueue, Handle + offset, patternSize, pPattern, size,
1007-
numEventsInWaitList, phEventWaitList, phEvent));
1055+
numEventsInWaitList, phEventWaitList, &Event));
1056+
Events.push_back(Event);
1057+
1058+
// Update shadow memory
1059+
std::shared_ptr<DeviceInfo> DeviceInfo =
1060+
getMsanInterceptor()->getDeviceInfo(Device);
1061+
const char Val = 0;
1062+
uptr ShadowAddr =
1063+
DeviceInfo->Shadow->MemToShadow((uptr)Handle + offset);
1064+
Event = nullptr;
1065+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
1066+
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
1067+
phEventWaitList, &Event));
1068+
Events.push_back(Event);
1069+
1070+
if (phEvent) {
1071+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1072+
hQueue, Events.size(), Events.data(), phEvent));
1073+
}
10081074
} else {
10091075
UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset,
10101076
size, numEventsInWaitList, phEventWaitList,
@@ -1270,9 +1336,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
12701336
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
12711337
getContext()->logger.debug("==== urEnqueueUSMFill");
12721338

1273-
ur_event_handle_t hEvents[2] = {};
1339+
std::vector<ur_event_handle_t> Events;
1340+
ur_event_handle_t Event{};
12741341
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
1275-
numEventsInWaitList, phEventWaitList, &hEvents[0]));
1342+
numEventsInWaitList, phEventWaitList, &Event));
1343+
Events.push_back(Event);
12761344

12771345
const auto Mem = (uptr)pMem;
12781346
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
@@ -1283,13 +1351,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
12831351
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
12841352
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
12851353

1354+
Event = nullptr;
12861355
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
1287-
nullptr, &hEvents[1]));
1356+
nullptr, &Event));
1357+
Events.push_back(Event);
12881358
}
12891359

12901360
if (phEvent) {
12911361
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1292-
hQueue, 2, hEvents, phEvent));
1362+
hQueue, Events.size(), Events.data(), phEvent));
12931363
}
12941364

12951365
return UR_RESULT_SUCCESS;
@@ -1319,9 +1389,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13191389
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
13201390
getContext()->logger.debug("==== pfnUSMMemcpy");
13211391

1322-
ur_event_handle_t hEvents[2] = {};
1392+
std::vector<ur_event_handle_t> Events;
1393+
ur_event_handle_t Event{};
13231394
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
1324-
numEventsInWaitList, phEventWaitList, &hEvents[0]));
1395+
numEventsInWaitList, phEventWaitList, &Event));
1396+
Events.push_back(Event);
13251397

13261398
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
13271399
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
@@ -1336,22 +1408,26 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13361408
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
13371409
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
13381410

1411+
Event = nullptr;
13391412
UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
1340-
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
1413+
(void *)SrcShadow, size, 0, nullptr, &Event));
1414+
Events.push_back(Event);
13411415
} else if (DstInfoItOp) {
13421416
auto DstInfo = (*DstInfoItOp)->second;
13431417

13441418
const auto &DeviceInfo =
13451419
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
13461420
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
13471421

1422+
Event = nullptr;
13481423
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
1349-
nullptr, &hEvents[1]));
1424+
nullptr, &Event));
1425+
Events.push_back(Event);
13501426
}
13511427

13521428
if (phEvent) {
13531429
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1354-
hQueue, 2, hEvents, phEvent));
1430+
hQueue, Events.size(), Events.data(), phEvent));
13551431
}
13561432

13571433
return UR_RESULT_SUCCESS;
@@ -1387,10 +1463,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
13871463
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
13881464
getContext()->logger.debug("==== urEnqueueUSMFill2D");
13891465

1390-
ur_event_handle_t hEvents[2] = {};
1466+
std::vector<ur_event_handle_t> Events;
1467+
ur_event_handle_t Event{};
13911468
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
1392-
height, numEventsInWaitList, phEventWaitList,
1393-
&hEvents[0]));
1469+
height, numEventsInWaitList, phEventWaitList, &Event));
1470+
Events.push_back(Event);
13941471

13951472
const auto Mem = (uptr)pMem;
13961473
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
@@ -1402,13 +1479,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
14021479
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
14031480

14041481
const char Pattern = 0;
1482+
Event = nullptr;
14051483
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
1406-
width, height, 0, nullptr, &hEvents[1]));
1484+
width, height, 0, nullptr, &Event));
1485+
Events.push_back(Event);
14071486
}
14081487

14091488
if (phEvent) {
14101489
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1411-
hQueue, 2, hEvents, phEvent));
1490+
hQueue, Events.size(), Events.data(), phEvent));
14121491
}
14131492

14141493
return UR_RESULT_SUCCESS;
@@ -1443,10 +1522,12 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14431522
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
14441523
getContext()->logger.debug("==== pfnUSMMemcpy2D");
14451524

1446-
ur_event_handle_t hEvents[2] = {};
1525+
std::vector<ur_event_handle_t> Events;
1526+
ur_event_handle_t Event{};
14471527
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
14481528
width, height, numEventsInWaitList, phEventWaitList,
1449-
&hEvents[0]));
1529+
&Event));
1530+
Events.push_back(Event);
14501531

14511532
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
14521533
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
@@ -1461,9 +1542,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14611542
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
14621543
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
14631544

1545+
Event = nullptr;
14641546
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
14651547
(void *)SrcShadow, srcPitch, width, height, 0,
1466-
nullptr, &hEvents[1]));
1548+
nullptr, &Event));
1549+
Events.push_back(Event);
14671550
} else if (DstInfoItOp) {
14681551
auto DstInfo = (*DstInfoItOp)->second;
14691552

@@ -1472,14 +1555,16 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14721555
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
14731556

14741557
const char Pattern = 0;
1558+
Event = nullptr;
14751559
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
14761560
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
1477-
nullptr, &hEvents[1]));
1561+
nullptr, &Event));
1562+
Events.push_back(Event);
14781563
}
14791564

14801565
if (phEvent) {
14811566
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1482-
hQueue, 2, hEvents, phEvent));
1567+
hQueue, Events.size(), Events.data(), phEvent));
14831568
}
14841569

14851570
return UR_RESULT_SUCCESS;

source/loader/layers/sanitizer/msan/msan_interceptor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
7575
m_AllocationMap.emplace(AI->AllocBegin, AI);
7676
}
7777

78+
// Update shadow memory
7879
ManagedQueue Queue(Context, Device);
7980
DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, AI->AllocBegin,
8081
AI->AllocSize, 0xff);

0 commit comments

Comments
 (0)