@@ -46,18 +46,36 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
46
46
ur_device_handle_t Device,
47
47
const ur_usm_desc_t *Properties,
48
48
ur_usm_pool_handle_t Pool,
49
- size_t Size , void **ResultPtr) {
49
+ size_t Size , AllocType Type,
50
+ void **ResultPtr) {
50
51
51
52
auto ContextInfo = getContextInfo (Context);
52
- std::shared_ptr<DeviceInfo> DeviceInfo = getDeviceInfo (Device);
53
+ std::shared_ptr<DeviceInfo> DeviceInfo =
54
+ Device ? getDeviceInfo (Device) : nullptr ;
53
55
54
56
void *Allocated = nullptr ;
55
57
56
- UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
57
- Context, Device, Properties, Pool, Size , &Allocated));
58
+ if (Type == AllocType::DEVICE_USM) {
59
+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
60
+ Context, Device, Properties, Pool, Size , &Allocated));
61
+ } else if (Type == AllocType::HOST_USM) {
62
+ UR_CALL (getContext ()->urDdiTable .USM .pfnHostAlloc (
63
+ Context, Properties, Pool, Size , &Allocated));
64
+ } else if (Type == AllocType::SHARED_USM) {
65
+ UR_CALL (getContext ()->urDdiTable .USM .pfnSharedAlloc (
66
+ Context, Device, Properties, Pool, Size , &Allocated));
67
+ }
58
68
59
69
*ResultPtr = Allocated;
60
70
71
+ ContextInfo->MaxAllocatedSize =
72
+ std::max (ContextInfo->MaxAllocatedSize , Size );
73
+
74
+ // For host/shared usm, we only record the alloc size.
75
+ if (Type != AllocType::DEVICE_USM) {
76
+ return UR_RESULT_SUCCESS;
77
+ }
78
+
61
79
auto AI =
62
80
std::make_shared<MsanAllocInfo>(MsanAllocInfo{(uptr)Allocated,
63
81
Size ,
@@ -145,6 +163,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
145
163
return Result;
146
164
}
147
165
166
+ getContext ()->logger .info (" registerDeviceGlobals" );
167
+ Result = registerDeviceGlobals (Program);
168
+ if (Result != UR_RESULT_SUCCESS) {
169
+ return Result;
170
+ }
171
+
148
172
return Result;
149
173
}
150
174
@@ -213,6 +237,56 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
213
237
return UR_RESULT_SUCCESS;
214
238
}
215
239
240
+ ur_result_t
241
+ MsanInterceptor::registerDeviceGlobals (ur_program_handle_t Program) {
242
+ std::vector<ur_device_handle_t > Devices = GetDevices (Program);
243
+ assert (Devices.size () != 0 && " No devices in registerDeviceGlobals" );
244
+ auto Context = GetContext (Program);
245
+ auto ContextInfo = getContextInfo (Context);
246
+ auto ProgramInfo = getProgramInfo (Program);
247
+ assert (ProgramInfo != nullptr && " unregistered program!" );
248
+
249
+ for (auto Device : Devices) {
250
+ ManagedQueue Queue (Context, Device);
251
+
252
+ size_t MetadataSize;
253
+ void *MetadataPtr;
254
+ auto Result =
255
+ getContext ()->urDdiTable .Program .pfnGetGlobalVariablePointer (
256
+ Device, Program, kSPIR_MsanDeviceGlobalMetadata , &MetadataSize,
257
+ &MetadataPtr);
258
+ if (Result != UR_RESULT_SUCCESS) {
259
+ getContext ()->logger .info (" No device globals" );
260
+ continue ;
261
+ }
262
+
263
+ const uint64_t NumOfDeviceGlobal =
264
+ MetadataSize / sizeof (DeviceGlobalInfo);
265
+ assert ((MetadataSize % sizeof (DeviceGlobalInfo) == 0 ) &&
266
+ " DeviceGlobal metadata size is not correct" );
267
+ std::vector<DeviceGlobalInfo> GVInfos (NumOfDeviceGlobal);
268
+ Result = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
269
+ Queue, true , &GVInfos[0 ], MetadataPtr,
270
+ sizeof (DeviceGlobalInfo) * NumOfDeviceGlobal, 0 , nullptr , nullptr );
271
+ if (Result != UR_RESULT_SUCCESS) {
272
+ getContext ()->logger .error (" Device Global[{}] Read Failed: {}" ,
273
+ kSPIR_MsanDeviceGlobalMetadata , Result);
274
+ return Result;
275
+ }
276
+
277
+ auto DeviceInfo = getMsanInterceptor ()->getDeviceInfo (Device);
278
+ for (size_t i = 0 ; i < NumOfDeviceGlobal; i++) {
279
+ const auto &GVInfo = GVInfos[i];
280
+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (Queue, GVInfo.Addr ,
281
+ GVInfo.Size , 0 ));
282
+ ContextInfo->MaxAllocatedSize =
283
+ std::max (ContextInfo->MaxAllocatedSize , GVInfo.Size );
284
+ }
285
+ }
286
+
287
+ return UR_RESULT_SUCCESS;
288
+ }
289
+
216
290
ur_result_t MsanInterceptor::insertContext (ur_context_handle_t Context,
217
291
std::shared_ptr<ContextInfo> &CI) {
218
292
std::scoped_lock<ur_shared_mutex> Guard (m_ContextMapMutex);
@@ -380,10 +454,14 @@ ur_result_t MsanInterceptor::prepareLaunch(
380
454
}
381
455
382
456
// Set LaunchInfo
457
+ auto ContextInfo = getContextInfo (LaunchInfo.Context );
383
458
LaunchInfo.Data ->GlobalShadowOffset = DeviceInfo->Shadow ->ShadowBegin ;
384
459
LaunchInfo.Data ->GlobalShadowOffsetEnd = DeviceInfo->Shadow ->ShadowEnd ;
385
460
LaunchInfo.Data ->DeviceTy = DeviceInfo->Type ;
386
461
LaunchInfo.Data ->Debug = getOptions ().Debug ? 1 : 0 ;
462
+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
463
+ ContextInfo->Handle , DeviceInfo->Handle , nullptr , nullptr ,
464
+ ContextInfo->MaxAllocatedSize , &LaunchInfo.Data ->CleanShadow ));
387
465
388
466
getContext ()->logger .info (
389
467
" launch_info {} (GlobalShadow={}, Device={}, Debug={})" ,
@@ -466,6 +544,11 @@ ur_result_t USMLaunchInfo::initialize() {
466
544
USMLaunchInfo::~USMLaunchInfo () {
467
545
[[maybe_unused]] ur_result_t Result;
468
546
if (Data) {
547
+ if (Data->CleanShadow ) {
548
+ Result = getContext ()->urDdiTable .USM .pfnFree (Context,
549
+ Data->CleanShadow );
550
+ assert (Result == UR_RESULT_SUCCESS);
551
+ }
469
552
Result = getContext ()->urDdiTable .USM .pfnFree (Context, (void *)Data);
470
553
assert (Result == UR_RESULT_SUCCESS);
471
554
}
0 commit comments