Skip to content

Commit 129ee44

Browse files
[SYCL]: basic support of contexts with multiple devices in Level-Zero (#2440)
Signed-off-by: Sergey V Maslov <[email protected]>
1 parent 628424a commit 129ee44

File tree

2 files changed

+149
-96
lines changed

2 files changed

+149
-96
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

+129-79
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
/// \ingroup sycl_pi_level_zero
1313

1414
#include "pi_level_zero.hpp"
15+
#include <algorithm>
1516
#include <cstdarg>
1617
#include <cstdio>
1718
#include <cstring>
@@ -219,9 +220,13 @@ _pi_context::getFreeSlotInExistingOrNewPool(ze_event_pool_handle_t &ZePool,
219220
ZeEventPoolDesc.count = MaxNumEventsPerPool;
220221
ZeEventPoolDesc.flags = ZE_EVENT_POOL_FLAG_KERNEL_TIMESTAMP;
221222

222-
ze_device_handle_t ZeDevice = Device->ZeDevice;
223-
if (ze_result_t ZeRes = zeEventPoolCreate(ZeContext, &ZeEventPoolDesc, 1,
224-
&ZeDevice, &ZeEventPool))
223+
std::vector<ze_device_handle_t> ZeDevices;
224+
std::for_each(Devices.begin(), Devices.end(),
225+
[&](pi_device &D) { ZeDevices.push_back(D->ZeDevice); });
226+
227+
if (ze_result_t ZeRes =
228+
zeEventPoolCreate(ZeContext, &ZeEventPoolDesc, ZeDevices.size(),
229+
&ZeDevices[0], &ZeEventPool))
225230
return ZeRes;
226231
NumEventsAvailableInEventPool[ZeEventPool] = MaxNumEventsPerPool - 1;
227232
NumEventsLiveInEventPool[ZeEventPool] = MaxNumEventsPerPool;
@@ -408,9 +413,9 @@ _pi_queue::resetCommandListFenceEntry(ze_command_list_handle_t ZeCommandList,
408413
ZE_CALL(zeFenceReset(this->ZeCommandListFenceMap[ZeCommandList]));
409414
ZE_CALL(zeCommandListReset(ZeCommandList));
410415
if (MakeAvailable) {
411-
this->Context->Device->ZeCommandListCacheMutex.lock();
412-
this->Context->Device->ZeCommandListCache.push_back(ZeCommandList);
413-
this->Context->Device->ZeCommandListCacheMutex.unlock();
416+
this->Device->ZeCommandListCacheMutex.lock();
417+
this->Device->ZeCommandListCache.push_back(ZeCommandList);
418+
this->Device->ZeCommandListCacheMutex.unlock();
414419
}
415420

416421
return PI_SUCCESS;
@@ -433,7 +438,7 @@ _pi_device::getAvailableCommandList(pi_queue Queue,
433438

434439
// Initally, we need to check if a command list has already been created
435440
// on this device that is available for use. If so, then reuse that
436-
// L0 Command List and Fence for this PI call.
441+
// Level-Zero Command List and Fence for this PI call.
437442
if (Queue->Device->ZeCommandListCache.size() > 0) {
438443
Queue->Device->ZeCommandListCacheMutex.lock();
439444
*ZeCommandList = Queue->Device->ZeCommandListCache.front();
@@ -749,11 +754,25 @@ pi_result piextPlatformCreateWithNativeHandle(pi_native_handle NativeHandle,
749754
assert(Platform);
750755

751756
// Create PI platform from the given Level Zero driver handle.
757+
// TODO: get the platform from the platforms' cache.
752758
auto ZeDriver = pi_cast<ze_driver_handle_t>(NativeHandle);
753759
*Platform = new _pi_platform(ZeDriver);
754760
return PI_SUCCESS;
755761
}
756762

763+
// Get the cahched PI device created for the L0 device handle.
764+
// Return NULL if no such PI device found.
765+
pi_device _pi_platform::getDeviceFromNativeHandle(ze_device_handle_t ZeDevice) {
766+
767+
std::lock_guard<std::mutex> Lock(this->PiDevicesCacheMutex);
768+
auto it = std::find_if(PiDevicesCache.begin(), PiDevicesCache.end(),
769+
[&](pi_device &D) { return D->ZeDevice == ZeDevice; });
770+
if (it != PiDevicesCache.end()) {
771+
return *it;
772+
}
773+
return nullptr;
774+
}
775+
757776
pi_result piDevicesGet(pi_platform Platform, pi_device_type DeviceType,
758777
pi_uint32 NumEntries, pi_device *Devices,
759778
pi_uint32 *NumDevices) {
@@ -1391,6 +1410,7 @@ pi_result piextDeviceCreateWithNativeHandle(pi_native_handle NativeHandle,
13911410
assert(Platform);
13921411

13931412
// Create PI device from the given Level Zero device handle.
1413+
// TODO: get the device from the devices' cache.
13941414
auto ZeDevice = pi_cast<ze_device_handle_t>(NativeHandle);
13951415
*Device = new _pi_device(ZeDevice, Platform);
13961416
return (*Device)->initialize();
@@ -1402,15 +1422,14 @@ pi_result piContextCreate(const pi_context_properties *Properties,
14021422
const void *PrivateInfo, size_t CB,
14031423
void *UserData),
14041424
void *UserData, pi_context *RetContext) {
1405-
if (NumDevices != 1 || !Devices) {
1406-
zePrint("piCreateContext: context should have exactly one Device\n");
1425+
if (!Devices) {
14071426
return PI_INVALID_VALUE;
14081427
}
14091428

14101429
assert(RetContext);
14111430

14121431
try {
1413-
*RetContext = new _pi_context(*Devices);
1432+
*RetContext = new _pi_context(NumDevices, Devices);
14141433
} catch (const std::bad_alloc &) {
14151434
return PI_OUT_OF_HOST_MEMORY;
14161435
} catch (...) {
@@ -1444,9 +1463,10 @@ pi_result piContextGetInfo(pi_context Context, pi_context_info ParamName,
14441463
ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet);
14451464
switch (ParamName) {
14461465
case PI_CONTEXT_INFO_DEVICES:
1447-
return ReturnValue(Context->Device);
1466+
return getInfoArray(Context->Devices.size(), ParamValueSize, ParamValue,
1467+
ParamValueSizeRet, &Context->Devices[0]);
14481468
case PI_CONTEXT_INFO_NUM_DEVICES:
1449-
return ReturnValue(pi_uint32{1});
1469+
return ReturnValue(pi_uint32(Context->Devices.size()));
14501470
case PI_CONTEXT_INFO_REFERENCE_COUNT:
14511471
return ReturnValue(pi_uint32{Context->RefCount});
14521472
default:
@@ -1521,7 +1541,8 @@ pi_result piQueueCreate(pi_context Context, pi_device Device,
15211541
if (!Context) {
15221542
return PI_INVALID_CONTEXT;
15231543
}
1524-
if (Context->Device != Device) {
1544+
if (std::find(Context->Devices.begin(), Context->Devices.end(), Device) ==
1545+
Context->Devices.end()) {
15251546
return PI_INVALID_DEVICE;
15261547
}
15271548

@@ -1628,7 +1649,11 @@ pi_result piextQueueCreateWithNativeHandle(pi_native_handle NativeHandle,
16281649
assert(Queue);
16291650

16301651
auto ZeQueue = pi_cast<ze_command_queue_handle_t>(NativeHandle);
1631-
*Queue = new _pi_queue(ZeQueue, Context, Context->Device);
1652+
1653+
// Attach the queue to the "0" device.
1654+
// TODO: see if we need to let user choose the device.
1655+
pi_device Device = Context->Devices[0];
1656+
*Queue = new _pi_queue(ZeQueue, Context, Device);
16321657
return PI_SUCCESS;
16331658
}
16341659

@@ -1641,14 +1666,24 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
16411666
assert(RetMem);
16421667

16431668
void *Ptr;
1644-
ze_device_handle_t ZeDevice = Context->Device->ZeDevice;
16451669

1646-
ze_device_mem_alloc_desc_t ZeDesc = {};
1647-
ZeDesc.flags = 0;
1648-
ZeDesc.ordinal = 0;
1649-
ZE_CALL(zeMemAllocDevice(Context->ZeContext, &ZeDesc, Size,
1650-
1, // TODO: alignment
1651-
ZeDevice, &Ptr));
1670+
ze_device_mem_alloc_desc_t ZeDeviceMemDesc = {};
1671+
ZeDeviceMemDesc.flags = 0;
1672+
ZeDeviceMemDesc.ordinal = 0;
1673+
1674+
if (Context->Devices.size() == 1) {
1675+
ZE_CALL(zeMemAllocDevice(Context->ZeContext, &ZeDeviceMemDesc, Size,
1676+
1, // TODO: alignment
1677+
Context->Devices[0]->ZeDevice, &Ptr));
1678+
} else {
1679+
ze_host_mem_alloc_desc_t ZeHostMemDesc = {};
1680+
ZeHostMemDesc.flags = 0;
1681+
ZE_CALL(zeMemAllocShared(Context->ZeContext, &ZeDeviceMemDesc,
1682+
&ZeHostMemDesc, Size,
1683+
1, // TODO: alignment
1684+
nullptr, // not bound to any device
1685+
&Ptr));
1686+
}
16521687

16531688
if ((Flags & PI_MEM_FLAGS_HOST_PTR_USE) != 0 ||
16541689
(Flags & PI_MEM_FLAGS_HOST_PTR_COPY) != 0) {
@@ -1837,9 +1872,17 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
18371872
ZeImageDesc.arraylevels = pi_cast<uint32_t>(ImageDesc->image_array_size);
18381873
ZeImageDesc.miplevels = ImageDesc->num_mip_levels;
18391874

1875+
// Have the "0" device in context to own the image. Rely on Level-Zero
1876+
// drivers to perform migration as necessary for sharing it across multiple
1877+
// devices in the context.
1878+
//
1879+
// TODO: figure out if we instead need explicit copying for acessing
1880+
// the image from other devices in the context.
1881+
//
1882+
pi_device Device = Context->Devices[0];
18401883
ze_image_handle_t ZeHImage;
1841-
ZE_CALL(zeImageCreate(Context->ZeContext, Context->Device->ZeDevice,
1842-
&ZeImageDesc, &ZeHImage));
1884+
ZE_CALL(zeImageCreate(Context->ZeContext, Device->ZeDevice, &ZeImageDesc,
1885+
&ZeHImage));
18431886

18441887
auto HostPtrOrNull =
18451888
(Flags & PI_MEM_FLAGS_HOST_PTR_USE) ? pi_cast<char *>(HostPtr) : nullptr;
@@ -1926,7 +1969,7 @@ pi_result piProgramCreateWithBinary(pi_context Context, pi_uint32 NumDevices,
19261969
*BinaryStatus = PI_INVALID_VALUE;
19271970
return PI_INVALID_VALUE;
19281971
}
1929-
if (DeviceList[0] != Context->Device)
1972+
if (DeviceList[0] != Context->Devices[0])
19301973
return PI_INVALID_DEVICE;
19311974

19321975
size_t Length = Lengths[0];
@@ -1975,10 +2018,11 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
19752018
case PI_PROGRAM_INFO_REFERENCE_COUNT:
19762019
return ReturnValue(pi_uint32{Program->RefCount});
19772020
case PI_PROGRAM_INFO_NUM_DEVICES:
1978-
// Level Zero Module is always for a single device.
2021+
// TODO: return true number of devices this program exists for.
19792022
return ReturnValue(pi_uint32{1});
19802023
case PI_PROGRAM_INFO_DEVICES:
1981-
return ReturnValue(Program->Context->Device);
2024+
// TODO: return all devices this program exists for.
2025+
return ReturnValue(Program->Context->Devices[0]);
19822026
case PI_PROGRAM_INFO_BINARY_SIZES: {
19832027
size_t SzBinary;
19842028
if (Program->State == _pi_program::IL ||
@@ -2105,9 +2149,10 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
21052149
void (*PFnNotify)(pi_program Program, void *UserData),
21062150
void *UserData, pi_program *RetProgram) {
21072151

2108-
// We only support one device with Level Zero.
2152+
// We only support one device with Level Zero currently.
2153+
pi_device Device = Context->Devices[0];
21092154
assert(NumDevices == 1);
2110-
assert(DeviceList && DeviceList[0] == Context->Device);
2155+
assert(DeviceList && DeviceList[0] == Device);
21112156
assert(!PFnNotify && !UserData);
21122157

21132158
// Validate input parameters.
@@ -2170,9 +2215,8 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
21702215
// only export symbols.
21712216
Guard.unlock();
21722217
ze_module_handle_t ZeModule;
2173-
pi_result res =
2174-
copyModule(Context->ZeContext, Context->Device->ZeDevice,
2175-
Input->ZeModule, &ZeModule);
2218+
pi_result res = copyModule(Context->ZeContext, Device->ZeDevice,
2219+
Input->ZeModule, &ZeModule);
21762220
if (res != PI_SUCCESS) {
21772221
return res;
21782222
}
@@ -2270,7 +2314,9 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
22702314
if ((NumDevices && !DeviceList) || (!NumDevices && DeviceList))
22712315
return PI_INVALID_VALUE;
22722316

2273-
// We only support one device with Level Zero.
2317+
// We only support build to one device with Level Zero now.
2318+
// TODO: we should eventually build to the possibly multiple root
2319+
// devices in the context.
22742320
assert(NumDevices == 1 && DeviceList);
22752321

22762322
// We should have either IL or native device code.
@@ -2307,7 +2353,7 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
23072353
ZeModuleDesc.pBuildFlags = Options;
23082354
ZeModuleDesc.pConstants = &ZeSpecConstants;
23092355

2310-
ze_device_handle_t ZeDevice = Program->Context->Device->ZeDevice;
2356+
ze_device_handle_t ZeDevice = DeviceList[0]->ZeDevice;
23112357
ze_context_handle_t ZeContext = Program->Context->ZeContext;
23122358
ze_module_handle_t ZeModule;
23132359
ze_module_build_log_handle_t ZeBuildLog;
@@ -2905,7 +2951,8 @@ pi_result piEventCreate(pi_context Context, pi_event *RetEvent) {
29052951
ze_event_handle_t ZeEvent;
29062952
ze_event_desc_t ZeEventDesc = {};
29072953
// We have to set the SIGNAL & WAIT flags as HOST scope because the
2908-
// L0 plugin implementation waits for the events to complete on the host.
2954+
// Level-Zero plugin implementation waits for the events to complete
2955+
// on the host.
29092956
ZeEventDesc.signal = ZE_EVENT_SCOPE_FLAG_HOST;
29102957
ZeEventDesc.wait = ZE_EVENT_SCOPE_FLAG_HOST;
29112958
ZeEventDesc.index = Index;
@@ -3111,7 +3158,14 @@ pi_result piSamplerCreate(pi_context Context,
31113158
assert(Context);
31123159
assert(RetSampler);
31133160

3114-
ze_device_handle_t ZeDevice = Context->Device->ZeDevice;
3161+
// Have the "0" device in context to own the sampler. Rely on Level-Zero
3162+
// drivers to perform migration as necessary for sharing it across multiple
3163+
// devices in the context.
3164+
//
3165+
// TODO: figure out if we instead need explicit copying for acessing
3166+
// the sampler from other devices in the context.
3167+
//
3168+
pi_device Device = Context->Devices[0];
31153169

31163170
ze_sampler_handle_t ZeSampler;
31173171
ze_sampler_desc_t ZeSamplerDesc = {};
@@ -3199,7 +3253,7 @@ pi_result piSamplerCreate(pi_context Context,
31993253
}
32003254
}
32013255

3202-
ZE_CALL(zeSamplerCreate(Context->ZeContext, ZeDevice,
3256+
ZE_CALL(zeSamplerCreate(Context->ZeContext, Device->ZeDevice,
32033257
&ZeSamplerDesc, // TODO: translate properties
32043258
&ZeSampler));
32053259

@@ -4241,49 +4295,44 @@ pi_result piextUSMFree(pi_context Context, void *Ptr) {
42414295
ze_memory_allocation_properties_t ZeMemoryAllocationProperties = {};
42424296

42434297
// Query memory type of the pointer we're freeing to determine the correct
4244-
// way to do it(directly or via the allocator)
4298+
// way to do it(directly or via an allocator)
42454299
ZE_CALL(zeMemGetAllocProperties(
42464300
Context->ZeContext, Ptr, &ZeMemoryAllocationProperties, &ZeDeviceHandle));
42474301

4248-
// TODO: when support for multiple devices is implemented, here
4249-
// we should do the following:
4250-
// - Find pi_device instance corresponding to ZeDeviceHandle we've just got if
4251-
// exist
4252-
// - Use that pi_device to find the right allocator context and free the
4253-
// pointer.
4254-
4255-
// The allocation doesn't belong to any device for which USM allocator is
4256-
// enabled.
4257-
if (Context->Device->ZeDevice != ZeDeviceHandle) {
4258-
return USMFreeImpl(Context, Ptr);
4259-
}
4260-
4261-
auto DeallocationHelper =
4262-
[Context,
4263-
Ptr](std::unordered_map<pi_device, USMAllocContext> &AllocContextMap) {
4264-
try {
4265-
auto It = AllocContextMap.find(Context->Device);
4266-
if (It == AllocContextMap.end())
4267-
return PI_INVALID_VALUE;
4268-
4269-
// The right context is found, deallocate the pointer
4270-
It->second.deallocate(Ptr);
4271-
} catch (const UsmAllocationException &Ex) {
4272-
return Ex.getError();
4273-
}
4302+
if (ZeDeviceHandle) {
4303+
// All devices in the context are of the same platform.
4304+
auto Platform = Context->Devices[0]->Platform;
4305+
auto Device = Platform->getDeviceFromNativeHandle(ZeDeviceHandle);
4306+
assert(Device);
4307+
4308+
auto DeallocationHelper =
4309+
[Context, Device,
4310+
Ptr](std::unordered_map<pi_device, USMAllocContext> &AllocContextMap) {
4311+
try {
4312+
auto It = AllocContextMap.find(Device);
4313+
if (It == AllocContextMap.end())
4314+
return PI_INVALID_VALUE;
4315+
4316+
// The right context is found, deallocate the pointer
4317+
It->second.deallocate(Ptr);
4318+
} catch (const UsmAllocationException &Ex) {
4319+
return Ex.getError();
4320+
}
42744321

4275-
return PI_SUCCESS;
4276-
};
4322+
return PI_SUCCESS;
4323+
};
42774324

4278-
switch (ZeMemoryAllocationProperties.type) {
4279-
case ZE_MEMORY_TYPE_SHARED:
4280-
return DeallocationHelper(Context->SharedMemAllocContexts);
4281-
case ZE_MEMORY_TYPE_DEVICE:
4282-
return DeallocationHelper(Context->DeviceMemAllocContexts);
4283-
default:
4284-
// Handled below
4285-
break;
4325+
switch (ZeMemoryAllocationProperties.type) {
4326+
case ZE_MEMORY_TYPE_SHARED:
4327+
return DeallocationHelper(Context->SharedMemAllocContexts);
4328+
case ZE_MEMORY_TYPE_DEVICE:
4329+
return DeallocationHelper(Context->DeviceMemAllocContexts);
4330+
default:
4331+
// Handled below
4332+
break;
4333+
}
42864334
}
4335+
42874336
return USMFreeImpl(Context, Ptr);
42884337
}
42894338

@@ -4519,14 +4568,15 @@ pi_result piextUSMGetMemAllocInfo(pi_context Context, const void *Ptr,
45194568
}
45204569
return ReturnValue(MemAllocaType);
45214570
}
4522-
case PI_MEM_ALLOC_DEVICE: {
4571+
case PI_MEM_ALLOC_DEVICE:
45234572
if (ZeDeviceHandle) {
4524-
if (Context->Device->ZeDevice == ZeDeviceHandle) {
4525-
return ReturnValue(Context->Device);
4526-
}
4573+
// All devices in the context are of the same platform.
4574+
auto Platform = Context->Devices[0]->Platform;
4575+
auto Device = Platform->getDeviceFromNativeHandle(ZeDeviceHandle);
4576+
return Device ? ReturnValue(Device) : PI_INVALID_VALUE;
4577+
} else {
4578+
return PI_INVALID_VALUE;
45274579
}
4528-
return PI_INVALID_VALUE;
4529-
}
45304580
case PI_MEM_ALLOC_BASE_PTR: {
45314581
void *Base;
45324582
ZE_CALL(zeMemGetAddressRange(Context->ZeContext, Ptr, &Base, nullptr));

0 commit comments

Comments
 (0)