Skip to content

Commit c77344c

Browse files
Allow for change of name and location of sycl_complex.hpp
Introduced private header to load SYCL's experimental complex header from the right location. The header and implementations respond to USE_SYCL_FOR_COMPLEX_TYPES preprocessor variable. If set, sycl::ext::oneapi::experimental namespace functions are to be used. Otherwise std:: namespace functions will be used instead for complex types. USE_SYCL_FOR_COMPLEX_TYPES is being set in tensor/CMakeLists.txt If USE_SYCL_FOR_COMPLEX_TYPES is not set, std:: functions are used except for sqrt and abs functions. For abs we use hypot(std::real(z), std::imag(z)) and for sqrt we use custom implementation on Windows to avoid failure to offload for single precision type due to unwarranted use of double precision types in the implementation for single precision inputs iin MS VC headers
1 parent 6d3be5d commit c77344c

29 files changed

+237
-70
lines changed

Diff for: dpctl/tensor/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ foreach(_src_fn ${_no_fast_math_sources})
188188
)
189189
endforeach()
190190
if (UNIX)
191-
set(_compiler_definitions "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX")
191+
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
192192
else()
193-
set(_compiler_definitions "SYCL_EXT_ONEAPI_COMPLEX")
193+
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
194194
endif()
195195

196196
foreach(_src_fn ${_elementwise_sources})

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
#include <cstddef>
2929
#include <cstdint>
3030
#include <limits>
31-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3231
#include <sycl/sycl.hpp>
3332
#include <type_traits>
3433

3534
#include "kernels/elementwise_functions/common.hpp"
35+
#include "sycl_complex.hpp"
3636

3737
#include "utils/offset_utils.hpp"
3838
#include "utils/type_dispatch.hpp"
@@ -50,7 +50,6 @@ namespace abs
5050

5151
namespace py = pybind11;
5252
namespace td_ns = dpctl::tensor::type_dispatch;
53-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5453

5554
using dpctl::tensor::type_utils::is_complex;
5655

@@ -121,7 +120,7 @@ template <typename argT, typename resT> struct AbsFunctor
121120
return q_nan;
122121
}
123122
else {
124-
#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES
123+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
125124
return exprm_ns::abs(exprm_ns::complex<realT>(z));
126125
#else
127126
return std::hypot(std::real(z), std::imag(z));

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace acos
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -105,6 +104,7 @@ template <typename argT, typename resT> struct AcosFunctor
105104
constexpr realT r_eps =
106105
realT(1) / std::numeric_limits<realT>::epsilon();
107106
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
107+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
108108
using sycl_complexT = exprm_ns::complex<realT>;
109109
sycl_complexT log_in =
110110
exprm_ns::log(exprm_ns::complex<realT>(in));
@@ -115,11 +115,24 @@ template <typename argT, typename resT> struct AcosFunctor
115115

116116
realT ry = wx + std::log(realT(2));
117117
return resT{rx, (std::signbit(y)) ? ry : -ry};
118+
#else
119+
resT log_in = std::log(in);
120+
const realT wx = std::real(log_in);
121+
const realT wy = std::imag(log_in);
122+
const realT rx = std::abs(wy);
123+
124+
realT ry = wx + std::log(realT(2));
125+
return resT{rx, (std::signbit(y)) ? ry : -ry};
126+
#endif
118127
}
119128

120129
/* ordinary cases */
130+
#if USE_SYCL_FOR_COMPLEX_TYPES
121131
return exprm_ns::acos(
122132
exprm_ns::complex<realT>(in)); // std::acos(in);
133+
#else
134+
return std::acos(in);
135+
#endif
123136
}
124137
else {
125138
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace acosh
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -112,18 +111,28 @@ template <typename argT, typename resT> struct AcoshFunctor
112111
* For large x or y including acos(+-Inf + I*+-Inf)
113112
*/
114113
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
114+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
115115
using sycl_complexT = typename exprm_ns::complex<realT>;
116116
const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in));
117117
const realT wx = log_in.real();
118118
const realT wy = log_in.imag();
119+
#else
120+
const resT log_in = std::log(in);
121+
const realT wx = std::real(log_in);
122+
const realT wy = std::imag(log_in);
123+
#endif
119124
const realT rx = std::abs(wy);
120125
realT ry = wx + std::log(realT(2));
121126
acos_in = resT{rx, (std::signbit(y)) ? ry : -ry};
122127
}
123128
else {
124129
/* ordinary cases */
130+
#if USE_SYCL_FOR_COMPLEX_TYPES
125131
acos_in = exprm_ns::acos(
126132
exprm_ns::complex<realT>(in)); // std::acos(in);
133+
#else
134+
acos_in = std::acos(in);
135+
#endif
127136
}
128137

129138
/* Now we calculate acosh(z) */

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#pragma once
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

32+
#include "sycl_complex.hpp"
3333
#include "utils/offset_utils.hpp"
3434
#include "utils/type_dispatch.hpp"
3535
#include "utils/type_utils.hpp"
@@ -50,7 +50,6 @@ namespace add
5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252
namespace tu_ns = dpctl::tensor::type_utils;
53-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5453

5554
template <typename argT1, typename argT2, typename resT> struct AddFunctor
5655
{
@@ -65,24 +64,36 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
6564
if constexpr (tu_ns::is_complex<argT1>::value &&
6665
tu_ns::is_complex<argT2>::value)
6766
{
67+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
6868
using rT1 = typename argT1::value_type;
6969
using rT2 = typename argT2::value_type;
7070

7171
return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
72+
#else
73+
return in1 + in2;
74+
#endif
7275
}
7376
else if constexpr (tu_ns::is_complex<argT1>::value &&
7477
!tu_ns::is_complex<argT2>::value)
7578
{
79+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
7680
using rT1 = typename argT1::value_type;
7781

7882
return exprm_ns::complex<rT1>(in1) + in2;
83+
#else
84+
return in1 + in2;
85+
#endif
7986
}
8087
else if constexpr (!tu_ns::is_complex<argT1>::value &&
8188
tu_ns::is_complex<argT2>::value)
8289
{
90+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
8391
using rT2 = typename argT2::value_type;
8492

8593
return in1 + exprm_ns::complex<rT2>(in2);
94+
#else
95+
return in1 + in2;
96+
#endif
8697
}
8798
else {
8899
return in1 + in2;

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace asin
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -119,26 +118,45 @@ template <typename argT, typename resT> struct AsinFunctor
119118
constexpr realT r_eps =
120119
realT(1) / std::numeric_limits<realT>::epsilon();
121120
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
121+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
122122
using sycl_complexT = exprm_ns::complex<realT>;
123123
const sycl_complexT z{x, y};
124124
realT wx, wy;
125125
if (!std::signbit(x)) {
126-
auto log_z = exprm_ns::log(z);
126+
const auto log_z = exprm_ns::log(z);
127127
wx = log_z.real() + std::log(realT(2));
128128
wy = log_z.imag();
129129
}
130130
else {
131-
auto log_mz = exprm_ns::log(-z);
131+
const auto log_mz = exprm_ns::log(-z);
132132
wx = log_mz.real() + std::log(realT(2));
133133
wy = log_mz.imag();
134134
}
135+
#else
136+
const resT z{x, y};
137+
realT wx, wy;
138+
if (!std::signbit(x)) {
139+
const auto log_z = std::log(z);
140+
wx = std::real(log_z) + std::log(realT(2));
141+
wy = std::imag(log_z);
142+
}
143+
else {
144+
const auto log_mz = std::log(-z);
145+
wx = std::real(log_mz) + std::log(realT(2));
146+
wy = std::imag(log_mz);
147+
}
148+
#endif
135149
const realT asinh_re = std::copysign(wx, x);
136150
const realT asinh_im = std::copysign(wy, y);
137151
return resT{asinh_im, asinh_re};
138152
}
139153
/* ordinary cases */
154+
#if USE_SYCL_FOR_COMPLEX_TYPES
140155
return exprm_ns::asin(
141156
exprm_ns::complex<realT>(in)); // std::asin(in);
157+
#else
158+
return std::asin(in);
159+
#endif
142160
}
143161
else {
144162
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3029
#include <sycl/sycl.hpp>
3130
#include <type_traits>
3231

3332
#include "kernels/elementwise_functions/common.hpp"
33+
#include "sycl_complex.hpp"
3434

3535
#include "utils/offset_utils.hpp"
3636
#include "utils/type_dispatch.hpp"
@@ -48,7 +48,6 @@ namespace asinh
4848

4949
namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
51-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5251

5352
using dpctl::tensor::type_utils::is_complex;
5453

@@ -108,20 +107,30 @@ template <typename argT, typename resT> struct AsinhFunctor
108107
realT(1) / std::numeric_limits<realT>::epsilon();
109108

110109
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
110+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
111111
using sycl_complexT = exprm_ns::complex<realT>;
112112
sycl_complexT log_in = (std::signbit(x))
113113
? exprm_ns::log(sycl_complexT(-in))
114114
: exprm_ns::log(sycl_complexT(in));
115115
realT wx = log_in.real() + std::log(realT(2));
116116
realT wy = log_in.imag();
117+
#else
118+
auto log_in = std::log(std::signbit(x) ? -in : in);
119+
realT wx = std::real(log_in) + std::log(realT(2));
120+
realT wy = std::imag(log_in);
121+
#endif
117122
const realT res_re = std::copysign(wx, x);
118123
const realT res_im = std::copysign(wy, y);
119124
return resT{res_re, res_im};
120125
}
121126

122127
/* ordinary cases */
128+
#if USE_SYCL_FOR_COMPLEX_TYPES
123129
return exprm_ns::asinh(
124130
exprm_ns::complex<realT>(in)); // std::asinh(in);
131+
#else
132+
return std::asinh(in);
133+
#endif
125134
}
126135
else {
127136
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include <complex>
2828
#include <cstddef>
2929
#include <cstdint>
30-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3130
#include <sycl/sycl.hpp>
3231
#include <type_traits>
3332

3433
#include "kernels/elementwise_functions/common.hpp"
34+
#include "sycl_complex.hpp"
3535

3636
#include "utils/offset_utils.hpp"
3737
#include "utils/type_dispatch.hpp"
@@ -49,7 +49,6 @@ namespace atan
4949

5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
52-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5352

5453
using dpctl::tensor::type_utils::is_complex;
5554

@@ -128,8 +127,12 @@ template <typename argT, typename resT> struct AtanFunctor
128127
return resT{atanh_im, atanh_re};
129128
}
130129
/* ordinary cases */
130+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
131131
return exprm_ns::atan(
132132
exprm_ns::complex<realT>(in)); // std::atan(in);
133+
#else
134+
return std::atan(in);
135+
#endif
133136
}
134137
else {
135138
static_assert(std::is_floating_point_v<argT> ||

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include <complex>
2828
#include <cstddef>
2929
#include <cstdint>
30-
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3130
#include <sycl/sycl.hpp>
3231
#include <type_traits>
3332

3433
#include "kernels/elementwise_functions/common.hpp"
34+
#include "sycl_complex.hpp"
3535

3636
#include "utils/offset_utils.hpp"
3737
#include "utils/type_dispatch.hpp"
@@ -49,7 +49,6 @@ namespace atanh
4949

5050
namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
52-
namespace exprm_ns = sycl::ext::oneapi::experimental;
5352

5453
using dpctl::tensor::type_utils::is_complex;
5554

@@ -121,8 +120,12 @@ template <typename argT, typename resT> struct AtanhFunctor
121120
return resT{res_re, res_im};
122121
}
123122
/* ordinary cases */
123+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
124124
return exprm_ns::atanh(
125125
exprm_ns::complex<realT>(in)); // std::atanh(in);
126+
#else
127+
return std::atanh(in);
128+
#endif
126129
}
127130
else {
128131
static_assert(std::is_floating_point_v<argT> ||

0 commit comments

Comments
 (0)