Skip to content

Commit e4643cb

Browse files
Leverage dpctl.tensor.put() implementation (#1529)
* Leverage dpctl.tensor.put impl * Apply review remarks * Update tests for dpnp.put and dpnp.take * Add todo for vals type checking * Apply review remarks * Update examples --------- Co-authored-by: Anton <[email protected]>
1 parent 4df273c commit e4643cb

File tree

10 files changed

+215
-160
lines changed

10 files changed

+215
-160
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ enum class DPNPFuncName : size_t
280280
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
281281
DPNP_FN_PTP_EXT, /**< Used in numpy.ptp() impl, requires extra parameters */
282282
DPNP_FN_PUT, /**< Used in numpy.put() impl */
283-
DPNP_FN_PUT_EXT, /**< Used in numpy.put() impl, requires extra parameters */
284283
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
285284
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
286285
requires extra parameters */

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

-20
Original file line numberDiff line numberDiff line change
@@ -602,17 +602,6 @@ void (*dpnp_put_default_c)(void *,
602602
const size_t) =
603603
dpnp_put_c<_DataType, _IndecesType, _ValueType>;
604604

605-
template <typename _DataType, typename _IndecesType, typename _ValueType>
606-
DPCTLSyclEventRef (*dpnp_put_ext_c)(DPCTLSyclQueueRef,
607-
void *,
608-
void *,
609-
void *,
610-
const size_t,
611-
const size_t,
612-
const size_t,
613-
const DPCTLEventVectorRef) =
614-
dpnp_put_c<_DataType, _IndecesType, _ValueType>;
615-
616605
template <typename _DataType>
617606
DPCTLSyclEventRef
618607
dpnp_put_along_axis_c(DPCTLSyclQueueRef q_ref,
@@ -1007,15 +996,6 @@ void func_map_init_indexing_func(func_map_t &fmap)
1007996
fmap[DPNPFuncName::DPNP_FN_PUT][eft_DBL][eft_DBL] = {
1008997
eft_DBL, (void *)dpnp_put_default_c<double, int64_t, double>};
1009998

1010-
fmap[DPNPFuncName::DPNP_FN_PUT_EXT][eft_INT][eft_INT] = {
1011-
eft_INT, (void *)dpnp_put_ext_c<int32_t, int64_t, int32_t>};
1012-
fmap[DPNPFuncName::DPNP_FN_PUT_EXT][eft_LNG][eft_LNG] = {
1013-
eft_LNG, (void *)dpnp_put_ext_c<int64_t, int64_t, int64_t>};
1014-
fmap[DPNPFuncName::DPNP_FN_PUT_EXT][eft_FLT][eft_FLT] = {
1015-
eft_FLT, (void *)dpnp_put_ext_c<float, int64_t, float>};
1016-
fmap[DPNPFuncName::DPNP_FN_PUT_EXT][eft_DBL][eft_DBL] = {
1017-
eft_DBL, (void *)dpnp_put_ext_c<double, int64_t, double>};
1018-
1019999
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS][eft_INT][eft_INT] = {
10201000
eft_INT, (void *)dpnp_put_along_axis_default_c<int32_t>};
10211001
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

-2
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
150150
DPNP_FN_PROD_EXT
151151
DPNP_FN_PTP
152152
DPNP_FN_PTP_EXT
153-
DPNP_FN_PUT
154-
DPNP_FN_PUT_EXT
155153
DPNP_FN_QR
156154
DPNP_FN_QR_EXT
157155
DPNP_FN_RADIANS

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

-68
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ __all__ += [
4141
"dpnp_diagonal",
4242
"dpnp_fill_diagonal",
4343
"dpnp_indices",
44-
"dpnp_put",
4544
"dpnp_put_along_axis",
4645
"dpnp_putmask",
4746
"dpnp_select",
@@ -80,14 +79,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_3in_with_axis_func_ptr_t)(c_
8079
const size_t,
8180
const size_t,
8281
const c_dpctl.DPCTLEventVectorRef)
83-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_6in_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
84-
void *,
85-
void * ,
86-
void * ,
87-
const size_t,
88-
const size_t,
89-
const size_t,
90-
const c_dpctl.DPCTLEventVectorRef)
9182

9283

9384
cpdef utils.dpnp_descriptor dpnp_choose(utils.dpnp_descriptor x1, list choices1):
@@ -292,65 +283,6 @@ cpdef object dpnp_indices(dimensions):
292283
return dpnp_result
293284

294285

295-
cpdef dpnp_put(dpnp_descriptor x1, object ind, v):
296-
ind_is_list = isinstance(ind, list)
297-
298-
x1_obj = x1.get_array()
299-
300-
if dpnp.isscalar(ind):
301-
ind_size = 1
302-
else:
303-
ind_size = len(ind)
304-
cdef utils.dpnp_descriptor ind_array = utils_py.create_output_descriptor_py((ind_size,),
305-
dpnp.int64,
306-
None,
307-
device=x1_obj.sycl_device,
308-
usm_type=x1_obj.usm_type,
309-
sycl_queue=x1_obj.sycl_queue)
310-
if dpnp.isscalar(ind):
311-
ind_array.get_pyobj()[0] = ind
312-
else:
313-
for i in range(ind_size):
314-
ind_array.get_pyobj()[i] = ind[i]
315-
316-
if dpnp.isscalar(v):
317-
v_size = 1
318-
else:
319-
v_size = len(v)
320-
cdef utils.dpnp_descriptor v_array = utils_py.create_output_descriptor_py((v_size,),
321-
x1.dtype,
322-
None,
323-
device=x1_obj.sycl_device,
324-
usm_type=x1_obj.usm_type,
325-
sycl_queue=x1_obj.sycl_queue)
326-
if dpnp.isscalar(v):
327-
v_array.get_pyobj()[0] = v
328-
else:
329-
for i in range(v_size):
330-
v_array.get_pyobj()[i] = v[i]
331-
332-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
333-
334-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PUT_EXT, param1_type, param1_type)
335-
336-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> x1_obj.sycl_queue
337-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
338-
339-
cdef custom_indexing_6in_func_ptr_t func = <custom_indexing_6in_func_ptr_t > kernel_data.ptr
340-
341-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
342-
x1.get_data(),
343-
ind_array.get_data(),
344-
v_array.get_data(),
345-
x1.size,
346-
ind_array.size,
347-
v_array.size,
348-
NULL) # dep_events_ref
349-
350-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
351-
c_dpctl.DPCTLEvent_Delete(event_ref)
352-
353-
354286
cpdef dpnp_put_along_axis(dpnp_descriptor arr, dpnp_descriptor indices, dpnp_descriptor values, int axis):
355287
cdef shape_type_c arr_shape = arr.shape
356288
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)

dpnp/dpnp_array.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -963,8 +963,17 @@ def prod(
963963

964964
return dpnp.prod(self, axis, dtype, out, keepdims, initial, where)
965965

966-
# 'ptp',
967-
# 'put',
966+
# 'ptp'
967+
968+
def put(self, indices, vals, /, *, axis=None, mode="wrap"):
969+
"""
970+
Puts values of an array into another array along a given axis.
971+
972+
For full documentation refer to :obj:`numpy.put`.
973+
"""
974+
975+
return dpnp.put(self, indices, vals, axis=axis, mode=mode)
976+
968977
# 'ravel',
969978
# 'real',
970979
# 'repeat',

dpnp/dpnp_iface_indexing.py

+63-16
Original file line numberDiff line numberDiff line change
@@ -417,34 +417,81 @@ def place(x, mask, vals, /):
417417
return call_origin(numpy.place, x, mask, vals, dpnp_inplace=True)
418418

419419

420-
def put(x1, ind, v, mode="raise"):
420+
def put(a, indices, vals, /, *, axis=None, mode="wrap"):
421421
"""
422-
Replaces specified elements of an array with given values.
422+
Puts values of an array into another array along a given axis.
423423
424424
For full documentation refer to :obj:`numpy.put`.
425425
426426
Limitations
427427
-----------
428-
Input array is supported as :obj:`dpnp.ndarray`.
429-
Not supported parameter mode.
428+
Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray`
429+
or :class:`dpctl.tensor.usm_ndarray`.
430+
Parameter `indices` is supported as 1-D array of integer data type.
431+
Parameter `vals` must be broadcastable to the shape of `indices`
432+
and has the same data type as `a` if it is as :class:`dpnp.ndarray`
433+
or :class:`dpctl.tensor.usm_ndarray`.
434+
Parameter `mode` is supported with ``wrap``, the default, and ``clip`` values.
435+
Parameter `axis` is supported as integer only.
436+
Otherwise the function will be executed sequentially on CPU.
437+
438+
See Also
439+
--------
440+
:obj:`dpnp.putmask` : Changes elements of an array based on conditional and input values.
441+
:obj:`dpnp.place` : Change elements of an array based on conditional and input values.
442+
:obj:`dpnp.put_along_axis` : Put values into the destination array by matching 1d index and data slices.
443+
444+
Notes
445+
-----
446+
In contrast to :obj:`numpy.put` `wrap` mode which wraps indices around the array for cyclic operations,
447+
:obj:`dpnp.put` `wrap` mode clamps indices to a fixed range within the array boundaries (-n <= i < n).
448+
449+
Examples
450+
--------
451+
>>> import dpnp as np
452+
>>> x = np.arange(5)
453+
>>> indices = np.array([0, 1])
454+
>>> np.put(x, indices, [-44, -55])
455+
>>> x
456+
array([-44, -55, 2, 3, 4])
457+
458+
>>> x = np.arange(5)
459+
>>> indices = np.array([22])
460+
>>> np.put(x, indices, -5, mode='clip')
461+
>>> x
462+
array([ 0, 1, 2, 3, -5])
463+
430464
"""
431465

432-
x1_desc = dpnp.get_dpnp_descriptor(
433-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
434-
)
435-
if x1_desc:
436-
if mode != "raise":
466+
if dpnp.is_supported_array_type(a) and dpnp.is_supported_array_type(
467+
indices
468+
):
469+
if indices.ndim != 1 or not dpnp.issubdtype(
470+
indices.dtype, dpnp.integer
471+
):
437472
pass
438-
elif type(ind) is not type(v):
473+
elif mode not in ("clip", "wrap"):
439474
pass
440-
elif (
441-
numpy.max(ind) >= x1_desc.size or numpy.min(ind) + x1_desc.size < 0
442-
):
475+
elif axis is not None and not isinstance(axis, int):
476+
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
477+
# TODO: remove when #1382(dpctl) is solved
478+
elif dpnp.is_supported_array_type(vals) and a.dtype != vals.dtype:
443479
pass
444480
else:
445-
return dpnp_put(x1_desc, ind, v)
481+
if axis is None and a.ndim > 1:
482+
a = dpnp.reshape(a, -1)
483+
dpt_array = dpnp.get_usm_ndarray(a)
484+
dpt_indices = dpnp.get_usm_ndarray(indices)
485+
dpt_vals = (
486+
dpnp.get_usm_ndarray(vals)
487+
if isinstance(vals, dpnp_array)
488+
else vals
489+
)
490+
return dpt.put(
491+
dpt_array, dpt_indices, dpt_vals, axis=axis, mode=mode
492+
)
446493

447-
return call_origin(numpy.put, x1, ind, v, mode, dpnp_inplace=True)
494+
return call_origin(numpy.put, a, indices, vals, mode, dpnp_inplace=True)
448495

449496

450497
def put_along_axis(x1, indices, values, axis):
@@ -557,7 +604,7 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
557604
or :class:`dpctl.tensor.usm_ndarray`.
558605
Parameter `indices` is supported as 1-D array of integer data type.
559606
Parameter `out` is supported only with default value.
560-
Parameter `mode` is supported with ``wrap``(default) and ``clip`` mode.
607+
Parameter `mode` is supported with ``wrap``, the default, and ``clip`` values.
561608
Providing parameter `axis` is optional when `x` is a 1-D array.
562609
Otherwise the function will be executed sequentially on CPU.
563610

0 commit comments

Comments
 (0)