diff --git a/dpctl/tensor/_accumulation.py b/dpctl/tensor/_accumulation.py index 5232740b4f..5f5223629a 100644 --- a/dpctl/tensor/_accumulation.py +++ b/dpctl/tensor/_accumulation.py @@ -20,53 +20,14 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_accumulation_impl as tai import dpctl.tensor._tensor_impl as ti -from dpctl.tensor._type_utils import _to_device_supported_dtype +from dpctl.tensor._type_utils import ( + _default_accumulation_dtype, + _default_accumulation_dtype_fp_types, + _to_device_supported_dtype, +) from dpctl.utils import ExecutionPlacementError -def _default_accumulation_dtype(inp_dt, q): - """Gives default output data type for given input data - type `inp_dt` when accumulation is performed on queue `q` - """ - inp_kind = inp_dt.kind - if inp_kind in "bi": - res_dt = dpt.dtype(ti.default_device_int_type(q)) - if inp_dt.itemsize > res_dt.itemsize: - res_dt = inp_dt - elif inp_kind in "u": - res_dt = dpt.dtype(ti.default_device_int_type(q).upper()) - res_ii = dpt.iinfo(res_dt) - inp_ii = dpt.iinfo(inp_dt) - if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max: - pass - else: - res_dt = inp_dt - elif inp_kind in "fc": - res_dt = inp_dt - - return res_dt - - -def _default_accumulation_dtype_fp_types(inp_dt, q): - """Gives default output data type for given input data - type `inp_dt` when accumulation is performed on queue `q` - and the accumulation supports only floating-point data types - """ - inp_kind = inp_dt.kind - if inp_kind in "biu": - res_dt = dpt.dtype(ti.default_device_fp_type(q)) - can_cast_v = dpt.can_cast(inp_dt, res_dt) - if not can_cast_v: - _fp64 = q.sycl_device.has_aspect_fp64 - res_dt = dpt.float64 if _fp64 else dpt.float32 - elif inp_kind in "f": - res_dt = inp_dt - elif inp_kind in "c": - raise ValueError("function not defined for complex types") - - return res_dt - - def _accumulate_common( x, axis, diff --git a/dpctl/tensor/_array_api.py b/dpctl/tensor/_array_api.py index 2a4fd30686..75531564af 100644 --- a/dpctl/tensor/_array_api.py +++ b/dpctl/tensor/_array_api.py @@ -49,7 +49,7 @@ def _isdtype_impl(dtype, kind): raise TypeError(f"Unsupported data type kind: {kind}") -__array_api_version__ = "2022.12" +__array_api_version__ = "2023.12" class Info: @@ -80,6 +80,8 @@ def __init__(self): def capabilities(self): """ + capabilities() + Returns a dictionary of `dpctl`'s capabilities. Returns: @@ -92,12 +94,16 @@ def capabilities(self): def default_device(self): """ + default_device() + Returns the default SYCL device. """ return dpctl.select_default_device() - def default_dtypes(self, device=None): + def default_dtypes(self, *, device=None): """ + default_dtypes(*, device=None) + Returns a dictionary of default data types for `device`. Args: @@ -129,8 +135,10 @@ def default_dtypes(self, device=None): "indexing": dpt.dtype(default_device_index_type(device)), } - def dtypes(self, device=None, kind=None): + def dtypes(self, *, device=None, kind=None): """ + dtypes(*, device=None, kind=None) + Returns a dictionary of all Array API data types of a specified `kind` supported by `device` @@ -193,13 +201,16 @@ def dtypes(self, device=None, kind=None): def devices(self): """ + devices() + Returns a list of supported devices. """ return dpctl.get_devices() def __array_namespace_info__(): - """__array_namespace_info__() + """ + __array_namespace_info__() Returns a namespace with Array API namespace inspection utilities. diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 2494e130ab..6c0afbeebb 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -21,58 +21,11 @@ import dpctl.tensor._tensor_impl as ti import dpctl.tensor._tensor_reductions_impl as tri -from ._type_utils import _to_device_supported_dtype - - -def _default_reduction_dtype(inp_dt, q): - """Gives default output data type for given input data - type `inp_dt` when reduction is performed on queue `q` - """ - inp_kind = inp_dt.kind - if inp_kind in "bi": - res_dt = dpt.dtype(ti.default_device_int_type(q)) - if inp_dt.itemsize > res_dt.itemsize: - res_dt = inp_dt - elif inp_kind in "u": - res_dt = dpt.dtype(ti.default_device_int_type(q).upper()) - res_ii = dpt.iinfo(res_dt) - inp_ii = dpt.iinfo(inp_dt) - if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max: - pass - else: - res_dt = inp_dt - elif inp_kind in "f": - res_dt = dpt.dtype(ti.default_device_fp_type(q)) - if res_dt.itemsize < inp_dt.itemsize: - res_dt = inp_dt - elif inp_kind in "c": - res_dt = dpt.dtype(ti.default_device_complex_type(q)) - if res_dt.itemsize < inp_dt.itemsize: - res_dt = inp_dt - - return res_dt - - -def _default_reduction_dtype_fp_types(inp_dt, q): - """Gives default output data type for given input data - type `inp_dt` when reduction is performed on queue `q` - and the reduction supports only floating-point data types - """ - inp_kind = inp_dt.kind - if inp_kind in "biu": - res_dt = dpt.dtype(ti.default_device_fp_type(q)) - can_cast_v = dpt.can_cast(inp_dt, res_dt) - if not can_cast_v: - _fp64 = q.sycl_device.has_aspect_fp64 - res_dt = dpt.float64 if _fp64 else dpt.float32 - elif inp_kind in "f": - res_dt = dpt.dtype(ti.default_device_fp_type(q)) - if res_dt.itemsize < inp_dt.itemsize: - res_dt = inp_dt - elif inp_kind in "c": - raise TypeError("reduction not defined for complex types") - - return res_dt +from ._type_utils import ( + _default_accumulation_dtype, + _default_accumulation_dtype_fp_types, + _to_device_supported_dtype, +) def _reduction_over_axis( @@ -237,7 +190,7 @@ def sum(x, axis=None, dtype=None, keepdims=False): keepdims, tri._sum_over_axis, tri._sum_over_axis_dtype_supported, - _default_reduction_dtype, + _default_accumulation_dtype, ) @@ -299,7 +252,7 @@ def prod(x, axis=None, dtype=None, keepdims=False): keepdims, tri._prod_over_axis, tri._prod_over_axis_dtype_supported, - _default_reduction_dtype, + _default_accumulation_dtype, ) @@ -356,7 +309,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False): lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( inp_dt, res_dt ), - _default_reduction_dtype_fp_types, + _default_accumulation_dtype_fp_types, ) @@ -413,7 +366,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False): lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( inp_dt, res_dt ), - _default_reduction_dtype_fp_types, + _default_accumulation_dtype_fp_types, ) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index c16bb6112e..005e452219 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -733,6 +733,49 @@ def isdtype(dtype, kind): raise TypeError(f"Unsupported data type kind: {kind}") +def _default_accumulation_dtype(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when accumulation is performed on queue `q` + """ + inp_kind = inp_dt.kind + if inp_kind in "bi": + res_dt = dpt.dtype(ti.default_device_int_type(q)) + if inp_dt.itemsize > res_dt.itemsize: + res_dt = inp_dt + elif inp_kind in "u": + res_dt = dpt.dtype(ti.default_device_int_type(q).upper()) + res_ii = dpt.iinfo(res_dt) + inp_ii = dpt.iinfo(inp_dt) + if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max: + pass + else: + res_dt = inp_dt + elif inp_kind in "fc": + res_dt = inp_dt + + return res_dt + + +def _default_accumulation_dtype_fp_types(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when accumulation is performed on queue `q` + and the accumulation supports only floating-point data types + """ + inp_kind = inp_dt.kind + if inp_kind in "biu": + res_dt = dpt.dtype(ti.default_device_fp_type(q)) + can_cast_v = dpt.can_cast(inp_dt, res_dt) + if not can_cast_v: + _fp64 = q.sycl_device.has_aspect_fp64 + res_dt = dpt.float64 if _fp64 else dpt.float32 + elif inp_kind in "f": + res_dt = inp_dt + elif inp_kind in "c": + raise ValueError("function not defined for complex types") + + return res_dt + + __all__ = [ "_find_buf_dtype", "_find_buf_dtype2", @@ -753,4 +796,6 @@ def isdtype(dtype, kind): "WeakIntegralType", "WeakFloatingType", "WeakComplexType", + "_default_accumulation_dtype", + "_default_accumulation_dtype_fp_types", ] diff --git a/dpctl/tests/test_tensor_array_api_inspection.py b/dpctl/tests/test_tensor_array_api_inspection.py index 874ada9363..2fb7532466 100644 --- a/dpctl/tests/test_tensor_array_api_inspection.py +++ b/dpctl/tests/test_tensor_array_api_inspection.py @@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes(): info = dpt.__array_namespace_info__() default_dts_nodev = info.default_dtypes() - default_dts_dev = info.default_dtypes(dev) + default_dts_dev = info.default_dtypes(device=dev) assert ( int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"] diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 47334acf18..7d431e0d0d 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -491,7 +491,7 @@ def test_mixed_index_getitem(): x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10)) i1b = dpt.ones(10, dtype="?") info = x.__array_namespace__().__array_namespace_info__() - ind_dt = info.default_dtypes(x.device)["indexing"] + ind_dt = info.default_dtypes(device=x.device)["indexing"] i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis] i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis] y = x[i0, i1b, i2] @@ -503,7 +503,7 @@ def test_mixed_index_setitem(): x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10)) i1b = dpt.ones(10, dtype="?") info = x.__array_namespace__().__array_namespace_info__() - ind_dt = info.default_dtypes(x.device)["indexing"] + ind_dt = info.default_dtypes(device=x.device)["indexing"] i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis] i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis] v_shape = (3, int(dpt.sum(i1b, dtype="i8"))) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 7da447d315..8b22d049e3 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -406,7 +406,7 @@ def test_logsumexp_complex(): get_queue_or_skip() x = dpt.zeros(1, dtype="c8") - with pytest.raises(TypeError): + with pytest.raises(ValueError): dpt.logsumexp(x) @@ -470,7 +470,7 @@ def test_hypot_complex(): get_queue_or_skip() x = dpt.zeros(1, dtype="c8") - with pytest.raises(TypeError): + with pytest.raises(ValueError): dpt.reduce_hypot(x) diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index 4d2e899fe1..41f6ecac7a 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np): assert hay_stack.ndim == 1 info_ = dpt.__array_namespace_info__() - default_dts_dev = info_.default_dtypes(hay_stack.device) + default_dts_dev = info_.default_dtypes(device=hay_stack.device) index_dt = default_dts_dev["indexing"] p_left = dpt.searchsorted(hay_stack, needles, side="left")