@@ -56,9 +56,9 @@ struct MaskedExtractStridedFunctor
56
56
char *dst_data_p,
57
57
size_t orthog_iter_size,
58
58
size_t masked_iter_size,
59
- OrthogIndexerT orthog_src_dst_indexer_,
60
- MaskedSrcIndexerT masked_src_indexer_,
61
- MaskedDstIndexerT masked_dst_indexer_)
59
+ const OrthogIndexerT & orthog_src_dst_indexer_,
60
+ const MaskedSrcIndexerT & masked_src_indexer_,
61
+ const MaskedDstIndexerT & masked_dst_indexer_)
62
62
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63
63
orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
64
64
orthog_src_dst_indexer(orthog_src_dst_indexer_),
@@ -106,13 +106,14 @@ struct MaskedExtractStridedFunctor
106
106
char *dst_cp = nullptr ;
107
107
size_t orthog_nelems = 0 ;
108
108
size_t masked_nelems = 0 ;
109
- OrthogIndexerT
110
- orthog_src_dst_indexer; // has nd, shape, src_strides, dst_strides for
111
- // dimensions that ARE NOT masked
112
- MaskedSrcIndexerT masked_src_indexer; // has nd, shape, src_strides for
113
- // dimensions that ARE masked
114
- MaskedDstIndexerT
115
- masked_dst_indexer; // has 1, dst_strides for dimensions that ARE masked
109
+ // has nd, shape, src_strides, dst_strides for
110
+ // dimensions that ARE NOT masked
111
+ const OrthogIndexerT orthog_src_dst_indexer;
112
+ // has nd, shape, src_strides for
113
+ // dimensions that ARE masked
114
+ const MaskedSrcIndexerT masked_src_indexer;
115
+ // has 1, dst_strides for dimensions that ARE masked
116
+ const MaskedDstIndexerT masked_dst_indexer;
116
117
};
117
118
118
119
template <typename OrthogIndexerT,
@@ -127,9 +128,9 @@ struct MaskedPlaceStridedFunctor
127
128
const char *rhs_data_p,
128
129
size_t orthog_iter_size,
129
130
size_t masked_iter_size,
130
- OrthogIndexerT orthog_dst_rhs_indexer_,
131
- MaskedDstIndexerT masked_dst_indexer_,
132
- MaskedRhsIndexerT masked_rhs_indexer_)
131
+ const OrthogIndexerT & orthog_dst_rhs_indexer_,
132
+ const MaskedDstIndexerT & masked_dst_indexer_,
133
+ const MaskedRhsIndexerT & masked_rhs_indexer_)
133
134
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
134
135
orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
135
136
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
@@ -177,13 +178,14 @@ struct MaskedPlaceStridedFunctor
177
178
const char *rhs_cp = nullptr ;
178
179
size_t orthog_nelems = 0 ;
179
180
size_t masked_nelems = 0 ;
180
- OrthogIndexerT
181
- orthog_dst_rhs_indexer; // has nd, shape, dst_strides, rhs_strides for
182
- // dimensions that ARE NOT masked
183
- MaskedDstIndexerT masked_dst_indexer; // has nd, shape, dst_strides for
184
- // dimensions that ARE masked
185
- MaskedRhsIndexerT
186
- masked_rhs_indexer; // has 1, rhs_strides for dimensions that ARE masked
181
+ // has nd, shape, dst_strides, rhs_strides for
182
+ // dimensions that ARE NOT masked
183
+ const OrthogIndexerT orthog_dst_rhs_indexer;
184
+ // has nd, shape, dst_strides for
185
+ // dimensions that ARE masked
186
+ const MaskedDstIndexerT masked_dst_indexer;
187
+ // has 1, rhs_strides for dimensions that ARE masked
188
+ const MaskedRhsIndexerT masked_rhs_indexer;
187
189
};
188
190
189
191
// ======= Masked extraction ================================
@@ -226,12 +228,12 @@ sycl::event masked_extract_all_slices_strided_impl(
226
228
// using StridedIndexer;
227
229
// using TwoZeroOffsets_Indexer;
228
230
229
- TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
231
+ constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
230
232
231
233
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
232
234
* *_packed_shape_strides) */
233
- StridedIndexer masked_src_indexer (nd, 0 , packed_src_shape_strides);
234
- Strided1DIndexer masked_dst_indexer (0 , dst_size, dst_stride);
235
+ const StridedIndexer masked_src_indexer (nd, 0 , packed_src_shape_strides);
236
+ const Strided1DIndexer masked_dst_indexer (0 , dst_size, dst_stride);
235
237
236
238
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
237
239
cgh.depends_on (depends);
@@ -283,17 +285,16 @@ sycl::event masked_extract_some_slices_strided_impl(
283
285
const char *cumsum_p,
284
286
char *dst_p,
285
287
int orthog_nd,
286
- const ssize_t
287
- *packed_ortho_src_dst_shape_strides, // [ortho_shape, ortho_src_strides,
288
- // ortho_dst_strides], length
289
- // 3*ortho_nd
288
+ // [ortho_shape, ortho_src_strides, // ortho_dst_strides],
289
+ // length 3*ortho_nd
290
+ const ssize_t *packed_ortho_src_dst_shape_strides,
290
291
ssize_t ortho_src_offset,
291
292
ssize_t ortho_dst_offset,
292
293
int masked_nd,
293
- const ssize_t *packed_masked_src_shape_strides, // [masked_src_shape,
294
- // masked_src_strides],
295
- // length 2*masked_nd
296
- ssize_t masked_dst_size, // mask_dst is 1D
294
+ // [masked_src_shape, masked_src_strides] ,
295
+ // length 2*masked_nd, mask_dst is 1D
296
+ const ssize_t *packed_masked_src_shape_strides,
297
+ ssize_t masked_dst_size,
297
298
ssize_t masked_dst_stride,
298
299
const std::vector<sycl::event> &depends = {})
299
300
{
@@ -302,13 +303,14 @@ sycl::event masked_extract_some_slices_strided_impl(
302
303
// using StridedIndexer;
303
304
// using TwoOffsets_StridedIndexer;
304
305
305
- TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306
+ const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306
307
orthog_nd, ortho_src_offset, ortho_dst_offset,
307
308
packed_ortho_src_dst_shape_strides};
308
309
309
- StridedIndexer masked_src_indexer{masked_nd, 0 ,
310
- packed_masked_src_shape_strides};
311
- Strided1DIndexer masked_dst_indexer{0 , masked_dst_size, masked_dst_stride};
310
+ const StridedIndexer masked_src_indexer{masked_nd, 0 ,
311
+ packed_masked_src_shape_strides};
312
+ const Strided1DIndexer masked_dst_indexer{0 , masked_dst_size,
313
+ masked_dst_stride};
312
314
313
315
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
314
316
cgh.depends_on (depends);
@@ -403,12 +405,12 @@ sycl::event masked_place_all_slices_strided_impl(
403
405
ssize_t rhs_stride,
404
406
const std::vector<sycl::event> &depends = {})
405
407
{
406
- TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
408
+ constexpr TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
407
409
408
410
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
409
411
* *_packed_shape_strides) */
410
- StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
411
- Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
412
+ const StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
413
+ const Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
412
414
413
415
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
414
416
cgh.depends_on (depends);
@@ -460,30 +462,29 @@ sycl::event masked_place_some_slices_strided_impl(
460
462
const char *cumsum_p,
461
463
const char *rhs_p,
462
464
int orthog_nd,
463
- const ssize_t
464
- *packed_ortho_dst_rhs_shape_strides, // [ortho_shape, ortho_dst_strides,
465
- // ortho_rhs_strides], length
466
- // 3*ortho_nd
465
+ // [ortho_shape, ortho_dst_strides, ortho_rhs_strides],
466
+ // length 3*ortho_nd
467
+ const ssize_t *packed_ortho_dst_rhs_shape_strides,
467
468
ssize_t ortho_dst_offset,
468
469
ssize_t ortho_rhs_offset,
469
470
int masked_nd,
470
- const ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape,
471
- // masked_dst_strides],
472
- // length 2*masked_nd
473
- ssize_t masked_rhs_size, // mask_dst is 1D
471
+ // [masked_dst_shape, masked_dst_strides] ,
472
+ // length 2*masked_nd, mask_dst is 1D
473
+ const ssize_t *packed_masked_dst_shape_strides,
474
+ ssize_t masked_rhs_size,
474
475
ssize_t masked_rhs_stride,
475
476
const std::vector<sycl::event> &depends = {})
476
477
{
477
- TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478
+ const TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478
479
orthog_nd, ortho_dst_offset, ortho_rhs_offset,
479
480
packed_ortho_dst_rhs_shape_strides};
480
481
481
482
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
482
483
* *_packed_shape_strides) */
483
- StridedIndexer masked_dst_indexer{masked_nd, 0 ,
484
- packed_masked_dst_shape_strides};
485
- Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
486
- masked_rhs_stride};
484
+ const StridedIndexer masked_dst_indexer{masked_nd, 0 ,
485
+ packed_masked_dst_shape_strides};
486
+ const Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
487
+ masked_rhs_stride};
487
488
488
489
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
489
490
cgh.depends_on (depends);
0 commit comments