@@ -515,6 +515,12 @@ ur_result_t urMemBufferCreate(
515
515
UR_CALL (pMemBuffer->getHandle (hDevice, Handle ));
516
516
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
517
517
InternalQueue, true , Handle , Host, size, 0 , nullptr , nullptr ));
518
+
519
+ // Update shadow memory
520
+ std::shared_ptr<DeviceInfo> DeviceInfo =
521
+ getMsanInterceptor ()->getDeviceInfo (hDevice);
522
+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (
523
+ InternalQueue, (uptr)Handle , size, 0 ));
518
524
}
519
525
}
520
526
@@ -730,10 +736,29 @@ ur_result_t urEnqueueMemBufferWrite(
730
736
if (auto MemBuffer = getMsanInterceptor ()->getMemBuffer (hBuffer)) {
731
737
ur_device_handle_t Device = GetDevice (hQueue);
732
738
char *pDst = nullptr ;
739
+ std::vector<ur_event_handle_t > Events;
740
+ ur_event_handle_t Event{};
733
741
UR_CALL (MemBuffer->getHandle (Device, pDst));
734
742
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
735
743
hQueue, blockingWrite, pDst + offset, pSrc, size,
736
- numEventsInWaitList, phEventWaitList, phEvent));
744
+ numEventsInWaitList, phEventWaitList, &Event));
745
+ Events.push_back (Event);
746
+
747
+ // Update shadow memory
748
+ std::shared_ptr<DeviceInfo> DeviceInfo =
749
+ getMsanInterceptor ()->getDeviceInfo (Device);
750
+ const char Val = 0 ;
751
+ uptr ShadowAddr = DeviceInfo->Shadow ->MemToShadow ((uptr)pDst + offset);
752
+ Event = nullptr ;
753
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
754
+ hQueue, (void *)ShadowAddr, 1 , &Val, size, numEventsInWaitList,
755
+ phEventWaitList, &Event));
756
+ Events.push_back (Event);
757
+
758
+ if (phEvent) {
759
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
760
+ hQueue, Events.size (), Events.data (), phEvent));
761
+ }
737
762
} else {
738
763
UR_CALL (pfnMemBufferWrite (hQueue, hBuffer, blockingWrite, offset, size,
739
764
pSrc, numEventsInWaitList, phEventWaitList,
@@ -893,15 +918,36 @@ ur_result_t urEnqueueMemBufferCopy(
893
918
894
919
if (SrcBuffer && DstBuffer) {
895
920
ur_device_handle_t Device = GetDevice (hQueue);
921
+ std::shared_ptr<DeviceInfo> DeviceInfo =
922
+ getMsanInterceptor ()->getDeviceInfo (Device);
896
923
char *SrcHandle = nullptr ;
897
924
UR_CALL (SrcBuffer->getHandle (Device, SrcHandle));
898
925
899
926
char *DstHandle = nullptr ;
900
927
UR_CALL (DstBuffer->getHandle (Device, DstHandle));
901
928
929
+ std::vector<ur_event_handle_t > Events;
930
+ ur_event_handle_t Event{};
902
931
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
903
932
hQueue, false , DstHandle + dstOffset, SrcHandle + srcOffset, size,
904
- numEventsInWaitList, phEventWaitList, phEvent));
933
+ numEventsInWaitList, phEventWaitList, &Event));
934
+ Events.push_back (Event);
935
+
936
+ // Update shadow memory
937
+ uptr DstShadowAddr =
938
+ DeviceInfo->Shadow ->MemToShadow ((uptr)DstHandle + dstOffset);
939
+ uptr SrcShadowAddr =
940
+ DeviceInfo->Shadow ->MemToShadow ((uptr)SrcHandle + srcOffset);
941
+ Event = nullptr ;
942
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
943
+ hQueue, false , (void *)DstShadowAddr, (void *)SrcShadowAddr, size,
944
+ numEventsInWaitList, phEventWaitList, &Event));
945
+ Events.push_back (Event);
946
+
947
+ if (phEvent) {
948
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
949
+ hQueue, Events.size (), Events.data (), phEvent));
950
+ }
905
951
} else {
906
952
UR_CALL (pfnMemBufferCopy (hQueue, hBufferSrc, hBufferDst, srcOffset,
907
953
dstOffset, size, numEventsInWaitList,
@@ -1000,11 +1046,31 @@ ur_result_t urEnqueueMemBufferFill(
1000
1046
1001
1047
if (auto MemBuffer = getMsanInterceptor ()->getMemBuffer (hBuffer)) {
1002
1048
char *Handle = nullptr ;
1049
+ std::vector<ur_event_handle_t > Events;
1050
+ ur_event_handle_t Event{};
1003
1051
ur_device_handle_t Device = GetDevice (hQueue);
1004
1052
UR_CALL (MemBuffer->getHandle (Device, Handle ));
1005
1053
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
1006
1054
hQueue, Handle + offset, patternSize, pPattern, size,
1007
- numEventsInWaitList, phEventWaitList, phEvent));
1055
+ numEventsInWaitList, phEventWaitList, &Event));
1056
+ Events.push_back (Event);
1057
+
1058
+ // Update shadow memory
1059
+ std::shared_ptr<DeviceInfo> DeviceInfo =
1060
+ getMsanInterceptor ()->getDeviceInfo (Device);
1061
+ const char Val = 0 ;
1062
+ uptr ShadowAddr =
1063
+ DeviceInfo->Shadow ->MemToShadow ((uptr)Handle + offset);
1064
+ Event = nullptr ;
1065
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
1066
+ hQueue, (void *)ShadowAddr, 1 , &Val, size, numEventsInWaitList,
1067
+ phEventWaitList, &Event));
1068
+ Events.push_back (Event);
1069
+
1070
+ if (phEvent) {
1071
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1072
+ hQueue, Events.size (), Events.data (), phEvent));
1073
+ }
1008
1074
} else {
1009
1075
UR_CALL (pfnMemBufferFill (hQueue, hBuffer, pPattern, patternSize, offset,
1010
1076
size, numEventsInWaitList, phEventWaitList,
@@ -1270,9 +1336,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
1270
1336
auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
1271
1337
getContext ()->logger .debug (" ==== urEnqueueUSMFill" );
1272
1338
1273
- ur_event_handle_t hEvents[2 ] = {};
1339
+ std::vector<ur_event_handle_t > Events;
1340
+ ur_event_handle_t Event{};
1274
1341
UR_CALL (pfnUSMFill (hQueue, pMem, patternSize, pPattern, size,
1275
- numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1342
+ numEventsInWaitList, phEventWaitList, &Event));
1343
+ Events.push_back (Event);
1276
1344
1277
1345
const auto Mem = (uptr)pMem;
1278
1346
auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
@@ -1283,13 +1351,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
1283
1351
getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1284
1352
const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1285
1353
1354
+ Event = nullptr ;
1286
1355
UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)MemShadow, 0 , size, 0 ,
1287
- nullptr , &hEvents[1 ]));
1356
+ nullptr , &Event));
1357
+ Events.push_back (Event);
1288
1358
}
1289
1359
1290
1360
if (phEvent) {
1291
1361
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1292
- hQueue, 2 , hEvents , phEvent));
1362
+ hQueue, Events. size (), Events. data () , phEvent));
1293
1363
}
1294
1364
1295
1365
return UR_RESULT_SUCCESS;
@@ -1319,9 +1389,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
1319
1389
auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
1320
1390
getContext ()->logger .debug (" ==== pfnUSMMemcpy" );
1321
1391
1322
- ur_event_handle_t hEvents[2 ] = {};
1392
+ std::vector<ur_event_handle_t > Events;
1393
+ ur_event_handle_t Event{};
1323
1394
UR_CALL (pfnUSMMemcpy (hQueue, blocking, pDst, pSrc, size,
1324
- numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1395
+ numEventsInWaitList, phEventWaitList, &Event));
1396
+ Events.push_back (Event);
1325
1397
1326
1398
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1327
1399
auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
@@ -1336,22 +1408,26 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
1336
1408
const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1337
1409
const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1338
1410
1411
+ Event = nullptr ;
1339
1412
UR_CALL (pfnUSMMemcpy (hQueue, blocking, (void *)DstShadow,
1340
- (void *)SrcShadow, size, 0 , nullptr , &hEvents[1 ]));
1413
+ (void *)SrcShadow, size, 0 , nullptr , &Event));
1414
+ Events.push_back (Event);
1341
1415
} else if (DstInfoItOp) {
1342
1416
auto DstInfo = (*DstInfoItOp)->second ;
1343
1417
1344
1418
const auto &DeviceInfo =
1345
1419
getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1346
1420
auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1347
1421
1422
+ Event = nullptr ;
1348
1423
UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)DstShadow, 0 , size, 0 ,
1349
- nullptr , &hEvents[1 ]));
1424
+ nullptr , &Event));
1425
+ Events.push_back (Event);
1350
1426
}
1351
1427
1352
1428
if (phEvent) {
1353
1429
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1354
- hQueue, 2 , hEvents , phEvent));
1430
+ hQueue, Events. size (), Events. data () , phEvent));
1355
1431
}
1356
1432
1357
1433
return UR_RESULT_SUCCESS;
@@ -1387,10 +1463,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
1387
1463
auto pfnUSMFill2D = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D ;
1388
1464
getContext ()->logger .debug (" ==== urEnqueueUSMFill2D" );
1389
1465
1390
- ur_event_handle_t hEvents[2 ] = {};
1466
+ std::vector<ur_event_handle_t > Events;
1467
+ ur_event_handle_t Event{};
1391
1468
UR_CALL (pfnUSMFill2D (hQueue, pMem, pitch, patternSize, pPattern, width,
1392
- height, numEventsInWaitList, phEventWaitList,
1393
- &hEvents[ 0 ]) );
1469
+ height, numEventsInWaitList, phEventWaitList, &Event));
1470
+ Events. push_back (Event );
1394
1471
1395
1472
const auto Mem = (uptr)pMem;
1396
1473
auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
@@ -1402,13 +1479,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
1402
1479
const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1403
1480
1404
1481
const char Pattern = 0 ;
1482
+ Event = nullptr ;
1405
1483
UR_CALL (pfnUSMFill2D (hQueue, (void *)MemShadow, pitch, 1 , &Pattern,
1406
- width, height, 0 , nullptr , &hEvents[1 ]));
1484
+ width, height, 0 , nullptr , &Event));
1485
+ Events.push_back (Event);
1407
1486
}
1408
1487
1409
1488
if (phEvent) {
1410
1489
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1411
- hQueue, 2 , hEvents , phEvent));
1490
+ hQueue, Events. size (), Events. data () , phEvent));
1412
1491
}
1413
1492
1414
1493
return UR_RESULT_SUCCESS;
@@ -1443,10 +1522,12 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
1443
1522
auto pfnUSMMemcpy2D = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy2D ;
1444
1523
getContext ()->logger .debug (" ==== pfnUSMMemcpy2D" );
1445
1524
1446
- ur_event_handle_t hEvents[2 ] = {};
1525
+ std::vector<ur_event_handle_t > Events;
1526
+ ur_event_handle_t Event{};
1447
1527
UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
1448
1528
width, height, numEventsInWaitList, phEventWaitList,
1449
- &hEvents[0 ]));
1529
+ &Event));
1530
+ Events.push_back (Event);
1450
1531
1451
1532
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1452
1533
auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
@@ -1461,9 +1542,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
1461
1542
const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1462
1543
const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1463
1544
1545
+ Event = nullptr ;
1464
1546
UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, (void *)DstShadow, dstPitch,
1465
1547
(void *)SrcShadow, srcPitch, width, height, 0 ,
1466
- nullptr , &hEvents[1 ]));
1548
+ nullptr , &Event));
1549
+ Events.push_back (Event);
1467
1550
} else if (DstInfoItOp) {
1468
1551
auto DstInfo = (*DstInfoItOp)->second ;
1469
1552
@@ -1472,14 +1555,16 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
1472
1555
const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1473
1556
1474
1557
const char Pattern = 0 ;
1558
+ Event = nullptr ;
1475
1559
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill2D (
1476
1560
hQueue, (void *)DstShadow, dstPitch, 1 , &Pattern, width, height, 0 ,
1477
- nullptr , &hEvents[1 ]));
1561
+ nullptr , &Event));
1562
+ Events.push_back (Event);
1478
1563
}
1479
1564
1480
1565
if (phEvent) {
1481
1566
UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1482
- hQueue, 2 , hEvents , phEvent));
1567
+ hQueue, Events. size (), Events. data () , phEvent));
1483
1568
}
1484
1569
1485
1570
return UR_RESULT_SUCCESS;
0 commit comments