6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
#include < array>
9
+ #include < cstddef>
9
10
#include < cstdint>
11
+ #include < vector>
10
12
11
13
#include " ur_api.h"
12
14
13
15
#include " common.hpp"
14
16
#include " kernel.hpp"
15
17
#include " memory.hpp"
16
- #include " threadpool.hpp"
17
18
#include " queue.hpp"
19
+ #include " threadpool.hpp"
18
20
19
21
namespace native_cpu {
20
22
struct NDRDescT {
@@ -37,9 +39,29 @@ struct NDRDescT {
37
39
GlobalOffset[I] = 0 ;
38
40
}
39
41
}
42
+
43
+ void dump (std::ostream &os) const {
44
+ os << " GlobalSize: " << GlobalSize[0 ] << " " << GlobalSize[1 ] << " "
45
+ << GlobalSize[2 ] << " \n " ;
46
+ os << " LocalSize: " << LocalSize[0 ] << " " << LocalSize[1 ] << " "
47
+ << LocalSize[2 ] << " \n " ;
48
+ os << " GlobalOffset: " << GlobalOffset[0 ] << " " << GlobalOffset[1 ] << " "
49
+ << GlobalOffset[2 ] << " \n " ;
50
+ }
40
51
};
41
52
} // namespace native_cpu
42
53
54
+ #ifdef NATIVECPU_USE_OCK
55
+ static native_cpu::state getResizedState (const native_cpu::NDRDescT &ndr,
56
+ size_t itemsPerThread) {
57
+ native_cpu::state resized_state (
58
+ ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], ndr.GlobalSize [2 ], itemsPerThread,
59
+ ndr.LocalSize [1 ], ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
60
+ ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
61
+ return resized_state;
62
+ }
63
+ #endif
64
+
43
65
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
44
66
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
45
67
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -61,38 +83,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
61
83
62
84
// TODO: add proper error checking
63
85
// TODO: add proper event dep management
64
- native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize);
65
- auto & tp = hQueue->device ->tp ;
86
+ native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize,
87
+ pLocalWorkSize);
88
+ auto &tp = hQueue->device ->tp ;
66
89
const size_t numParallelThreads = tp.num_threads ();
67
90
hKernel->updateMemPool (numParallelThreads);
68
91
std::vector<std::future<void >> futures;
92
+ std::vector<std::function<void (size_t , ur_kernel_handle_t_)>> groups;
69
93
auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
70
94
auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
71
95
auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
72
- bool isLocalSizeOne =
73
- ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
74
-
75
-
76
96
native_cpu::state state (ndr.GlobalSize [0 ], ndr.GlobalSize [1 ],
77
97
ndr.GlobalSize [2 ], ndr.LocalSize [0 ], ndr.LocalSize [1 ],
78
98
ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
79
99
ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
80
- if (isLocalSizeOne) {
81
- // If the local size is one, we make the assumption that we are running a
82
- // parallel_for over a sycl::range Todo: we could add compiler checks and
83
- // kernel properties for this (e.g. check that no barriers are called, no
84
- // local memory args).
85
-
86
- auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
87
- auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
88
- auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
100
+ #ifndef NATIVECPU_USE_OCK
101
+ hKernel->handleLocalArgs (1 , 0 );
89
102
for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
90
103
for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
91
104
for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
92
- #ifdef NATIVECPU_USE_OCK
93
- state.update (g0, g1, g2);
94
- hKernel->_subhandler (hKernel->_args .data (), &state);
95
- #else
96
105
for (unsigned local2 = 0 ; local2 < ndr.LocalSize [2 ]; local2++) {
97
106
for (unsigned local1 = 0 ; local1 < ndr.LocalSize [1 ]; local1++) {
98
107
for (unsigned local0 = 0 ; local0 < ndr.LocalSize [0 ]; local0++) {
@@ -101,13 +110,118 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
101
110
}
102
111
}
103
112
}
104
- #endif
113
+ }
114
+ }
115
+ }
116
+ #else
117
+ bool isLocalSizeOne =
118
+ ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
119
+ if (isLocalSizeOne && ndr.GlobalSize [0 ] > numParallelThreads) {
120
+ // If the local size is one, we make the assumption that we are running a
121
+ // parallel_for over a sycl::range.
122
+ // Todo: we could add compiler checks and
123
+ // kernel properties for this (e.g. check that no barriers are called, no
124
+ // local memory args).
125
+
126
+ // Todo: this assumes that dim 0 is the best dimension over which we want to
127
+ // parallelize
128
+
129
+ // Since we also vectorize the kernel, and vectorization happens within the
130
+ // work group loop, it's better to have a large-ish local size. We can
131
+ // divide the global range by the number of threads, set that as the local
132
+ // size and peel everything else.
133
+
134
+ size_t new_num_work_groups_0 = numParallelThreads;
135
+ size_t itemsPerThread = ndr.GlobalSize [0 ] / numParallelThreads;
136
+
137
+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
138
+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
139
+ for (unsigned g0 = 0 ; g0 < new_num_work_groups_0; g0 += 1 ) {
140
+ futures.emplace_back (
141
+ tp.schedule_task ([&ndr = std::as_const (ndr), itemsPerThread,
142
+ hKernel, g0, g1, g2](size_t ) {
143
+ native_cpu::state resized_state =
144
+ getResizedState (ndr, itemsPerThread);
145
+ resized_state.update (g0, g1, g2);
146
+ hKernel->_subhandler (hKernel->_args .data (), &resized_state);
147
+ }));
148
+ }
149
+ // Peel the remaining work items. Since the local size is 1, we iterate
150
+ // over the work groups.
151
+ for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
152
+ g0++) {
153
+ state.update (g0, g1, g2);
154
+ hKernel->_subhandler (hKernel->_args .data (), &state);
155
+ }
156
+ }
157
+ }
158
+
159
+ } else {
160
+ // We are running a parallel_for over an nd_range
161
+
162
+ if (numWG1 * numWG2 >= numParallelThreads) {
163
+ // Dimensions 1 and 2 have enough work, split them across the threadpool
164
+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
165
+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
166
+ futures.emplace_back (
167
+ tp.schedule_task ([state, kernel = *hKernel, numWG0, g1, g2,
168
+ numParallelThreads](size_t threadId) mutable {
169
+ for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
170
+ kernel.handleLocalArgs (numParallelThreads, threadId);
171
+ state.update (g0, g1, g2);
172
+ kernel._subhandler (kernel._args .data (), &state);
173
+ }
174
+ }));
175
+ }
176
+ }
177
+ } else {
178
+ // Split dimension 0 across the threadpool
179
+ // Here we try to create groups of workgroups in order to reduce
180
+ // synchronization overhead
181
+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
182
+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
183
+ for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
184
+ groups.push_back (
185
+ [state, g0, g1, g2, numParallelThreads](
186
+ size_t threadId, ur_kernel_handle_t_ kernel) mutable {
187
+ kernel.handleLocalArgs (numParallelThreads, threadId);
188
+ state.update (g0, g1, g2);
189
+ kernel._subhandler (kernel._args .data (), &state);
190
+ });
191
+ }
192
+ }
193
+ }
194
+ auto numGroups = groups.size ();
195
+ auto groupsPerThread = numGroups / numParallelThreads;
196
+ auto remainder = numGroups % numParallelThreads;
197
+ for (unsigned thread = 0 ; thread < numParallelThreads; thread++) {
198
+ futures.emplace_back (tp.schedule_task (
199
+ [&groups, thread, groupsPerThread, hKernel](size_t threadId) {
200
+ for (unsigned i = 0 ; i < groupsPerThread; i++) {
201
+ auto index = thread * groupsPerThread + i;
202
+ groups[index ](threadId, *hKernel);
203
+ }
204
+ }));
205
+ }
206
+
207
+ // schedule the remaining tasks
208
+ if (remainder ) {
209
+ futures.emplace_back (
210
+ tp.schedule_task ([&groups, remainder ,
211
+ scheduled = numParallelThreads * groupsPerThread,
212
+ hKernel](size_t threadId) {
213
+ for (unsigned i = 0 ; i < remainder ; i++) {
214
+ auto index = scheduled + i;
215
+ groups[index ](threadId, *hKernel);
216
+ }
217
+ }));
105
218
}
106
219
}
107
220
}
108
221
109
222
for (auto &f : futures)
110
223
f.get ();
224
+ #endif // NATIVECPU_USE_OCK
111
225
// TODO: we should avoid calling clear here by avoiding using push_back
112
226
// in setKernelArgs.
113
227
hKernel->_args .clear ();
@@ -553,4 +667,3 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
553
667
554
668
DIE_NO_IMPLEMENTATION;
555
669
}
556
-
0 commit comments