Skip to content

Commit 062c35d

Browse files
committed
Handle subdevice partition correctly
1 parent 461cf6e commit 062c35d

File tree

8 files changed

+205
-97
lines changed

8 files changed

+205
-97
lines changed

source/adapters/native_cpu/device.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
9898
case UR_DEVICE_INFO_LINKER_AVAILABLE:
9999
return ReturnValue(bool{false});
100100
case UR_DEVICE_INFO_MAX_COMPUTE_UNITS:
101-
return ReturnValue(static_cast<uint32_t>(
102-
hDevice->tp.num_threads()));
101+
return ReturnValue(static_cast<uint32_t>(hDevice->tp.num_threads()));
103102
case UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES:
104103
return ReturnValue(uint32_t{0});
105104
case UR_DEVICE_INFO_SUPPORTED_PARTITIONS:
@@ -139,7 +138,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
139138
case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS:
140139
return ReturnValue(uint32_t{3});
141140
case UR_DEVICE_INFO_PARTITION_TYPE:
142-
return ReturnValue(ur_device_partition_property_t{});
141+
if (pPropSizeRet) {
142+
*pPropSizeRet = 0;
143+
}
144+
return UR_RESULT_SUCCESS;
143145
case UR_EXT_DEVICE_INFO_OPENCL_C_VERSION:
144146
return ReturnValue("");
145147
case UR_DEVICE_INFO_QUEUE_PROPERTIES:
@@ -159,8 +161,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
159161
case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_FLOAT:
160162
case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_DOUBLE:
161163
case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_HALF:
162-
// todo: how can we query vector width in a platform
163-
// indipendent way?
164+
// TODO: How can we query vector width in a platform
165+
// independent way?
164166
case UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR:
165167
return ReturnValue(uint32_t{32});
166168
case UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT:
@@ -266,7 +268,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
266268
case UR_DEVICE_INFO_ATOMIC_64:
267269
return ReturnValue(bool{1});
268270
case UR_DEVICE_INFO_BFLOAT16:
269-
return ReturnValue(bool{1});
271+
return ReturnValue(bool{0});
270272
case UR_DEVICE_INFO_MEM_CHANNEL_SUPPORT:
271273
return ReturnValue(bool{0});
272274
case UR_DEVICE_INFO_IMAGE_SRGB:

source/adapters/native_cpu/device.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,12 @@
1010

1111
#pragma once
1212

13-
#include <ur/ur.hpp>
1413
#include "threadpool.hpp"
14+
#include <ur/ur.hpp>
1515

1616
struct ur_device_handle_t_ {
1717
native_cpu::threadpool_t tp;
18-
ur_device_handle_t_(ur_platform_handle_t ArgPlt) : Platform(ArgPlt) {
19-
tp.start();
20-
}
21-
22-
~ur_device_handle_t_() {
23-
tp.stop();
24-
}
18+
ur_device_handle_t_(ur_platform_handle_t ArgPlt) : Platform(ArgPlt) {}
2519

2620
ur_platform_handle_t Platform;
2721
};

source/adapters/native_cpu/enqueue.cpp

Lines changed: 135 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
//
77
//===----------------------------------------------------------------------===//
88
#include <array>
9+
#include <cstddef>
910
#include <cstdint>
11+
#include <vector>
1012

1113
#include "ur_api.h"
1214

1315
#include "common.hpp"
1416
#include "kernel.hpp"
1517
#include "memory.hpp"
16-
#include "threadpool.hpp"
1718
#include "queue.hpp"
19+
#include "threadpool.hpp"
1820

1921
namespace native_cpu {
2022
struct NDRDescT {
@@ -37,9 +39,29 @@ struct NDRDescT {
3739
GlobalOffset[I] = 0;
3840
}
3941
}
42+
43+
void dump(std::ostream &os) const {
44+
os << "GlobalSize: " << GlobalSize[0] << " " << GlobalSize[1] << " "
45+
<< GlobalSize[2] << "\n";
46+
os << "LocalSize: " << LocalSize[0] << " " << LocalSize[1] << " "
47+
<< LocalSize[2] << "\n";
48+
os << "GlobalOffset: " << GlobalOffset[0] << " " << GlobalOffset[1] << " "
49+
<< GlobalOffset[2] << "\n";
50+
}
4051
};
4152
} // namespace native_cpu
4253

54+
#ifdef NATIVECPU_USE_OCK
55+
static native_cpu::state getResizedState(const native_cpu::NDRDescT &ndr,
56+
size_t itemsPerThread) {
57+
native_cpu::state resized_state(
58+
ndr.GlobalSize[0], ndr.GlobalSize[1], ndr.GlobalSize[2], itemsPerThread,
59+
ndr.LocalSize[1], ndr.LocalSize[2], ndr.GlobalOffset[0],
60+
ndr.GlobalOffset[1], ndr.GlobalOffset[2]);
61+
return resized_state;
62+
}
63+
#endif
64+
4365
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4466
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
4567
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -61,38 +83,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6183

6284
// TODO: add proper error checking
6385
// TODO: add proper event dep management
64-
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize);
65-
auto& tp = hQueue->device->tp;
86+
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize,
87+
pLocalWorkSize);
88+
auto &tp = hQueue->device->tp;
6689
const size_t numParallelThreads = tp.num_threads();
6790
hKernel->updateMemPool(numParallelThreads);
6891
std::vector<std::future<void>> futures;
92+
std::vector<std::function<void(size_t, ur_kernel_handle_t_)>> groups;
6993
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
7094
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
7195
auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2];
72-
bool isLocalSizeOne =
73-
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
74-
75-
7696
native_cpu::state state(ndr.GlobalSize[0], ndr.GlobalSize[1],
7797
ndr.GlobalSize[2], ndr.LocalSize[0], ndr.LocalSize[1],
7898
ndr.LocalSize[2], ndr.GlobalOffset[0],
7999
ndr.GlobalOffset[1], ndr.GlobalOffset[2]);
80-
if (isLocalSizeOne) {
81-
// If the local size is one, we make the assumption that we are running a
82-
// parallel_for over a sycl::range Todo: we could add compiler checks and
83-
// kernel properties for this (e.g. check that no barriers are called, no
84-
// local memory args).
85-
86-
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
87-
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
88-
auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2];
100+
#ifndef NATIVECPU_USE_OCK
101+
hKernel->handleLocalArgs(1, 0);
89102
for (unsigned g2 = 0; g2 < numWG2; g2++) {
90103
for (unsigned g1 = 0; g1 < numWG1; g1++) {
91104
for (unsigned g0 = 0; g0 < numWG0; g0++) {
92-
#ifdef NATIVECPU_USE_OCK
93-
state.update(g0, g1, g2);
94-
hKernel->_subhandler(hKernel->_args.data(), &state);
95-
#else
96105
for (unsigned local2 = 0; local2 < ndr.LocalSize[2]; local2++) {
97106
for (unsigned local1 = 0; local1 < ndr.LocalSize[1]; local1++) {
98107
for (unsigned local0 = 0; local0 < ndr.LocalSize[0]; local0++) {
@@ -101,13 +110,118 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
101110
}
102111
}
103112
}
104-
#endif
113+
}
114+
}
115+
}
116+
#else
117+
bool isLocalSizeOne =
118+
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
119+
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads) {
120+
// If the local size is one, we make the assumption that we are running a
121+
// parallel_for over a sycl::range.
122+
// Todo: we could add compiler checks and
123+
// kernel properties for this (e.g. check that no barriers are called, no
124+
// local memory args).
125+
126+
// Todo: this assumes that dim 0 is the best dimension over which we want to
127+
// parallelize
128+
129+
// Since we also vectorize the kernel, and vectorization happens within the
130+
// work group loop, it's better to have a large-ish local size. We can
131+
// divide the global range by the number of threads, set that as the local
132+
// size and peel everything else.
133+
134+
size_t new_num_work_groups_0 = numParallelThreads;
135+
size_t itemsPerThread = ndr.GlobalSize[0] / numParallelThreads;
136+
137+
for (unsigned g2 = 0; g2 < numWG2; g2++) {
138+
for (unsigned g1 = 0; g1 < numWG1; g1++) {
139+
for (unsigned g0 = 0; g0 < new_num_work_groups_0; g0 += 1) {
140+
futures.emplace_back(
141+
tp.schedule_task([&ndr = std::as_const(ndr), itemsPerThread,
142+
hKernel, g0, g1, g2](size_t) {
143+
native_cpu::state resized_state =
144+
getResizedState(ndr, itemsPerThread);
145+
resized_state.update(g0, g1, g2);
146+
hKernel->_subhandler(hKernel->_args.data(), &resized_state);
147+
}));
148+
}
149+
// Peel the remaining work items. Since the local size is 1, we iterate
150+
// over the work groups.
151+
for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
152+
g0++) {
153+
state.update(g0, g1, g2);
154+
hKernel->_subhandler(hKernel->_args.data(), &state);
155+
}
156+
}
157+
}
158+
159+
} else {
160+
// We are running a parallel_for over an nd_range
161+
162+
if (numWG1 * numWG2 >= numParallelThreads) {
163+
// Dimensions 1 and 2 have enough work, split them across the threadpool
164+
for (unsigned g2 = 0; g2 < numWG2; g2++) {
165+
for (unsigned g1 = 0; g1 < numWG1; g1++) {
166+
futures.emplace_back(
167+
tp.schedule_task([state, kernel = *hKernel, numWG0, g1, g2,
168+
numParallelThreads](size_t threadId) mutable {
169+
for (unsigned g0 = 0; g0 < numWG0; g0++) {
170+
kernel.handleLocalArgs(numParallelThreads, threadId);
171+
state.update(g0, g1, g2);
172+
kernel._subhandler(kernel._args.data(), &state);
173+
}
174+
}));
175+
}
176+
}
177+
} else {
178+
// Split dimension 0 across the threadpool
179+
// Here we try to create groups of workgroups in order to reduce
180+
// synchronization overhead
181+
for (unsigned g2 = 0; g2 < numWG2; g2++) {
182+
for (unsigned g1 = 0; g1 < numWG1; g1++) {
183+
for (unsigned g0 = 0; g0 < numWG0; g0++) {
184+
groups.push_back(
185+
[state, g0, g1, g2, numParallelThreads](
186+
size_t threadId, ur_kernel_handle_t_ kernel) mutable {
187+
kernel.handleLocalArgs(numParallelThreads, threadId);
188+
state.update(g0, g1, g2);
189+
kernel._subhandler(kernel._args.data(), &state);
190+
});
191+
}
192+
}
193+
}
194+
auto numGroups = groups.size();
195+
auto groupsPerThread = numGroups / numParallelThreads;
196+
auto remainder = numGroups % numParallelThreads;
197+
for (unsigned thread = 0; thread < numParallelThreads; thread++) {
198+
futures.emplace_back(tp.schedule_task(
199+
[&groups, thread, groupsPerThread, hKernel](size_t threadId) {
200+
for (unsigned i = 0; i < groupsPerThread; i++) {
201+
auto index = thread * groupsPerThread + i;
202+
groups[index](threadId, *hKernel);
203+
}
204+
}));
205+
}
206+
207+
// schedule the remaining tasks
208+
if (remainder) {
209+
futures.emplace_back(
210+
tp.schedule_task([&groups, remainder,
211+
scheduled = numParallelThreads * groupsPerThread,
212+
hKernel](size_t threadId) {
213+
for (unsigned i = 0; i < remainder; i++) {
214+
auto index = scheduled + i;
215+
groups[index](threadId, *hKernel);
216+
}
217+
}));
105218
}
106219
}
107220
}
108221

109222
for (auto &f : futures)
110223
f.get();
224+
#endif // NATIVECPU_USE_OCK
111225
// TODO: we should avoid calling clear here by avoiding using push_back
112226
// in setKernelArgs.
113227
hKernel->_args.clear();
@@ -553,4 +667,3 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
553667

554668
DIE_NO_IMPLEMENTATION;
555669
}
556-

source/adapters/native_cpu/kernel.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,25 @@ struct ur_kernel_handle_t_ : RefCounted {
4040
ur_kernel_handle_t_(const char *name, nativecpu_task_t subhandler)
4141
: _name{name}, _subhandler{std::move(subhandler)} {}
4242

43-
ur_kernel_handle_t_(const ur_kernel_handle_t_& other) : _name(other._name), _subhandler(other._subhandler),
44-
_args(other._args), _localArgInfo(other._localArgInfo), _localMemPool(other._localMemPool), _localMemPoolSize(other._localMemPoolSize) {
43+
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
44+
: _name(other._name), _subhandler(other._subhandler), _args(other._args),
45+
_localArgInfo(other._localArgInfo), _localMemPool(other._localMemPool),
46+
_localMemPoolSize(other._localMemPoolSize) {
4547
incrementReferenceCount();
4648
}
4749

4850
~ur_kernel_handle_t_() {
49-
decrementReferenceCount();
50-
if (_refCount == 0) {
51+
if (decrementReferenceCount() == 0) {
5152
free(_localMemPool);
5253
}
53-
5454
}
5555

5656
const char *_name;
5757
nativecpu_task_t _subhandler;
5858
std::vector<native_cpu::NativeCPUArgDesc> _args;
5959
std::vector<local_arg_info_t> _localArgInfo;
6060

61-
// To be called before enqueing the kernel.
61+
// To be called before enqueueing the kernel.
6262
void updateMemPool(size_t numParallelThreads) {
6363
// compute requested size.
6464
size_t reqSize = 0;
@@ -69,7 +69,7 @@ struct ur_kernel_handle_t_ : RefCounted {
6969
return;
7070
}
7171
// realloc handles nullptr case
72-
_localMemPool = (char*)realloc(_localMemPool, reqSize);
72+
_localMemPool = (char *)realloc(_localMemPool, reqSize);
7373
_localMemPoolSize = reqSize;
7474
}
7575

@@ -86,7 +86,6 @@ struct ur_kernel_handle_t_ : RefCounted {
8686
}
8787

8888
private:
89-
char* _localMemPool = nullptr;
89+
char *_localMemPool = nullptr;
9090
size_t _localMemPoolSize = 0;
9191
};
92-

source/adapters/native_cpu/nativecpu_state.hpp

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct state {
1919
size_t MLocal_id[3];
2020
size_t MNumGroups[3];
2121
size_t MGlobalOffset[3];
22+
uint32_t NumSubGroups, SubGroup_id, SubGroup_local_id, SubGroup_size;
2223
state(size_t globalR0, size_t globalR1, size_t globalR2, size_t localR0,
2324
size_t localR1, size_t localR2, size_t globalO0, size_t globalO1,
2425
size_t globalO2)
@@ -36,6 +37,10 @@ struct state {
3637
MLocal_id[0] = 0;
3738
MLocal_id[1] = 0;
3839
MLocal_id[2] = 0;
40+
NumSubGroups = 32;
41+
SubGroup_id = 0;
42+
SubGroup_local_id = 0;
43+
SubGroup_size = 1;
3944
}
4045

4146
void update(size_t group0, size_t group1, size_t group2, size_t local0,

source/adapters/native_cpu/queue.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
#include "device.hpp"
1313

1414
struct ur_queue_handle_t_ : RefCounted {
15-
ur_device_handle_t_ *device;
15+
ur_device_handle_t_ *const device;
1616

1717
ur_queue_handle_t_(ur_device_handle_t_ *device) : device(device) {}
18-
1918
};

0 commit comments

Comments
 (0)