@@ -115,10 +115,7 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
115
115
if (!ownHostPtr) {
116
116
return ;
117
117
}
118
- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
119
- if (ret != UR_RESULT_SUCCESS) {
120
- logger::error (" Failed to free host memory: {}" , ret);
121
- }
118
+ ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
122
119
});
123
120
}
124
121
@@ -209,7 +206,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
209
206
device_access_mode_t accessMode)
210
207
: ur_mem_handle_t_(hContext, size, accessMode),
211
208
deviceAllocations (hContext->getPlatform ()->getNumDevices()),
212
- activeAllocationDevice(nullptr ), hostAllocations() {
209
+ activeAllocationDevice(nullptr ), mapToPtr(hostPtr), hostAllocations() {
213
210
if (hostPtr) {
214
211
auto initialDevice = hContext->getDevices ()[0 ];
215
212
UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
@@ -234,10 +231,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
234
231
if (!ownZePtr) {
235
232
return ;
236
233
}
237
- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
238
- if (ret != UR_RESULT_SUCCESS) {
239
- logger::error (" Failed to free device memory: {}" , ret);
240
- }
234
+ ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
241
235
});
242
236
}
243
237
}
@@ -246,12 +240,18 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
246
240
if (!activeAllocationDevice || !writeBackPtr)
247
241
return ;
248
242
249
- auto srcPtr = ur_cast<char *>(
250
- deviceAllocations[activeAllocationDevice->Id .value ()].get ());
243
+ auto srcPtr = getActiveDeviceAlloc ();
251
244
synchronousZeCopy (hContext, activeAllocationDevice, writeBackPtr, srcPtr,
252
245
getSize ());
253
246
}
254
247
248
+ void *ur_discrete_mem_handle_t ::getActiveDeviceAlloc(size_t offset) {
249
+ assert (activeAllocationDevice);
250
+ return ur_cast<char *>(
251
+ deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
252
+ offset;
253
+ }
254
+
255
255
void *ur_discrete_mem_handle_t ::getDevicePtr(
256
256
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
257
257
size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
@@ -272,10 +272,8 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
272
272
hDevice = activeAllocationDevice;
273
273
}
274
274
275
- char *ptr;
276
275
if (activeAllocationDevice == hDevice) {
277
- ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()].get ());
278
- return ptr + offset;
276
+ return getActiveDeviceAlloc (offset);
279
277
}
280
278
281
279
auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -288,9 +286,7 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
288
286
}
289
287
290
288
// TODO: see if it's better to migrate the memory to the specified device
291
- return ur_cast<char *>(
292
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
293
- offset;
289
+ return getActiveDeviceAlloc (offset);
294
290
}
295
291
296
292
void *ur_discrete_mem_handle_t ::mapHostPtr(
@@ -299,55 +295,63 @@ void *ur_discrete_mem_handle_t::mapHostPtr(
299
295
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
300
296
// TODO: use async alloc?
301
297
302
- void *ptr;
303
- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
304
- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
298
+ void *ptr = mapToPtr;
299
+ if (!ptr) {
300
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
301
+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
302
+ }
305
303
306
- hostAllocations.emplace_back (ptr, size, offset, flags);
304
+ usm_unique_ptr_t mappedPtr =
305
+ usm_unique_ptr_t (ptr, [ownsAlloc = bool (mapToPtr), this ](void *p) {
306
+ if (ownsAlloc) {
307
+ auto ret = hContext->getDefaultUSMPool ()->free (p);
308
+ if (ret != UR_RESULT_SUCCESS) {
309
+ logger::error (" Failed to mapped memory: {}" , ret);
310
+ }
311
+ }
312
+ });
313
+
314
+ hostAllocations.emplace_back (std::move (mappedPtr), size, offset, flags);
307
315
308
316
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
309
- auto srcPtr =
310
- ur_cast<char *>(
311
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
312
- offset;
313
- migrate (srcPtr, hostAllocations.back ().ptr , size);
317
+ auto srcPtr = getActiveDeviceAlloc (offset);
318
+ migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
314
319
}
315
320
316
- return hostAllocations.back ().ptr ;
321
+ return hostAllocations.back ().ptr . get () ;
317
322
}
318
323
319
324
void ur_discrete_mem_handle_t::unmapHostPtr (
320
325
void *pMappedPtr,
321
326
std::function<void (void *src, void *dst, size_t )> migrate) {
322
327
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::unmapHostPtr" );
323
328
324
- for (auto &hostAllocation : hostAllocations) {
325
- if (hostAllocation.ptr == pMappedPtr) {
326
- void *devicePtr = nullptr ;
327
- if (activeAllocationDevice) {
328
- devicePtr =
329
- ur_cast<char *>(
330
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
331
- hostAllocation.offset ;
332
- } else if (!(hostAllocation.flags &
333
- UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
334
- devicePtr = ur_cast<char *>(getDevicePtr (
335
- hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
336
- hostAllocation.offset , hostAllocation.size , migrate));
337
- }
329
+ auto hostAlloc =
330
+ std::find_if (hostAllocations.begin (), hostAllocations.end (),
331
+ [pMappedPtr](const host_allocation_desc_t &desc) {
332
+ return desc.ptr .get () == pMappedPtr;
333
+ });
338
334
339
- if (devicePtr ) {
340
- migrate (hostAllocation. ptr , devicePtr, hostAllocation. size ) ;
341
- }
335
+ if (hostAlloc == hostAllocations. end () ) {
336
+ throw UR_RESULT_ERROR_INVALID_ARGUMENT ;
337
+ }
342
338
343
- // TODO: use async free here?
344
- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
345
- return ;
346
- }
339
+ bool shouldMigrateToDevice =
340
+ !(hostAlloc->flags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION);
341
+
342
+ if (!activeAllocationDevice && shouldMigrateToDevice) {
343
+ allocateOnDevice (hContext->getDevices ()[0 ], getSize ());
344
+ }
345
+
346
+ // TODO: tests require that memory is migrated even for
347
+ // UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
348
+ // allocation. is this correct?
349
+ if (activeAllocationDevice) {
350
+ migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
351
+ hostAlloc->size );
347
352
}
348
353
349
- // No mapping found
350
- throw UR_RESULT_ERROR_INVALID_ARGUMENT;
354
+ hostAllocations.erase (hostAlloc);
351
355
}
352
356
353
357
static bool useHostBuffer (ur_context_handle_t hContext) {
@@ -419,8 +423,6 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
419
423
auto accessMode = getDeviceAccessMode (flags);
420
424
421
425
if (useHostBuffer (hContext)) {
422
- // TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
423
- // or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
424
426
auto hostPtrAction =
425
427
flags & UR_MEM_FLAG_USE_HOST_POINTER
426
428
? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
0 commit comments