Skip to content

Commit 786f7fc

Browse files
Merge pull request #1582 from IntelPython/avoid-copy-in-functor-constructors
No copying of indexers in functor constructors
2 parents 2cf2187 + 47d9a7a commit 786f7fc

File tree

9 files changed

+162
-156
lines changed

9 files changed

+162
-156
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ inclusive_scan_base_step(sycl::queue &exec_q,
107107
outputT *output,
108108
const size_t s0,
109109
const size_t s1,
110-
IndexerT indexer,
111-
TransformerT transformer,
110+
const IndexerT &indexer,
111+
const TransformerT &transformer,
112112
size_t &n_groups,
113113
const std::vector<sycl::event> &depends = {})
114114
{
@@ -234,8 +234,8 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
234234
outputT *output,
235235
const size_t s0,
236236
const size_t s1,
237-
IndexerT indexer,
238-
TransformerT transformer,
237+
const IndexerT &indexer,
238+
const TransformerT &transformer,
239239
std::vector<sycl::event> &host_tasks,
240240
const std::vector<sycl::event> &depends = {})
241241
{

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

+52-51
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ struct MaskedExtractStridedFunctor
5656
char *dst_data_p,
5757
size_t orthog_iter_size,
5858
size_t masked_iter_size,
59-
OrthogIndexerT orthog_src_dst_indexer_,
60-
MaskedSrcIndexerT masked_src_indexer_,
61-
MaskedDstIndexerT masked_dst_indexer_)
59+
const OrthogIndexerT &orthog_src_dst_indexer_,
60+
const MaskedSrcIndexerT &masked_src_indexer_,
61+
const MaskedDstIndexerT &masked_dst_indexer_)
6262
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
6363
orthog_nelems(orthog_iter_size), masked_nelems(masked_iter_size),
6464
orthog_src_dst_indexer(orthog_src_dst_indexer_),
@@ -106,13 +106,14 @@ struct MaskedExtractStridedFunctor
106106
char *dst_cp = nullptr;
107107
size_t orthog_nelems = 0;
108108
size_t masked_nelems = 0;
109-
OrthogIndexerT
110-
orthog_src_dst_indexer; // has nd, shape, src_strides, dst_strides for
111-
// dimensions that ARE NOT masked
112-
MaskedSrcIndexerT masked_src_indexer; // has nd, shape, src_strides for
113-
// dimensions that ARE masked
114-
MaskedDstIndexerT
115-
masked_dst_indexer; // has 1, dst_strides for dimensions that ARE masked
109+
// has nd, shape, src_strides, dst_strides for
110+
// dimensions that ARE NOT masked
111+
const OrthogIndexerT orthog_src_dst_indexer;
112+
// has nd, shape, src_strides for
113+
// dimensions that ARE masked
114+
const MaskedSrcIndexerT masked_src_indexer;
115+
// has 1, dst_strides for dimensions that ARE masked
116+
const MaskedDstIndexerT masked_dst_indexer;
116117
};
117118

118119
template <typename OrthogIndexerT,
@@ -127,9 +128,9 @@ struct MaskedPlaceStridedFunctor
127128
const char *rhs_data_p,
128129
size_t orthog_iter_size,
129130
size_t masked_iter_size,
130-
OrthogIndexerT orthog_dst_rhs_indexer_,
131-
MaskedDstIndexerT masked_dst_indexer_,
132-
MaskedRhsIndexerT masked_rhs_indexer_)
131+
const OrthogIndexerT &orthog_dst_rhs_indexer_,
132+
const MaskedDstIndexerT &masked_dst_indexer_,
133+
const MaskedRhsIndexerT &masked_rhs_indexer_)
133134
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
134135
orthog_nelems(orthog_iter_size), masked_nelems(masked_iter_size),
135136
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
@@ -177,13 +178,14 @@ struct MaskedPlaceStridedFunctor
177178
const char *rhs_cp = nullptr;
178179
size_t orthog_nelems = 0;
179180
size_t masked_nelems = 0;
180-
OrthogIndexerT
181-
orthog_dst_rhs_indexer; // has nd, shape, dst_strides, rhs_strides for
182-
// dimensions that ARE NOT masked
183-
MaskedDstIndexerT masked_dst_indexer; // has nd, shape, dst_strides for
184-
// dimensions that ARE masked
185-
MaskedRhsIndexerT
186-
masked_rhs_indexer; // has 1, rhs_strides for dimensions that ARE masked
181+
// has nd, shape, dst_strides, rhs_strides for
182+
// dimensions that ARE NOT masked
183+
const OrthogIndexerT orthog_dst_rhs_indexer;
184+
// has nd, shape, dst_strides for
185+
// dimensions that ARE masked
186+
const MaskedDstIndexerT masked_dst_indexer;
187+
// has 1, rhs_strides for dimensions that ARE masked
188+
const MaskedRhsIndexerT masked_rhs_indexer;
187189
};
188190

189191
// ======= Masked extraction ================================
@@ -226,12 +228,12 @@ sycl::event masked_extract_all_slices_strided_impl(
226228
// using StridedIndexer;
227229
// using TwoZeroOffsets_Indexer;
228230

229-
TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
231+
constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
230232

231233
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
232234
* *_packed_shape_strides) */
233-
StridedIndexer masked_src_indexer(nd, 0, packed_src_shape_strides);
234-
Strided1DIndexer masked_dst_indexer(0, dst_size, dst_stride);
235+
const StridedIndexer masked_src_indexer(nd, 0, packed_src_shape_strides);
236+
const Strided1DIndexer masked_dst_indexer(0, dst_size, dst_stride);
235237

236238
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
237239
cgh.depends_on(depends);
@@ -283,17 +285,16 @@ sycl::event masked_extract_some_slices_strided_impl(
283285
const char *cumsum_p,
284286
char *dst_p,
285287
int orthog_nd,
286-
const ssize_t
287-
*packed_ortho_src_dst_shape_strides, // [ortho_shape, ortho_src_strides,
288-
// ortho_dst_strides], length
289-
// 3*ortho_nd
288+
// [ortho_shape, ortho_src_strides, // ortho_dst_strides],
289+
// length 3*ortho_nd
290+
const ssize_t *packed_ortho_src_dst_shape_strides,
290291
ssize_t ortho_src_offset,
291292
ssize_t ortho_dst_offset,
292293
int masked_nd,
293-
const ssize_t *packed_masked_src_shape_strides, // [masked_src_shape,
294-
// masked_src_strides],
295-
// length 2*masked_nd
296-
ssize_t masked_dst_size, // mask_dst is 1D
294+
// [masked_src_shape, masked_src_strides],
295+
// length 2*masked_nd, mask_dst is 1D
296+
const ssize_t *packed_masked_src_shape_strides,
297+
ssize_t masked_dst_size,
297298
ssize_t masked_dst_stride,
298299
const std::vector<sycl::event> &depends = {})
299300
{
@@ -302,13 +303,14 @@ sycl::event masked_extract_some_slices_strided_impl(
302303
// using StridedIndexer;
303304
// using TwoOffsets_StridedIndexer;
304305

305-
TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306+
const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306307
orthog_nd, ortho_src_offset, ortho_dst_offset,
307308
packed_ortho_src_dst_shape_strides};
308309

309-
StridedIndexer masked_src_indexer{masked_nd, 0,
310-
packed_masked_src_shape_strides};
311-
Strided1DIndexer masked_dst_indexer{0, masked_dst_size, masked_dst_stride};
310+
const StridedIndexer masked_src_indexer{masked_nd, 0,
311+
packed_masked_src_shape_strides};
312+
const Strided1DIndexer masked_dst_indexer{0, masked_dst_size,
313+
masked_dst_stride};
312314

313315
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
314316
cgh.depends_on(depends);
@@ -403,12 +405,12 @@ sycl::event masked_place_all_slices_strided_impl(
403405
ssize_t rhs_stride,
404406
const std::vector<sycl::event> &depends = {})
405407
{
406-
TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
408+
constexpr TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
407409

408410
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
409411
* *_packed_shape_strides) */
410-
StridedIndexer masked_dst_indexer(nd, 0, packed_dst_shape_strides);
411-
Strided1DCyclicIndexer masked_rhs_indexer(0, rhs_size, rhs_stride);
412+
const StridedIndexer masked_dst_indexer(nd, 0, packed_dst_shape_strides);
413+
const Strided1DCyclicIndexer masked_rhs_indexer(0, rhs_size, rhs_stride);
412414

413415
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
414416
cgh.depends_on(depends);
@@ -460,30 +462,29 @@ sycl::event masked_place_some_slices_strided_impl(
460462
const char *cumsum_p,
461463
const char *rhs_p,
462464
int orthog_nd,
463-
const ssize_t
464-
*packed_ortho_dst_rhs_shape_strides, // [ortho_shape, ortho_dst_strides,
465-
// ortho_rhs_strides], length
466-
// 3*ortho_nd
465+
// [ortho_shape, ortho_dst_strides, ortho_rhs_strides],
466+
// length 3*ortho_nd
467+
const ssize_t *packed_ortho_dst_rhs_shape_strides,
467468
ssize_t ortho_dst_offset,
468469
ssize_t ortho_rhs_offset,
469470
int masked_nd,
470-
const ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape,
471-
// masked_dst_strides],
472-
// length 2*masked_nd
473-
ssize_t masked_rhs_size, // mask_dst is 1D
471+
// [masked_dst_shape, masked_dst_strides],
472+
// length 2*masked_nd, mask_dst is 1D
473+
const ssize_t *packed_masked_dst_shape_strides,
474+
ssize_t masked_rhs_size,
474475
ssize_t masked_rhs_stride,
475476
const std::vector<sycl::event> &depends = {})
476477
{
477-
TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478+
const TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478479
orthog_nd, ortho_dst_offset, ortho_rhs_offset,
479480
packed_ortho_dst_rhs_shape_strides};
480481

481482
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
482483
* *_packed_shape_strides) */
483-
StridedIndexer masked_dst_indexer{masked_nd, 0,
484-
packed_masked_dst_shape_strides};
485-
Strided1DCyclicIndexer masked_rhs_indexer{0, masked_rhs_size,
486-
masked_rhs_stride};
484+
const StridedIndexer masked_dst_indexer{masked_nd, 0,
485+
packed_masked_dst_shape_strides};
486+
const Strided1DCyclicIndexer masked_rhs_indexer{0, masked_rhs_size,
487+
masked_rhs_stride};
487488

488489
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
489490
cgh.depends_on(depends);

dpctl/tensor/libtensor/include/kernels/clip.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,14 @@ template <typename T, typename IndexerT> class ClipStridedFunctor
235235
const T *min_p = nullptr;
236236
const T *max_p = nullptr;
237237
T *dst_p = nullptr;
238-
IndexerT indexer;
238+
const IndexerT indexer;
239239

240240
public:
241241
ClipStridedFunctor(const T *x_p_,
242242
const T *min_p_,
243243
const T *max_p_,
244244
T *dst_p_,
245-
IndexerT indexer_)
245+
const IndexerT &indexer_)
246246
: x_p(x_p_), min_p(min_p_), max_p(max_p_), dst_p(dst_p_),
247247
indexer(indexer_)
248248
{
@@ -298,7 +298,7 @@ sycl::event clip_strided_impl(sycl::queue &q,
298298
sycl::event clip_ev = q.submit([&](sycl::handler &cgh) {
299299
cgh.depends_on(depends);
300300

301-
FourOffsets_StridedIndexer indexer{
301+
const FourOffsets_StridedIndexer indexer{
302302
nd, x_offset, min_offset, max_offset, dst_offset, shape_strides};
303303

304304
cgh.parallel_for<clip_strided_kernel<T, FourOffsets_StridedIndexer>>(

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

+32-31
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class GenericCopyFunctor
7878
private:
7979
const srcT *src_ = nullptr;
8080
dstT *dst_ = nullptr;
81-
IndexerT indexer_;
81+
const IndexerT indexer_;
8282

8383
public:
84-
GenericCopyFunctor(const srcT *src_p, dstT *dst_p, IndexerT indexer)
84+
GenericCopyFunctor(const srcT *src_p, dstT *dst_p, const IndexerT &indexer)
8585
: src_(src_p), dst_(dst_p), indexer_(indexer)
8686
{
8787
}
@@ -169,8 +169,8 @@ copy_and_cast_generic_impl(sycl::queue &q,
169169
cgh.depends_on(depends);
170170
cgh.depends_on(additional_depends);
171171

172-
TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
173-
shape_and_strides};
172+
const TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
173+
shape_and_strides};
174174
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
175175
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
176176

@@ -460,8 +460,8 @@ copy_and_cast_nd_specialized_impl(sycl::queue &q,
460460

461461
sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) {
462462
using IndexerT = TwoOffsets_FixedDimStridedIndexer<nd>;
463-
IndexerT indexer{shape, src_strides, dst_strides, src_offset,
464-
dst_offset};
463+
const IndexerT indexer{shape, src_strides, dst_strides, src_offset,
464+
dst_offset};
465465
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
466466
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
467467

@@ -515,12 +515,12 @@ class GenericCopyFromHostFunctor
515515
private:
516516
const AccessorT src_acc_;
517517
dstTy *dst_ = nullptr;
518-
IndexerT indexer_;
518+
const IndexerT indexer_;
519519

520520
public:
521521
GenericCopyFromHostFunctor(const AccessorT &src_acc,
522522
dstTy *dst_p,
523-
IndexerT indexer)
523+
const IndexerT &indexer)
524524
: src_acc_(src_acc), dst_(dst_p), indexer_(indexer)
525525
{
526526
}
@@ -618,7 +618,7 @@ void copy_and_cast_from_host_impl(
618618

619619
sycl::accessor npy_acc(npy_buf, cgh, sycl::read_only);
620620

621-
TwoOffsets_StridedIndexer indexer{
621+
const TwoOffsets_StridedIndexer indexer{
622622
nd, src_offset - src_min_nelem_offset, dst_offset,
623623
const_cast<const ssize_t *>(shape_and_strides)};
624624

@@ -666,14 +666,14 @@ class GenericCopyForReshapeFunctor
666666
private:
667667
const Ty *src_p = nullptr;
668668
Ty *dst_p = nullptr;
669-
SrcIndexerT src_indexer_;
670-
DstIndexerT dst_indexer_;
669+
const SrcIndexerT src_indexer_;
670+
const DstIndexerT dst_indexer_;
671671

672672
public:
673673
GenericCopyForReshapeFunctor(const char *src_ptr,
674674
char *dst_ptr,
675-
SrcIndexerT src_indexer,
676-
DstIndexerT dst_indexer)
675+
const SrcIndexerT &src_indexer,
676+
const DstIndexerT &dst_indexer)
677677
: src_p(reinterpret_cast<const Ty *>(src_ptr)),
678678
dst_p(reinterpret_cast<Ty *>(dst_ptr)), src_indexer_(src_indexer),
679679
dst_indexer_(dst_indexer)
@@ -747,8 +747,8 @@ copy_for_reshape_generic_impl(sycl::queue &q,
747747
const ssize_t *dst_shape_and_strides = const_cast<const ssize_t *>(
748748
packed_shapes_and_strides + (2 * src_nd));
749749

750-
StridedIndexer src_indexer{src_nd, 0, src_shape_and_strides};
751-
StridedIndexer dst_indexer{dst_nd, 0, dst_shape_and_strides};
750+
const StridedIndexer src_indexer{src_nd, 0, src_shape_and_strides};
751+
const StridedIndexer dst_indexer{dst_nd, 0, dst_shape_and_strides};
752752

753753
using KernelName =
754754
copy_for_reshape_generic_kernel<Ty, StridedIndexer, StridedIndexer>;
@@ -864,14 +864,14 @@ class StridedCopyForRollFunctor
864864
private:
865865
const Ty *src_p = nullptr;
866866
Ty *dst_p = nullptr;
867-
SrcIndexerT src_indexer_;
868-
DstIndexerT dst_indexer_;
867+
const SrcIndexerT src_indexer_;
868+
const DstIndexerT dst_indexer_;
869869

870870
public:
871871
StridedCopyForRollFunctor(const Ty *src_ptr,
872872
Ty *dst_ptr,
873-
SrcIndexerT src_indexer,
874-
DstIndexerT dst_indexer)
873+
const SrcIndexerT &src_indexer,
874+
const DstIndexerT &dst_indexer)
875875
: src_p(src_ptr), dst_p(dst_ptr), src_indexer_(src_indexer),
876876
dst_indexer_(dst_indexer)
877877
{
@@ -946,14 +946,15 @@ sycl::event copy_for_roll_strided_impl(sycl::queue &q,
946946
// USM array of size 3 * nd
947947
// [ common_shape; src_strides; dst_strides ]
948948

949-
StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
950-
LeftRolled1DTransformer left_roll_transformer{shift, nelems};
949+
const StridedIndexer src_indexer{nd, src_offset,
950+
packed_shapes_and_strides};
951+
const LeftRolled1DTransformer left_roll_transformer{shift, nelems};
951952

952953
using CompositeIndexerT =
953954
CompositionIndexer<StridedIndexer, LeftRolled1DTransformer>;
954955

955-
CompositeIndexerT rolled_src_indexer(src_indexer,
956-
left_roll_transformer);
956+
const CompositeIndexerT rolled_src_indexer(src_indexer,
957+
left_roll_transformer);
957958

958959
UnpackedStridedIndexer dst_indexer{nd, dst_offset,
959960
packed_shapes_and_strides,
@@ -1024,12 +1025,12 @@ sycl::event copy_for_roll_contig_impl(sycl::queue &q,
10241025
sycl::event copy_for_roll_ev = q.submit([&](sycl::handler &cgh) {
10251026
cgh.depends_on(depends);
10261027

1027-
NoOpIndexer src_indexer{};
1028-
LeftRolled1DTransformer roller{shift, nelems};
1028+
constexpr NoOpIndexer src_indexer{};
1029+
const LeftRolled1DTransformer roller{shift, nelems};
10291030

1030-
CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>
1031+
const CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>
10311032
left_rolled_src_indexer{src_indexer, roller};
1032-
NoOpIndexer dst_indexer{};
1033+
constexpr NoOpIndexer dst_indexer{};
10331034

10341035
using KernelName = copy_for_roll_contig_kernel<Ty>;
10351036

@@ -1119,11 +1120,11 @@ sycl::event copy_for_roll_ndshift_strided_impl(
11191120
const ssize_t *shifts_ptr =
11201121
packed_shapes_and_strides_and_shifts + 3 * nd;
11211122

1122-
RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr,
1123-
src_offset};
1123+
const RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr,
1124+
shifts_ptr, src_offset};
11241125

1125-
UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr,
1126-
dst_strides_ptr};
1126+
const UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr,
1127+
dst_strides_ptr};
11271128

11281129
using KernelName = copy_for_roll_strided_kernel<Ty, RolledNDIndexer,
11291130
UnpackedStridedIndexer>;

0 commit comments

Comments
 (0)