@@ -31,6 +31,50 @@ ur_command_list_manager::~ur_command_list_manager() {
31
31
ur::level_zero::urDeviceRelease (device);
32
32
}
33
33
34
+ ur_result_t ur_command_list_manager::appendGenericFillUnlocked (
35
+ ur_mem_buffer_t *dst, size_t offset, size_t patternSize,
36
+ const void *pPattern, size_t size, uint32_t numEventsInWaitList,
37
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent,
38
+ ur_command_t commandType) {
39
+
40
+ auto zeSignalEvent = getSignalEvent (phEvent, commandType);
41
+
42
+ auto waitListView = getWaitListView (phEventWaitList, numEventsInWaitList);
43
+
44
+ auto pDst = ur_cast<char *>(dst->getDevicePtr (
45
+ device, ur_mem_buffer_t ::device_access_mode_t ::read_only, offset, size,
46
+ [&](void *src, void *dst, size_t size) {
47
+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
48
+ (zeCommandList.get (), dst, src, size, nullptr ,
49
+ waitListView.num , waitListView.handles ));
50
+ waitListView.clear ();
51
+ }));
52
+
53
+ // PatternSize must be a power of two for zeCommandListAppendMemoryFill.
54
+ // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
55
+ if (isPowerOf2 (patternSize)) {
56
+ ZE2UR_CALL (zeCommandListAppendMemoryFill,
57
+ (zeCommandList.get (), pDst, pPattern, patternSize, size,
58
+ zeSignalEvent, waitListView.num , waitListView.handles ));
59
+ } else {
60
+ // Copy pattern into every entry in memory array pointed by Ptr.
61
+ uint32_t numOfCopySteps = size / patternSize;
62
+ const void *src = pPattern;
63
+
64
+ for (uint32_t step = 0 ; step < numOfCopySteps; ++step) {
65
+ void *dst = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(pDst) +
66
+ step * patternSize);
67
+ ZE2UR_CALL (zeCommandListAppendMemoryCopy,
68
+ (zeCommandList.get (), dst, src, patternSize,
69
+ step == numOfCopySteps - 1 ? zeSignalEvent : nullptr ,
70
+ waitListView.num , waitListView.handles ));
71
+ waitListView.clear ();
72
+ }
73
+ }
74
+
75
+ return UR_RESULT_SUCCESS;
76
+ }
77
+
34
78
ur_result_t ur_command_list_manager::appendGenericCopyUnlocked (
35
79
ur_mem_buffer_t *src, ur_mem_buffer_t *dst, bool blocking, size_t srcOffset,
36
80
size_t dstOffset, size_t size, uint32_t numEventsInWaitList,
@@ -209,6 +253,96 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy(
209
253
return UR_RESULT_SUCCESS;
210
254
}
211
255
256
+ ur_result_t ur_command_list_manager::appendMemBufferFill (
257
+ ur_mem_handle_t hMem, const void *pPattern, size_t patternSize,
258
+ size_t offset, size_t size, uint32_t numEventsInWaitList,
259
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
260
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendMemBufferFill" );
261
+
262
+ auto hBuffer = hMem->getBuffer ();
263
+ UR_ASSERT (offset + size <= hBuffer->getSize (), UR_RESULT_ERROR_INVALID_SIZE);
264
+
265
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> lock (this ->Mutex ,
266
+ hBuffer->getMutex ());
267
+
268
+ return appendGenericFillUnlocked (hBuffer, offset, patternSize, pPattern, size,
269
+ numEventsInWaitList, phEventWaitList,
270
+ phEvent, UR_COMMAND_MEM_BUFFER_FILL);
271
+ }
272
+
273
+ ur_result_t ur_command_list_manager::appendUSMFill (
274
+ void *pMem, size_t patternSize, const void *pPattern, size_t size,
275
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
276
+ ur_event_handle_t *phEvent) {
277
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMFill" );
278
+
279
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
280
+
281
+ ur_usm_handle_t dstHandle (context, size, pMem);
282
+ return appendGenericFillUnlocked (&dstHandle, 0 , patternSize, pPattern, size,
283
+ numEventsInWaitList, phEventWaitList,
284
+ phEvent, UR_COMMAND_USM_FILL);
285
+ }
286
+
287
+ ur_result_t ur_command_list_manager::appendUSMPrefetch (
288
+ const void *pMem, size_t size, ur_usm_migration_flags_t flags,
289
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
290
+ ur_event_handle_t *phEvent) {
291
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMPrefetch" );
292
+
293
+ std::ignore = flags;
294
+
295
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
296
+
297
+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_PREFETCH);
298
+
299
+ auto [pWaitEvents, numWaitEvents] =
300
+ getWaitListView (phEventWaitList, numEventsInWaitList);
301
+
302
+ if (pWaitEvents) {
303
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
304
+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
305
+ }
306
+ // TODO: figure out how to translate "flags"
307
+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
308
+ (zeCommandList.get (), pMem, size));
309
+ if (zeSignalEvent) {
310
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
311
+ (zeCommandList.get (), zeSignalEvent));
312
+ }
313
+
314
+ return UR_RESULT_SUCCESS;
315
+ }
316
+
317
+ ur_result_t
318
+ ur_command_list_manager::appendUSMAdvise (const void *pMem, size_t size,
319
+ ur_usm_advice_flags_t advice,
320
+ ur_event_handle_t *phEvent) {
321
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMAdvise" );
322
+
323
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
324
+
325
+ auto zeAdvice = ur_cast<ze_memory_advice_t >(advice);
326
+
327
+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_ADVISE);
328
+
329
+ auto [pWaitEvents, numWaitEvents] = getWaitListView (nullptr , 0 );
330
+
331
+ if (pWaitEvents) {
332
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
333
+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
334
+ }
335
+
336
+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
337
+ (zeCommandList.get (), device->ZeDevice , pMem, size, zeAdvice));
338
+
339
+ if (zeSignalEvent) {
340
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
341
+ (zeCommandList.get (), zeSignalEvent));
342
+ }
343
+ return UR_RESULT_SUCCESS;
344
+ }
345
+
212
346
ur_result_t ur_command_list_manager::appendMemBufferRead (
213
347
ur_mem_handle_t hMem, bool blockingRead, size_t offset, size_t size,
214
348
void *pDst, uint32_t numEventsInWaitList,
0 commit comments