Skip to content

Commit 746bfe1

Browse files
authored
[SYCL][COMPAT] Added Complex muladd to syclcompat (#12969)
Adds the documentation, functionality and tests for complex multiplication addition to SYCLcompat.
1 parent f6b952d commit 746bfe1

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ as a vector of elements, and returning `0` for vector components for which
12291229
`vectorized_sum_abs_diff` calculates the absolute difference for two values
12301230
without modulo overflow for vector types.
12311231

1232-
The functions `cmul`,`cdiv`,`cabs`, and `conj` define complex math operations
1232+
The functions `cmul`,`cdiv`,`cabs`, `cmul_add`, and `conj` define complex math operations
12331233
which accept `sycl::vec<T,2>` arguments representing complex values.
12341234

12351235
```cpp
@@ -1259,6 +1259,16 @@ sycl::vec<T, 2> cdiv(sycl::vec<T, 2> x, sycl::vec<T, 2> y);
12591259

12601260
template <typename T> T cabs(sycl::vec<T, 2> x);
12611261

1262+
template <typename ValueT>
1263+
inline sycl::vec<ValueT, 2> cmul_add(const sycl::vec<ValueT, 2> a,
1264+
const sycl::vec<ValueT, 2> b,
1265+
const sycl::vec<ValueT, 2> c);
1266+
1267+
template <typename ValueT>
1268+
inline sycl::marray<ValueT, 2> cmul_add(const sycl::marray<ValueT, 2> a,
1269+
const sycl::marray<ValueT, 2> b,
1270+
const sycl::marray<ValueT, 2> c);
1271+
12621272
template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x);
12631273

12641274
template <typename ValueT> inline ValueT reverse_bits(ValueT a);

sycl/include/syclcompat/math.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,32 @@ template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x) {
251251
return sycl::vec<T, 2>(t.real(), t.imag());
252252
}
253253

254+
/// Performs complex number multiply addition.
255+
/// \param [in] a The first value
256+
/// \param [in] b The second value
257+
/// \param [in] c The third value
258+
/// \returns the operation result
259+
template <typename ValueT>
260+
inline sycl::vec<ValueT, 2> cmul_add(const sycl::vec<ValueT, 2> a,
261+
const sycl::vec<ValueT, 2> b,
262+
const sycl::vec<ValueT, 2> c) {
263+
sycl::ext::oneapi::experimental::complex<ValueT> t(a[0], a[1]);
264+
sycl::ext::oneapi::experimental::complex<ValueT> u(b[0], b[1]);
265+
sycl::ext::oneapi::experimental::complex<ValueT> v(c[0], c[1]);
266+
t = t * u + v;
267+
return sycl::vec<ValueT, 2>{t.real(), t.imag()};
268+
}
269+
template <typename ValueT>
270+
inline sycl::marray<ValueT, 2> cmul_add(const sycl::marray<ValueT, 2> a,
271+
const sycl::marray<ValueT, 2> b,
272+
const sycl::marray<ValueT, 2> c) {
273+
sycl::ext::oneapi::experimental::complex<ValueT> t(a[0], a[1]);
274+
sycl::ext::oneapi::experimental::complex<ValueT> u(b[0], b[1]);
275+
sycl::ext::oneapi::experimental::complex<ValueT> v(c[0], c[1]);
276+
t = t * u + v;
277+
return sycl::marray<ValueT, 2>{t.real(), t.imag()};
278+
}
279+
254280
/// A sycl::abs wrapper functors.
255281
struct abs {
256282
template <typename ValueT> auto operator()(const ValueT x) const {

sycl/test-e2e/syclcompat/math/math_complex.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,46 @@ void kernel_mul(int *result) {
174174
*result = r;
175175
}
176176

177+
void kernel_mul_add(int *result) {
178+
sycl::double2 d1, d2, d3;
179+
sycl::float2 f1, f2, f3;
180+
sycl::marray<double, 2> m_d1, m_d2, m_d3;
181+
sycl::marray<float, 2> m_f1, m_f2, m_f3;
182+
183+
d1 = sycl::double2(5.4, -6.3);
184+
d2 = sycl::double2(-7.2, 8.1);
185+
d3 = sycl::double2(1.0, -1.0);
186+
187+
f1 = sycl::float2(1.8, -2.7);
188+
f2 = sycl::float2(-3.6, 4.5);
189+
f3 = sycl::float2(1.0, -1.0);
190+
191+
bool r = true;
192+
float expect[4] = {13.150000, 88.100000, 6.670001, 16.820000};
193+
194+
auto a1 = syclcompat::cmul_add(d1, d2, d3);
195+
r = r && check(a1, expect);
196+
197+
auto a2 = syclcompat::cmul_add(f1, f2, f3);
198+
r = r && check(a2, expect + 2);
199+
200+
m_d1 = sycl::marray<double, 2>(5.4, -6.3);
201+
m_d2 = sycl::marray<double, 2>(-7.2, 8.1);
202+
m_d3 = sycl::marray<double, 2>(1.0, -1.0);
203+
204+
m_f1 = sycl::marray<float, 2>(1.8, -2.7);
205+
m_f2 = sycl::marray<float, 2>(-3.6, 4.5);
206+
m_f3 = sycl::marray<float, 2>(1.0, -1.0);
207+
208+
auto a3 = syclcompat::cmul_add(d1, d2, d3);
209+
r = r && check(a3, expect);
210+
211+
auto a4 = syclcompat::cmul_add(f1, f2, f3);
212+
r = r && check(a4, expect + 2);
213+
214+
*result = r;
215+
}
216+
177217
void test_abs() {
178218
std::cout << __PRETTY_FUNCTION__ << std::endl;
179219
ComplexLauncher<kernel_abs>().launch();
@@ -191,11 +231,17 @@ void test_conj() {
191231
ComplexLauncher<kernel_conj>().launch();
192232
}
193233

234+
void test_mul_add() {
235+
std::cout << __PRETTY_FUNCTION__ << std::endl;
236+
ComplexLauncher<kernel_mul_add>().launch();
237+
}
238+
194239
int main() {
195240
test_abs();
196241
test_mul();
197242
test_div();
198243
test_conj();
244+
test_mul_add();
199245

200246
return 0;
201247
}

0 commit comments

Comments
 (0)