Skip to content

Commit 1805102

Browse files
committed
Implement tensor.isin
isin leverages kernel very similar to searchsorted, but after the search, the position is checked, and if the position is equal to the number of elements in the searched array, existence is considered false
1 parent a26cac1 commit 1805102

File tree

7 files changed

+671
-0
lines changed

7 files changed

+671
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ set(_reduction_sources
112112
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
113113
)
114114
set(_sorting_sources
115+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
115116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
116117
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
117118
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@
198198
)
199199
from ._searchsorted import searchsorted
200200
from ._set_functions import (
201+
isin,
201202
unique_all,
202203
unique_counts,
203204
unique_inverse,
@@ -394,4 +395,5 @@
394395
"top_k",
395396
"dldevice_to_sycl_device",
396397
"sycl_device_to_dldevice",
398+
"isin",
397399
]

dpctl/tensor/_set_functions.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl.tensor as dpt
2020
import dpctl.utils as du
2121

22+
from ._copy_utils import _empty_like_orderK
2223
from ._tensor_elementwise_impl import _not_equal, _subtract
2324
from ._tensor_impl import (
2425
_copy_usm_ndarray_into_usm_ndarray,
@@ -31,6 +32,7 @@
3132
)
3233
from ._tensor_sorting_impl import (
3334
_argsort_ascending,
35+
_isin,
3436
_searchsorted_left,
3537
_sort_ascending,
3638
)
@@ -624,3 +626,63 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
624626
inv,
625627
_counts,
626628
)
629+
630+
631+
def isin(x, test_elements, /, *, assume_unique=False, invert=False):
632+
if not isinstance(x, dpt.usm_ndarray):
633+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
634+
if not isinstance(test_elements, dpt.usm_ndarray):
635+
raise TypeError(
636+
f"Expected dpctl.tensor.usm_ndarray, got {type(test_elements)}"
637+
)
638+
639+
q = du.get_execution_queue([x.sycl_queue, test_elements.sycl_queue])
640+
if q is None:
641+
raise du.ExecutionPlacementError(
642+
"Execution placement can not be unambiguously "
643+
"inferred from input arguments."
644+
)
645+
646+
x1 = dpt.reshape(x, -1)
647+
x2 = dpt.reshape(test_elements, -1)
648+
649+
x1_dt = x1.dtype
650+
x2_dt = x2.dtype
651+
652+
_manager = du.SequentialOrderManager[q]
653+
dep_evs = _manager.submitted_events
654+
655+
if x1_dt != x2_dt:
656+
dt = dpt.result_type(x1, x2)
657+
if x1_dt != dt:
658+
x1_buf = _empty_like_orderK(x1, dt)
659+
dep_evs = _manager.submitted_events
660+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
661+
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
662+
)
663+
_manager.add_event_pair(ht_ev, ev)
664+
x1 = x1_buf
665+
if x2_dt != dt:
666+
x2_buf = _empty_like_orderK(x2, dt)
667+
dep_evs = _manager.submitted_events
668+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
669+
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
670+
)
671+
_manager.add_event_pair(ht_ev, ev)
672+
x2 = x2_buf
673+
674+
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
675+
676+
dst = _empty_like_orderK(x1, dpt.bool, usm_type=dst_usm_type)
677+
678+
dep_evs = _manager.submitted_events
679+
ht_ev, s_ev = _isin(
680+
needles=x1,
681+
hay=x2,
682+
dst=dst,
683+
sycl_queue=q,
684+
invert=invert,
685+
depends=dep_evs,
686+
)
687+
_manager.add_event_pair(ht_ev, s_ev)
688+
return dpt.reshape(dst, x.shape)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//=== isin.hpp - ---*-C++-*--/===//
2+
// Implementation of searching for membership in sorted array
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2025 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for tensor membership operations.
24+
//===----------------------------------------------------------------------===//
25+
26+
#pragma once
27+
28+
#include <algorithm>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <sycl/sycl.hpp>
32+
#include <vector>
33+
34+
#include "kernels/dpctl_tensor_types.hpp"
35+
#include "kernels/sorting/search_sorted_detail.hpp"
36+
#include "utils/offset_utils.hpp"
37+
38+
namespace dpctl
39+
{
40+
namespace tensor
41+
{
42+
namespace kernels
43+
{
44+
45+
using dpctl::tensor::ssize_t;
46+
47+
template <typename T,
48+
typename HayIndexerT,
49+
typename NeedlesIndexerT,
50+
typename OutIndexerT,
51+
typename Compare>
52+
struct IsinFunctor
53+
{
54+
private:
55+
bool invert;
56+
const T *hay_tp;
57+
const T *needles_tp;
58+
bool *out_tp;
59+
std::size_t hay_nelems;
60+
HayIndexerT hay_indexer;
61+
NeedlesIndexerT needles_indexer;
62+
OutIndexerT out_indexer;
63+
64+
public:
65+
IsinFunctor(const bool invert_,
66+
const T *hay_,
67+
const T *needles_,
68+
bool *out_,
69+
const std::size_t hay_nelems_,
70+
const HayIndexerT &hay_indexer_,
71+
const NeedlesIndexerT &needles_indexer_,
72+
const OutIndexerT &out_indexer_)
73+
: invert(invert_), hay_tp(hay_), needles_tp(needles_), out_tp(out_),
74+
hay_nelems(hay_nelems_), hay_indexer(hay_indexer_),
75+
needles_indexer(needles_indexer_), out_indexer(out_indexer_)
76+
{
77+
}
78+
79+
void operator()(sycl::id<1> id) const
80+
{
81+
const Compare comp{};
82+
83+
const std::size_t i = id[0];
84+
const T needle_v = needles_tp[needles_indexer(i)];
85+
86+
// position of the needle_v in the hay array
87+
std::size_t pos{};
88+
89+
static constexpr std::size_t zero(0);
90+
// search in hay in left-closed interval, give `pos` such that
91+
// hay[pos - 1] < needle_v <= hay[pos]
92+
93+
// lower_bound returns the first pos such that bool(hay[pos] <
94+
// needle_v) is false, i.e. needle_v <= hay[pos]
95+
pos = static_cast<std::size_t>(
96+
search_sorted_detail::lower_bound_indexed_impl(
97+
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer));
98+
bool out = (pos == hay_nelems ? false : hay_tp[pos] == needle_v);
99+
out_tp[out_indexer(i)] = (invert) ? !out : out;
100+
}
101+
};
102+
103+
typedef sycl::event (*isin_contig_impl_fp_ptr_t)(
104+
sycl::queue &,
105+
const bool,
106+
const std::size_t,
107+
const std::size_t,
108+
const char *,
109+
const ssize_t,
110+
const char *,
111+
const ssize_t,
112+
char *,
113+
const ssize_t,
114+
const std::vector<sycl::event> &);
115+
116+
template <typename T> class isin_contig_impl_krn;
117+
118+
template <typename T, typename Compare>
119+
sycl::event isin_contig_impl(sycl::queue &exec_q,
120+
const bool invert,
121+
const std::size_t hay_nelems,
122+
const std::size_t needles_nelems,
123+
const char *hay_cp,
124+
const ssize_t hay_offset,
125+
const char *needles_cp,
126+
const ssize_t needles_offset,
127+
char *out_cp,
128+
const ssize_t out_offset,
129+
const std::vector<sycl::event> &depends)
130+
{
131+
const T *hay_tp = reinterpret_cast<const T *>(hay_cp) + hay_offset;
132+
const T *needles_tp =
133+
reinterpret_cast<const T *>(needles_cp) + needles_offset;
134+
135+
bool *out_tp = reinterpret_cast<bool *>(out_cp) + out_offset;
136+
137+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
138+
cgh.depends_on(depends);
139+
140+
using KernelName = class isin_contig_impl_krn<T>;
141+
142+
sycl::range<1> gRange(needles_nelems);
143+
144+
using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
145+
146+
constexpr TrivialIndexerT hay_indexer{};
147+
constexpr TrivialIndexerT needles_indexer{};
148+
constexpr TrivialIndexerT out_indexer{};
149+
150+
const auto fnctr =
151+
IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT,
152+
Compare>(invert, hay_tp, needles_tp, out_tp, hay_nelems,
153+
hay_indexer, needles_indexer, out_indexer);
154+
155+
cgh.parallel_for<KernelName>(gRange, fnctr);
156+
});
157+
158+
return comp_ev;
159+
}
160+
161+
typedef sycl::event (*isin_strided_impl_fp_ptr_t)(
162+
sycl::queue &,
163+
const bool,
164+
const std::size_t,
165+
const std::size_t,
166+
const char *,
167+
const ssize_t,
168+
const ssize_t,
169+
const char *,
170+
const ssize_t,
171+
char *,
172+
const ssize_t,
173+
int,
174+
const ssize_t *,
175+
const std::vector<sycl::event> &);
176+
177+
template <typename T> class isin_strided_impl_krn;
178+
179+
template <typename T, typename Compare>
180+
sycl::event isin_strided_impl(
181+
sycl::queue &exec_q,
182+
const bool invert,
183+
const std::size_t hay_nelems,
184+
const std::size_t needles_nelems,
185+
const char *hay_cp,
186+
const ssize_t hay_offset,
187+
// hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array
188+
const ssize_t hay_stride,
189+
const char *needles_cp,
190+
const ssize_t needles_offset,
191+
char *out_cp,
192+
const ssize_t out_offset,
193+
const int needles_nd,
194+
// packed_shape_strides is [needles_shape, needles_strides,
195+
// out_strides] has length of 3*needles_nd
196+
const ssize_t *packed_shape_strides,
197+
const std::vector<sycl::event> &depends)
198+
{
199+
const T *hay_tp = reinterpret_cast<const T *>(hay_cp);
200+
const T *needles_tp = reinterpret_cast<const T *>(needles_cp);
201+
202+
bool *out_tp = reinterpret_cast<bool *>(out_cp);
203+
204+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
205+
cgh.depends_on(depends);
206+
207+
sycl::range<1> gRange(needles_nelems);
208+
209+
using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
210+
const HayIndexerT hay_indexer(
211+
/* offset */ hay_offset,
212+
/* size */ hay_nelems,
213+
/* step */ hay_stride);
214+
215+
using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
216+
const ssize_t *needles_shape_strides = packed_shape_strides;
217+
const NeedlesIndexerT needles_indexer(needles_nd, needles_offset,
218+
needles_shape_strides);
219+
using OutIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer;
220+
221+
const ssize_t *out_shape = packed_shape_strides;
222+
const ssize_t *out_strides = packed_shape_strides + 2 * needles_nd;
223+
const OutIndexerT out_indexer(needles_nd, out_offset, out_shape,
224+
out_strides);
225+
226+
const auto fnctr =
227+
IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT, Compare>(
228+
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
229+
needles_indexer, out_indexer);
230+
using KernelName = class isin_strided_impl_krn<T>;
231+
232+
cgh.parallel_for<KernelName>(gRange, fnctr);
233+
});
234+
235+
return comp_ev;
236+
}
237+
238+
} // namespace kernels
239+
} // namespace tensor
240+
} // namespace dpctl

0 commit comments

Comments
 (0)