Skip to content

Commit 468110d

Browse files
More files change to use sycl_complex
1 parent e4ea991 commit 468110d

File tree

10 files changed

+29
-13
lines changed

10 files changed

+29
-13
lines changed

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/math_utils.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+
#include <sycl/sycl.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace log
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -65,7 +67,13 @@ template <typename argT, typename resT> struct LogFunctor
6567

6668
resT operator()(const argT &in) const
6769
{
68-
return std::log(in);
70+
if constexpr (is_complex<argT>::value) {
71+
using realT = typename argT::value_type;
72+
return exprm_ns::log(exprm_ns::complex<realT>(in)); // std::log(in);
73+
}
74+
else {
75+
return std::log(in);
76+
}
6977
}
7078
};
7179

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace log10
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355
using dpctl::tensor::type_utils::vec_cast;
@@ -70,7 +72,9 @@ template <typename argT, typename resT> struct Log10Functor
7072
{
7173
if constexpr (is_complex<argT>::value) {
7274
using realT = typename argT::value_type;
73-
return (std::log(in) / std::log(realT{10}));
75+
// return (std::log(in) / std::log(realT{10}));
76+
return exprm_ns::log(exprm_ns::complex<realT>(in)) /
77+
std::log(realT{10});
7478
}
7579
else {
7680
return std::log10(in);

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace log2
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355
using dpctl::tensor::type_utils::vec_cast;
@@ -70,7 +72,9 @@ template <typename argT, typename resT> struct Log2Functor
7072
{
7173
if constexpr (is_complex<argT>::value) {
7274
using realT = typename argT::value_type;
73-
return std::log(in) / std::log(realT{2});
75+
// std::log(in) / std::log(realT{2});
76+
return exprm_ns::log(exprm_ns::complex<realT>(in)) /
77+
std::log(realT{2});
7478
}
7579
else {
7680
return std::log2(in);

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
3130
#include <limits>
31+
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

3434
#include "utils/offset_utils.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

Diff for: dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

0 commit comments

Comments
 (0)