@@ -223,17 +223,16 @@ checkValueRange(const T &V) {
223
223
namespace ONEAPI {
224
224
namespace detail {
225
225
template <typename T, class BinaryOperation , int Dims, bool IsUSM,
226
- access::mode AccMode, access:: placeholder IsPlaceholder>
226
+ access::placeholder IsPlaceholder>
227
227
class reduction_impl ;
228
228
229
229
using cl::sycl::detail::enable_if_t ;
230
230
using cl::sycl::detail::queue_impl;
231
231
232
- template <typename KernelName, typename KernelType, int Dims, class Reduction ,
233
- typename OutputT>
232
+ template <typename KernelName, typename KernelType, int Dims, class Reduction >
234
233
enable_if_t <Reduction::has_fast_atomics>
235
234
reduCGFunc (handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
236
- Reduction &Redu, OutputT Out );
235
+ Reduction &Redu);
237
236
238
237
template <typename KernelName, typename KernelType, int Dims, class Reduction >
239
238
enable_if_t <!Reduction::has_fast_atomics>
@@ -258,6 +257,26 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
258
257
std::tuple<Reductions...> &ReduTuple,
259
258
std::index_sequence<Is...>);
260
259
260
+ template <typename KernelName, class Reduction >
261
+ std::enable_if_t <!Reduction::is_usm>
262
+ reduSaveFinalResultToUserMem (handler &CGH, Reduction &Redu);
263
+
264
+ template <typename KernelName, class Reduction >
265
+ std::enable_if_t <Reduction::is_usm>
266
+ reduSaveFinalResultToUserMem (handler &CGH, Reduction &Redu);
267
+
268
+ template <typename ... Reduction, size_t ... Is>
269
+ shared_ptr_class<event>
270
+ reduSaveFinalResultToUserMem (shared_ptr_class<detail::queue_impl> Queue,
271
+ bool IsHost, std::tuple<Reduction...> &ReduTuple,
272
+ std::index_sequence<Is...>);
273
+
274
+ template <typename Reduction, typename ... RestT>
275
+ std::enable_if_t <!Reduction::is_usm>
276
+ reduSaveFinalResultToUserMemHelper (std::vector<event> &Events,
277
+ shared_ptr_class<detail::queue_impl> Queue,
278
+ bool IsHost, Reduction &Redu, RestT... Rest);
279
+
261
280
__SYCL_EXPORT size_t reduGetMaxWGSize (shared_ptr_class<queue_impl> Queue,
262
281
size_t LocalMemBytesPerWorkItem);
263
282
@@ -1159,73 +1178,43 @@ class __SYCL_EXPORT handler {
1159
1178
#endif
1160
1179
}
1161
1180
1162
- // / Implements parallel_for() accepting nd_range and 1 reduction variable
1163
- // / having 'read_write' access mode.
1164
- // / This version uses fast sycl::atomic operations to update user's reduction
1181
+ // / Implements parallel_for() accepting nd_range \p Range and one reduction
1182
+ // / object. This version uses fast sycl::atomic operations to update reduction
1165
1183
// / variable at the end of each work-group work.
1184
+ //
1185
+ // If the reduction variable must be initialized with the identity value
1186
+ // before the kernel run, then an additional working accessor is created,
1187
+ // initialized with the identity value and used in the kernel. That working
1188
+ // accessor is then copied to user's accessor or USM pointer after
1189
+ // the kernel run.
1190
+ // For USM pointers without initialize_to_identity properties the same scheme
1191
+ // with working accessor is used as re-using user's USM pointer in the kernel
1192
+ // would require creation of another variant of user's kernel, which does not
1193
+ // seem efficient.
1166
1194
template <typename KernelName = detail::auto_name, typename KernelType,
1167
1195
int Dims, typename Reduction>
1168
- detail::enable_if_t <Reduction::accessor_mode == access ::mode::read_write &&
1169
- Reduction::has_fast_atomics && !Reduction::is_usm>
1170
- parallel_for (nd_range<Dims> Range, Reduction Redu,
1171
- _KERNELFUNCPARAM (KernelFunc)) {
1172
- ONEAPI::detail::reduCGFunc<KernelName>(*this , KernelFunc, Range, Redu,
1173
- Redu.getUserAccessor ());
1174
- }
1175
-
1176
- // / Implements parallel_for() accepting nd_range and 1 reduction variable
1177
- // / having 'read_write' access mode.
1178
- // / This version uses fast sycl::atomic operations to update user's reduction
1179
- // / variable at the end of each work-group work.
1180
- template <typename KernelName = detail::auto_name, typename KernelType,
1181
- int Dims, typename Reduction>
1182
- detail::enable_if_t <Reduction::accessor_mode == access ::mode::read_write &&
1183
- Reduction::has_fast_atomics && Reduction::is_usm>
1184
- parallel_for (nd_range<Dims> Range, Reduction Redu,
1185
- _KERNELFUNCPARAM (KernelFunc)) {
1186
- ONEAPI::detail::reduCGFunc<KernelName>(*this , KernelFunc, Range, Redu,
1187
- Redu.getUSMPointer ());
1188
- }
1189
-
1190
- // / Implements parallel_for() accepting nd_range and 1 reduction variable
1191
- // / having 'discard_write' access mode.
1192
- // / This version uses fast sycl::atomic operations to update user's reduction
1193
- // / variable at the end of each work-group work.
1194
- // /
1195
- // / The reduction variable must be initialized before the kernel is started
1196
- // / because atomic operations only update the value, but never initialize it.
1197
- // / Thus, an additional 'read_write' accessor is created/initialized with
1198
- // / identity value and then passed to the kernel. After running the kernel it
1199
- // / is copied to user's 'discard_write' accessor.
1200
- template <typename KernelName = detail::auto_name, typename KernelType,
1201
- int Dims, typename Reduction>
1202
- detail::enable_if_t <Reduction::accessor_mode == access ::mode::discard_write &&
1203
- Reduction::has_fast_atomics>
1196
+ detail::enable_if_t <Reduction::has_fast_atomics>
1204
1197
parallel_for (nd_range<Dims> Range, Reduction Redu,
1205
1198
_KERNELFUNCPARAM (KernelFunc)) {
1206
1199
shared_ptr_class<detail::queue_impl> QueueCopy = MQueue;
1207
- auto RWAcc = Redu.getReadWriteScalarAcc (*this );
1208
- ONEAPI::detail::reduCGFunc<KernelName>(*this , KernelFunc, Range, Redu,
1209
- RWAcc);
1210
- this ->finalize ();
1200
+ ONEAPI::detail::reduCGFunc<KernelName>(*this , KernelFunc, Range, Redu);
1211
1201
1212
- // Copy from RWAcc to user's reduction accessor.
1213
- handler CopyHandler (QueueCopy, MIsHost);
1214
- CopyHandler.saveCodeLoc (MCodeLoc);
1215
- #ifndef __SYCL_DEVICE_ONLY__
1216
- CopyHandler.associateWithHandler (&RWAcc, access ::target::global_buffer);
1217
- Redu.associateWithHandler (CopyHandler);
1218
- #endif
1219
- CopyHandler.copy (RWAcc, Redu.getUserAccessor ());
1220
- MLastEvent = CopyHandler.finalize ();
1202
+ if (Reduction::is_usm || Redu.initializeToIdentity ()) {
1203
+ this ->finalize ();
1204
+ handler CopyHandler (QueueCopy, MIsHost);
1205
+ CopyHandler.saveCodeLoc (MCodeLoc);
1206
+ ONEAPI::detail::reduSaveFinalResultToUserMem<KernelName>(CopyHandler,
1207
+ Redu);
1208
+ MLastEvent = CopyHandler.finalize ();
1209
+ }
1221
1210
}
1222
1211
1223
1212
// / Defines and invokes a SYCL kernel function for the specified nd_range.
1224
- // / Performs reduction operation specified in \param Redu.
1213
+ // / Performs reduction operation specified in \p Redu.
1225
1214
// /
1226
1215
// / The SYCL kernel function is defined as a lambda function or a named
1227
1216
// / function object type and given an id or item for indexing in the indexing
1228
- // / space defined by range .
1217
+ // / space defined by \p Range .
1229
1218
// / If it is a named function object and the function object type is
1230
1219
// / globally visible, there is no need for the developer to provide
1231
1220
// / a kernel name for it.
@@ -1300,13 +1289,50 @@ class __SYCL_EXPORT handler {
1300
1289
AuxHandler, NWorkItems, MaxWGSize, Redu);
1301
1290
MLastEvent = AuxHandler.finalize ();
1302
1291
} // end while (NWorkItems > 1)
1292
+
1293
+ if (Reduction::is_usm || Redu.hasUserDiscardWriteAccessor ()) {
1294
+ handler CopyHandler (QueueCopy, MIsHost);
1295
+ CopyHandler.saveCodeLoc (MCodeLoc);
1296
+ ONEAPI::detail::reduSaveFinalResultToUserMem<KernelName>(CopyHandler,
1297
+ Redu);
1298
+ MLastEvent = CopyHandler.finalize ();
1299
+ }
1303
1300
}
1304
1301
1305
1302
// This version of parallel_for may handle one or more reductions packed in
1306
1303
// \p Rest argument. Note thought that the last element in \p Rest pack is
1307
1304
// the kernel function.
1308
1305
// TODO: this variant is currently enabled for 2+ reductions only as the
1309
1306
// versions handling 1 reduction variable are more efficient right now.
1307
+ //
1308
+ // Algorithm:
1309
+ // 1) discard_write accessor (DWAcc), InitializeToIdentity = true:
1310
+ // a) Create uninitialized buffer and read_write accessor (RWAcc).
1311
+ // b) discard-write partial sums to RWAcc.
1312
+ // c) Repeat the steps (a) and (b) to get one final sum.
1313
+ // d) Copy RWAcc to DWAcc.
1314
+ // 2) read_write accessor (RWAcc), InitializeToIdentity = false:
1315
+ // a) Create new uninitialized buffer (if #work-groups > 1) and RWAcc or
1316
+ // re-use user's RWAcc (if #work-groups is 1).
1317
+ // b) discard-write to RWAcc (#WG > 1), or update-write (#WG == 1).
1318
+ // c) Repeat the steps (a) and (b) to get one final sum.
1319
+ // 3) read_write accessor (RWAcc), InitializeToIdentity = true:
1320
+ // a) Create new uninitialized buffer (if #work-groups > 1) and RWAcc or
1321
+ // re-use user's RWAcc (if #work-groups is 1).
1322
+ // b) discard-write to RWAcc.
1323
+ // c) Repeat the steps (a) and (b) to get one final sum.
1324
+ // 4) USM pointer, InitializeToIdentity = false:
1325
+ // a) Create new uninitialized buffer (if #work-groups > 1) and RWAcc or
1326
+ // re-use user's USM pointer (if #work-groups is 1).
1327
+ // b) discard-write to RWAcc (#WG > 1) or
1328
+ // update-write to USM pointer (#WG == 1).
1329
+ // c) Repeat the steps (a) and (b) to get one final sum.
1330
+ // 5) USM pointer, InitializeToIdentity = true:
1331
+ // a) Create new uninitialized buffer (if #work-groups > 1) and RWAcc or
1332
+ // re-use user's USM pointer (if #work-groups is 1).
1333
+ // b) discard-write to RWAcc (#WG > 1) or
1334
+ // discard-write to USM pointer (#WG == 1).
1335
+ // c) Repeat the steps (a) and (b) to get one final sum.
1310
1336
template <typename KernelName = detail::auto_name, int Dims,
1311
1337
typename ... RestT>
1312
1338
std::enable_if_t <(sizeof ...(RestT) >= 3 &&
@@ -1348,6 +1374,11 @@ class __SYCL_EXPORT handler {
1348
1374
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
1349
1375
MLastEvent = AuxHandler.finalize ();
1350
1376
} // end while (NWorkItems > 1)
1377
+
1378
+ auto CopyEvent = ONEAPI::detail::reduSaveFinalResultToUserMem (
1379
+ QueueCopy, MIsHost, ReduTuple, ReduIndices);
1380
+ if (CopyEvent)
1381
+ MLastEvent = *CopyEvent;
1351
1382
}
1352
1383
1353
1384
// / Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -2085,9 +2116,17 @@ class __SYCL_EXPORT handler {
2085
2116
// Make reduction_impl friend to store buffers and arrays created for it
2086
2117
// in handler from reduction_impl methods.
2087
2118
template <typename T, class BinaryOperation , int Dims, bool IsUSM,
2088
- access::mode AccMode, access:: placeholder IsPlaceholder>
2119
+ access::placeholder IsPlaceholder>
2089
2120
friend class ONEAPI ::detail::reduction_impl;
2090
2121
2122
+ // This method needs to call the method finalize().
2123
+ template <typename Reduction, typename ... RestT>
2124
+ std::enable_if_t <!Reduction::is_usm> friend ONEAPI::detail::
2125
+ reduSaveFinalResultToUserMemHelper (
2126
+ std::vector<event> &Events,
2127
+ shared_ptr_class<detail::queue_impl> Queue, bool IsHost, Reduction &,
2128
+ RestT...);
2129
+
2091
2130
friend void detail::associateWithHandler (handler &,
2092
2131
detail::AccessorBaseHost *,
2093
2132
access ::target);
0 commit comments