Skip to content

Commit 0f76890

Browse files
Allow for change of name and location of sycl_complex.hpp
Introduced private header to load SYCL's experimental complex header from the right location. The header and implementations respond to USE_SYCL_FOR_COMPLEX_TYPES preprocessor variable. If set, sycl::ext::oneapi::experimental namespace functions are to be used. Otherwise std:: namespace functions will be used instead for complex types. USE_SYCL_FOR_COMPLEX_TYPES is being set in tensor/CMakeLists.txt
1 parent 6d3be5d commit 0f76890

29 files changed

+228
-69
lines changed

Diff for: dpctl/tensor/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ foreach(_src_fn ${_no_fast_math_sources})
188188
)
189189
endforeach()
190190
if (UNIX)
191-
set(_compiler_definitions "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX")
191+
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
192192
else()
193-
set(_compiler_definitions "SYCL_EXT_ONEAPI_COMPLEX")
193+
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
194194
endif()
195195

196196
foreach(_src_fn ${_elementwise_sources})

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
#include <cstddef>
2929
#include <cstdint>
3030
#include <limits>
31-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3231
#include <sycl/sycl.hpp>
3332
#include <type_traits>
3433

3534
#include "kernels/elementwise_functions/common.hpp"
35+
#include "sycl_complex.hpp"
3636

3737
#include "utils/offset_utils.hpp"
3838
#include "utils/type_dispatch.hpp"
@@ -50,7 +50,6 @@ namespace abs
5050

5151
namespace py = pybind11;
5252
namespace td_ns = dpctl::tensor::type_dispatch;
53-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5453

5554
using dpctl::tensor::type_utils::is_complex;
5655

@@ -121,7 +120,7 @@ template <typename argT, typename resT> struct AbsFunctor
121120
return q_nan;
122121
}
123122
else {
124-
#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES
123+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
125124
return exprm_ns::abs(exprm_ns::complex<realT>(z));
126125
#else
127126
return std::hypot(std::real(z), std::imag(z));

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace acos
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -105,6 +104,7 @@ template <typename argT, typename resT> struct AcosFunctor
105104
constexpr realT r_eps =
106105
realT(1) / std::numeric_limits<realT>::epsilon();
107106
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
107+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
108108
using sycl_complexT = exprm_ns::complex<realT>;
109109
sycl_complexT log_in =
110110
exprm_ns::log(exprm_ns::complex<realT>(in));
@@ -115,11 +115,24 @@ template <typename argT, typename resT> struct AcosFunctor
115115

116116
realT ry = wx + std::log(realT(2));
117117
return resT{rx, (std::signbit(y)) ? ry : -ry};
118+
#else
119+
resT log_in = std::log(in);
120+
const realT wx = std::real(log_in);
121+
const realT wy = std::imag(log_in);
122+
const realT rx = std::abs(wy);
123+
124+
realT ry = wx + std::log(realT(2));
125+
return resT{rx, (std::signbit(y)) ? ry : -ry};
126+
#endif
118127
}
119128

120129
/* ordinary cases */
130+
#if USE_SYCL_FOR_COMPLEX_TYPES
121131
return exprm_ns::acos(
122132
exprm_ns::complex<realT>(in)); // std::acos(in);
133+
#else
134+
return std::acos(in);
135+
#endif
123136
}
124137
else {
125138
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace acosh
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -112,18 +111,28 @@ template <typename argT, typename resT> struct AcoshFunctor
112111
* For large x or y including acos(+-Inf + I*+-Inf)
113112
*/
114113
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
114+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
115115
using sycl_complexT = typename exprm_ns::complex<realT>;
116116
const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in));
117117
const realT wx = log_in.real();
118118
const realT wy = log_in.imag();
119+
#else
120+
const resT log_in = std::log(in);
121+
const realT wx = std::real(log_in);
122+
const realT wy = std::imag(log_in);
123+
#endif
119124
const realT rx = std::abs(wy);
120125
realT ry = wx + std::log(realT(2));
121126
acos_in = resT{rx, (std::signbit(y)) ? ry : -ry};
122127
}
123128
else {
124129
/* ordinary cases */
130+
#if USE_SYCL_FOR_COMPLEX_TYPES
125131
acos_in = exprm_ns::acos(
126132
exprm_ns::complex<realT>(in)); // std::acos(in);
133+
#else
134+
acos_in = std::acos(in);
135+
#endif
127136
}
128137

129138
/* Now we calculate acosh(z) */

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#pragma once
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

32+
#include "sycl_complex.hpp"
3333
#include "utils/offset_utils.hpp"
3434
#include "utils/type_dispatch.hpp"
3535
#include "utils/type_utils.hpp"
@@ -50,7 +50,6 @@ namespace add
5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252
namespace tu_ns = dpctl::tensor::type_utils;
53-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5453

5554
template <typename argT1, typename argT2, typename resT> struct AddFunctor
5655
{
@@ -65,24 +64,36 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
6564
if constexpr (tu_ns::is_complex<argT1>::value &&
6665
tu_ns::is_complex<argT2>::value)
6766
{
67+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
6868
using rT1 = typename argT1::value_type;
6969
using rT2 = typename argT2::value_type;
7070

7171
return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
72+
#else
73+
return in1 + in2;
74+
#endif
7275
}
7376
else if constexpr (tu_ns::is_complex<argT1>::value &&
7477
!tu_ns::is_complex<argT2>::value)
7578
{
79+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
7680
using rT1 = typename argT1::value_type;
7781

7882
return exprm_ns::complex<rT1>(in1) + in2;
83+
#else
84+
return in1 + in2;
85+
#endif
7986
}
8087
else if constexpr (!tu_ns::is_complex<argT1>::value &&
8188
tu_ns::is_complex<argT2>::value)
8289
{
90+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
8391
using rT2 = typename argT2::value_type;
8492

8593
return in1 + exprm_ns::complex<rT2>(in2);
94+
#else
95+
return in1 + in2;
96+
#endif
8697
}
8798
else {
8899
return in1 + in2;

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace asin
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -119,26 +118,45 @@ template <typename argT, typename resT> struct AsinFunctor
119118
constexpr realT r_eps =
120119
realT(1) / std::numeric_limits<realT>::epsilon();
121120
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
121+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
122122
using sycl_complexT = exprm_ns::complex<realT>;
123123
const sycl_complexT z{x, y};
124124
realT wx, wy;
125125
if (!std::signbit(x)) {
126-
auto log_z = exprm_ns::log(z);
126+
const auto log_z = exprm_ns::log(z);
127127
wx = log_z.real() + std::log(realT(2));
128128
wy = log_z.imag();
129129
}
130130
else {
131-
auto log_mz = exprm_ns::log(-z);
131+
const auto log_mz = exprm_ns::log(-z);
132132
wx = log_mz.real() + std::log(realT(2));
133133
wy = log_mz.imag();
134134
}
135+
#else
136+
const resT z{x, y};
137+
realT wx, wy;
138+
if (!std::signbit(x)) {
139+
const auto log_z = std::log(z);
140+
wx = std::real(log_z) + std::log(realT(2));
141+
wy = std::imag(log_z);
142+
}
143+
else {
144+
const auto log_mz = std::log(-z);
145+
wx = std::real(log_mz) + std::log(realT(2));
146+
wy = std::imag(log_mz);
147+
}
148+
#endif
135149
const realT asinh_re = std::copysign(wx, x);
136150
const realT asinh_im = std::copysign(wy, y);
137151
return resT{asinh_im, asinh_re};
138152
}
139153
/* ordinary cases */
154+
#if USE_SYCL_FOR_COMPLEX_TYPES
140155
return exprm_ns::asin(
141156
exprm_ns::complex<realT>(in)); // std::asin(in);
157+
#else
158+
return std::asin(in);
159+
#endif
142160
}
143161
else {
144162
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace asinh
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -108,20 +107,30 @@ template <typename argT, typename resT> struct AsinhFunctor
108107
realT(1) / std::numeric_limits<realT>::epsilon();
109108

110109
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
110+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
111111
using sycl_complexT = exprm_ns::complex<realT>;
112112
sycl_complexT log_in = (std::signbit(x))
113113
? exprm_ns::log(sycl_complexT(-in))
114114
: exprm_ns::log(sycl_complexT(in));
115115
realT wx = log_in.real() + std::log(realT(2));
116116
realT wy = log_in.imag();
117+
#else
118+
auto log_in = std::log(std::signbit(x) ? -in : in);
119+
realT wx = std::real(log_in) + std::log(realT(2));
120+
realT wy = std::imag(log_in);
121+
#endif
117122
const realT res_re = std::copysign(wx, x);
118123
const realT res_im = std::copysign(wy, y);
119124
return resT{res_re, res_im};
120125
}
121126

122127
/* ordinary cases */
128+
#if USE_SYCL_FOR_COMPLEX_TYPES
123129
return exprm_ns::asinh(
124130
exprm_ns::complex<realT>(in)); // std::asinh(in);
131+
#else
132+
return std::asinh(in);
133+
#endif
125134
}
126135
else {
127136
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include <complex>
2828
#include <cstddef>
2929
#include <cstdint>
30-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3130
#include <sycl/sycl.hpp>
3231
#include <type_traits>
3332

3433
#include "kernels/elementwise_functions/common.hpp"
34+
#include "sycl_complex.hpp"
3535

3636
#include "utils/offset_utils.hpp"
3737
#include "utils/type_dispatch.hpp"
@@ -49,7 +49,6 @@ namespace atan
4949

5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
52-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5352

5453
using dpctl::tensor::type_utils::is_complex;
5554

@@ -128,8 +127,12 @@ template <typename argT, typename resT> struct AtanFunctor
128127
return resT{atanh_im, atanh_re};
129128
}
130129
/* ordinary cases */
130+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
131131
return exprm_ns::atan(
132132
exprm_ns::complex<realT>(in)); // std::atan(in);
133+
#else
134+
return std::atan(in);
135+
#endif
133136
}
134137
else {
135138
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include <complex>
2828
#include <cstddef>
2929
#include <cstdint>
30-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3130
#include <sycl/sycl.hpp>
3231
#include <type_traits>
3332

3433
#include "kernels/elementwise_functions/common.hpp"
34+
#include "sycl_complex.hpp"
3535

3636
#include "utils/offset_utils.hpp"
3737
#include "utils/type_dispatch.hpp"
@@ -49,7 +49,6 @@ namespace atanh
4949

5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
52-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5352

5453
using dpctl::tensor::type_utils::is_complex;
5554

@@ -121,8 +120,12 @@ template <typename argT, typename resT> struct AtanhFunctor
121120
return resT{res_re, res_im};
122121
}
123122
/* ordinary cases */
123+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
124124
return exprm_ns::atanh(
125125
exprm_ns::complex<realT>(in)); // std::atanh(in);
126+
#else
127+
return std::atanh(in);
128+
#endif
126129
}
127130
else {
128131
static_assert(std::is_floating_point_v<argT> ||

0 commit comments

Comments
 (0)