Skip to content

Commit 6a3fece

Browse files
authored
Merge pull request #2534 from AllanZyne/review/yang/msan_device_global
[DeviceMSAN] Fix gpu crashed on device global variable
2 parents cb74dc9 + 5eaac97 commit 6a3fece

File tree

6 files changed

+144
-20
lines changed

6 files changed

+144
-20
lines changed

source/loader/layers/sanitizer/msan/msan_buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
138138
USMDesc.align = getAlignment();
139139
ur_usm_pool_handle_t Pool{};
140140
URes = getMsanInterceptor()->allocateMemory(
141-
Context, Device, &USMDesc, Pool, Size,
141+
Context, Device, &USMDesc, Pool, Size, AllocType::DEVICE_USM,
142142
ur_cast<void **>(&Allocation));
143143
if (URes != UR_RESULT_SUCCESS) {
144144
getContext()->logger.error(
@@ -181,8 +181,8 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
181181
ur_usm_desc_t USMDesc{};
182182
USMDesc.align = getAlignment();
183183
ur_usm_pool_handle_t Pool{};
184-
URes = getMsanInterceptor()->allocateMemory(
185-
Context, nullptr, &USMDesc, Pool, Size,
184+
URes = getContext()->urDdiTable.USM.pfnHostAlloc(
185+
Context, &USMDesc, Pool, Size,
186186
ur_cast<void **>(&HostAllocation));
187187
if (URes != UR_RESULT_SUCCESS) {
188188
getContext()->logger.error("Failed to allocate {} bytes host "

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,50 @@ ur_result_t urUSMDeviceAlloc(
9999
) {
100100
getContext()->logger.debug("==== urUSMDeviceAlloc");
101101

102-
return getMsanInterceptor()->allocateMemory(hContext, hDevice, pUSMDesc,
103-
pool, size, ppMem);
102+
return getMsanInterceptor()->allocateMemory(
103+
hContext, hDevice, pUSMDesc, pool, size, AllocType::DEVICE_USM, ppMem);
104+
}
105+
106+
///////////////////////////////////////////////////////////////////////////////
107+
/// @brief Intercept function for urUSMHostAlloc
108+
ur_result_t UR_APICALL urUSMHostAlloc(
109+
ur_context_handle_t hContext, ///< [in] handle of the context object
110+
const ur_usm_desc_t
111+
*pUSMDesc, ///< [in][optional] USM memory allocation descriptor
112+
ur_usm_pool_handle_t
113+
pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate
114+
size_t
115+
size, ///< [in] size in bytes of the USM memory object to be allocated
116+
void **ppMem ///< [out] pointer to USM host memory object
117+
) {
118+
getContext()->logger.debug("==== urUSMHostAlloc");
119+
120+
return getMsanInterceptor()->allocateMemory(
121+
hContext, nullptr, pUSMDesc, pool, size, AllocType::HOST_USM, ppMem);
122+
}
123+
124+
///////////////////////////////////////////////////////////////////////////////
125+
/// @brief Intercept function for urUSMSharedAlloc
126+
ur_result_t UR_APICALL urUSMSharedAlloc(
127+
ur_context_handle_t hContext, ///< [in] handle of the context object
128+
ur_device_handle_t hDevice, ///< [in] handle of the device object
129+
const ur_usm_desc_t *
130+
pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor.
131+
ur_usm_pool_handle_t
132+
pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate
133+
size_t
134+
size, ///< [in] size in bytes of the USM memory object to be allocated
135+
void **ppMem ///< [out] pointer to USM shared memory object
136+
) {
137+
getContext()->logger.debug("==== urUSMSharedAlloc");
138+
139+
return getMsanInterceptor()->allocateMemory(
140+
hContext, hDevice, pUSMDesc, pool, size, AllocType::SHARED_USM, ppMem);
104141
}
105142

106143
///////////////////////////////////////////////////////////////////////////////
107144
/// @brief Intercept function for urUSMFree
108-
__urdlllocal ur_result_t UR_APICALL urUSMFree(
145+
ur_result_t UR_APICALL urUSMFree(
109146
ur_context_handle_t hContext, ///< [in] handle of the context object
110147
void *pMem ///< [in] pointer to USM memory object
111148
) {
@@ -1748,6 +1785,8 @@ ur_result_t urGetUSMProcAddrTable(
17481785
ur_result_t result = UR_RESULT_SUCCESS;
17491786

17501787
pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
1788+
pDdiTable->pfnHostAlloc = ur_sanitizer_layer::msan::urUSMHostAlloc;
1789+
pDdiTable->pfnSharedAlloc = ur_sanitizer_layer::msan::urUSMSharedAlloc;
17511790
pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;
17521791

17531792
return result;

source/loader/layers/sanitizer/msan/msan_interceptor.cpp

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,36 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
4646
ur_device_handle_t Device,
4747
const ur_usm_desc_t *Properties,
4848
ur_usm_pool_handle_t Pool,
49-
size_t Size, void **ResultPtr) {
49+
size_t Size, AllocType Type,
50+
void **ResultPtr) {
5051

5152
auto ContextInfo = getContextInfo(Context);
52-
std::shared_ptr<DeviceInfo> DeviceInfo = getDeviceInfo(Device);
53+
std::shared_ptr<DeviceInfo> DeviceInfo =
54+
Device ? getDeviceInfo(Device) : nullptr;
5355

5456
void *Allocated = nullptr;
5557

56-
UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc(
57-
Context, Device, Properties, Pool, Size, &Allocated));
58+
if (Type == AllocType::DEVICE_USM) {
59+
UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc(
60+
Context, Device, Properties, Pool, Size, &Allocated));
61+
} else if (Type == AllocType::HOST_USM) {
62+
UR_CALL(getContext()->urDdiTable.USM.pfnHostAlloc(
63+
Context, Properties, Pool, Size, &Allocated));
64+
} else if (Type == AllocType::SHARED_USM) {
65+
UR_CALL(getContext()->urDdiTable.USM.pfnSharedAlloc(
66+
Context, Device, Properties, Pool, Size, &Allocated));
67+
}
5868

5969
*ResultPtr = Allocated;
6070

71+
ContextInfo->MaxAllocatedSize =
72+
std::max(ContextInfo->MaxAllocatedSize, Size);
73+
74+
// For host/shared usm, we only record the alloc size.
75+
if (Type != AllocType::DEVICE_USM) {
76+
return UR_RESULT_SUCCESS;
77+
}
78+
6179
auto AI =
6280
std::make_shared<MsanAllocInfo>(MsanAllocInfo{(uptr)Allocated,
6381
Size,
@@ -145,6 +163,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
145163
return Result;
146164
}
147165

166+
getContext()->logger.info("registerDeviceGlobals");
167+
Result = registerDeviceGlobals(Program);
168+
if (Result != UR_RESULT_SUCCESS) {
169+
return Result;
170+
}
171+
148172
return Result;
149173
}
150174

@@ -213,6 +237,56 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
213237
return UR_RESULT_SUCCESS;
214238
}
215239

240+
ur_result_t
241+
MsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
242+
std::vector<ur_device_handle_t> Devices = GetDevices(Program);
243+
assert(Devices.size() != 0 && "No devices in registerDeviceGlobals");
244+
auto Context = GetContext(Program);
245+
auto ContextInfo = getContextInfo(Context);
246+
auto ProgramInfo = getProgramInfo(Program);
247+
assert(ProgramInfo != nullptr && "unregistered program!");
248+
249+
for (auto Device : Devices) {
250+
ManagedQueue Queue(Context, Device);
251+
252+
size_t MetadataSize;
253+
void *MetadataPtr;
254+
auto Result =
255+
getContext()->urDdiTable.Program.pfnGetGlobalVariablePointer(
256+
Device, Program, kSPIR_MsanDeviceGlobalMetadata, &MetadataSize,
257+
&MetadataPtr);
258+
if (Result != UR_RESULT_SUCCESS) {
259+
getContext()->logger.info("No device globals");
260+
continue;
261+
}
262+
263+
const uint64_t NumOfDeviceGlobal =
264+
MetadataSize / sizeof(DeviceGlobalInfo);
265+
assert((MetadataSize % sizeof(DeviceGlobalInfo) == 0) &&
266+
"DeviceGlobal metadata size is not correct");
267+
std::vector<DeviceGlobalInfo> GVInfos(NumOfDeviceGlobal);
268+
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
269+
Queue, true, &GVInfos[0], MetadataPtr,
270+
sizeof(DeviceGlobalInfo) * NumOfDeviceGlobal, 0, nullptr, nullptr);
271+
if (Result != UR_RESULT_SUCCESS) {
272+
getContext()->logger.error("Device Global[{}] Read Failed: {}",
273+
kSPIR_MsanDeviceGlobalMetadata, Result);
274+
return Result;
275+
}
276+
277+
auto DeviceInfo = getMsanInterceptor()->getDeviceInfo(Device);
278+
for (size_t i = 0; i < NumOfDeviceGlobal; i++) {
279+
const auto &GVInfo = GVInfos[i];
280+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, GVInfo.Addr,
281+
GVInfo.Size, 0));
282+
ContextInfo->MaxAllocatedSize =
283+
std::max(ContextInfo->MaxAllocatedSize, GVInfo.Size);
284+
}
285+
}
286+
287+
return UR_RESULT_SUCCESS;
288+
}
289+
216290
ur_result_t MsanInterceptor::insertContext(ur_context_handle_t Context,
217291
std::shared_ptr<ContextInfo> &CI) {
218292
std::scoped_lock<ur_shared_mutex> Guard(m_ContextMapMutex);
@@ -380,10 +454,14 @@ ur_result_t MsanInterceptor::prepareLaunch(
380454
}
381455

382456
// Set LaunchInfo
457+
auto ContextInfo = getContextInfo(LaunchInfo.Context);
383458
LaunchInfo.Data->GlobalShadowOffset = DeviceInfo->Shadow->ShadowBegin;
384459
LaunchInfo.Data->GlobalShadowOffsetEnd = DeviceInfo->Shadow->ShadowEnd;
385460
LaunchInfo.Data->DeviceTy = DeviceInfo->Type;
386461
LaunchInfo.Data->Debug = getOptions().Debug ? 1 : 0;
462+
UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc(
463+
ContextInfo->Handle, DeviceInfo->Handle, nullptr, nullptr,
464+
ContextInfo->MaxAllocatedSize, &LaunchInfo.Data->CleanShadow));
387465

388466
getContext()->logger.info(
389467
"launch_info {} (GlobalShadow={}, Device={}, Debug={})",
@@ -466,6 +544,11 @@ ur_result_t USMLaunchInfo::initialize() {
466544
USMLaunchInfo::~USMLaunchInfo() {
467545
[[maybe_unused]] ur_result_t Result;
468546
if (Data) {
547+
if (Data->CleanShadow) {
548+
Result = getContext()->urDdiTable.USM.pfnFree(Context,
549+
Data->CleanShadow);
550+
assert(Result == UR_RESULT_SUCCESS);
551+
}
469552
Result = getContext()->urDdiTable.USM.pfnFree(Context, (void *)Data);
470553
assert(Result == UR_RESULT_SUCCESS);
471554
}

source/loader/layers/sanitizer/msan/msan_interceptor.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ struct ProgramInfo {
121121

122122
struct ContextInfo {
123123
ur_context_handle_t Handle;
124+
size_t MaxAllocatedSize = 1024;
124125
std::atomic<int32_t> RefCount = 1;
125126

126127
std::vector<ur_device_handle_t> DeviceList;
@@ -159,6 +160,11 @@ struct USMLaunchInfo {
159160
ur_result_t initialize();
160161
};
161162

163+
struct DeviceGlobalInfo {
164+
uptr Size;
165+
uptr Addr;
166+
};
167+
162168
struct SpirKernelInfo {
163169
uptr KernelName;
164170
uptr Size;
@@ -174,7 +180,7 @@ class MsanInterceptor {
174180
ur_device_handle_t Device,
175181
const ur_usm_desc_t *Properties,
176182
ur_usm_pool_handle_t Pool, size_t Size,
177-
void **ResultPtr);
183+
AllocType Type, void **ResultPtr);
178184
ur_result_t releaseMemory(ur_context_handle_t Context, void *Ptr);
179185

180186
ur_result_t registerProgram(ur_program_handle_t Program);
@@ -261,6 +267,7 @@ class MsanInterceptor {
261267
std::shared_ptr<msan::DeviceInfo> &DeviceInfo);
262268

263269
ur_result_t registerSpirKernels(ur_program_handle_t Program);
270+
ur_result_t registerDeviceGlobals(ur_program_handle_t Program);
264271

265272
private:
266273
std::unordered_map<ur_context_handle_t, std::shared_ptr<msan::ContextInfo>>

source/loader/layers/sanitizer/msan/msan_libdevice.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct MsanLaunchInfo {
5353

5454
MsanErrorReport Report;
5555

56-
uint8_t CleanShadow[128] = {};
56+
void *CleanShadow = nullptr;
5757
};
5858

5959
// Based on the observation, only the last 24 bits of the address of the private

source/loader/layers/sanitizer/msan/msan_shadow.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,10 @@ ur_result_t MsanShadowMemoryGPU::EnqueueMapShadow(
227227
VirtualMemMaps[MappedPtr].first = PhysicalMem;
228228
}
229229

230-
// We don't need to record virtual memory map for null pointer,
231-
// since it doesn't have an alloc info.
232-
if (Ptr == 0) {
233-
continue;
230+
auto AllocInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Ptr);
231+
if (AllocInfoItOp) {
232+
VirtualMemMaps[MappedPtr].second.insert((*AllocInfoItOp)->second);
234233
}
235-
236-
auto AllocInfoIt = getMsanInterceptor()->findAllocInfoByAddress(Ptr);
237-
assert(AllocInfoIt);
238-
VirtualMemMaps[MappedPtr].second.insert((*AllocInfoIt)->second);
239234
}
240235

241236
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)