|
15 | 15 | #include <CL/sycl/detail/type_traits.hpp>
|
16 | 16 |
|
17 | 17 | #include <CL/__spirv/spirv_ops.hpp>
|
| 18 | +#include <sycl/ext/oneapi/experimental/bfloat16.hpp> |
18 | 19 |
|
19 | 20 | // TODO Decide whether to mark functions with this attribute.
|
20 | 21 | #define __NOEXC /*noexcept*/
|
|
26 | 27 | #endif
|
27 | 28 |
|
28 | 29 | __SYCL_INLINE_NAMESPACE(cl) {
|
29 |
| -namespace sycl { |
30 |
| -namespace ext { |
31 |
| -namespace oneapi { |
32 |
| -namespace experimental { |
| 30 | +namespace sycl::ext::oneapi::experimental { |
| 31 | +namespace detail { |
| 32 | +template <size_t N> |
| 33 | +uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) { |
| 34 | + uint32_t res; |
| 35 | + std::memcpy(&res, &x[start], sizeof(uint32_t)); |
| 36 | + return res; |
| 37 | +} |
| 38 | +} // namespace detail |
33 | 39 |
|
34 | 40 | // Provides functionality to print data from kernels in a C way:
|
35 | 41 | // - On non-host devices this function is directly mapped to printf from
|
@@ -117,11 +123,154 @@ inline __SYCL_ALWAYS_INLINE
|
117 | 123 |
|
118 | 124 | } // namespace native
|
119 | 125 |
|
120 |
| -} // namespace experimental |
121 |
| -} // namespace oneapi |
122 |
| -} // namespace ext |
| 126 | +template <typename T> |
| 127 | +std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) { |
| 128 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 129 | + return bfloat16::from_bits(__clc_fabs(x.raw())); |
| 130 | +#else |
| 131 | + std::ignore = x; |
| 132 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 133 | + PI_ERROR_INVALID_DEVICE); |
| 134 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 135 | +} |
| 136 | + |
| 137 | +template <size_t N> |
| 138 | +sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) { |
| 139 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 140 | + sycl::marray<bfloat16, N> res; |
| 141 | + |
| 142 | + for (size_t i = 0; i < N / 2; i++) { |
| 143 | + auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2)); |
| 144 | + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); |
| 145 | + } |
| 146 | + |
| 147 | + if constexpr (N % 2) { |
| 148 | + res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); |
| 149 | + } |
| 150 | + return res; |
| 151 | +#else |
| 152 | + std::ignore = x; |
| 153 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 154 | + PI_ERROR_INVALID_DEVICE); |
| 155 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 156 | +} |
| 157 | + |
| 158 | +template <typename T> |
| 159 | +std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) { |
| 160 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 161 | + return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw())); |
| 162 | +#else |
| 163 | + std::ignore = x; |
| 164 | + (void)y; |
| 165 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 166 | + PI_ERROR_INVALID_DEVICE); |
| 167 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 168 | +} |
| 169 | + |
| 170 | +template <size_t N> |
| 171 | +sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x, |
| 172 | + sycl::marray<bfloat16, N> y) { |
| 173 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 174 | + sycl::marray<bfloat16, N> res; |
| 175 | + |
| 176 | + for (size_t i = 0; i < N / 2; i++) { |
| 177 | + auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2), |
| 178 | + detail::to_uint32_t(y, i * 2)); |
| 179 | + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); |
| 180 | + } |
| 181 | + |
| 182 | + if constexpr (N % 2) { |
| 183 | + res[N - 1] = |
| 184 | + bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); |
| 185 | + } |
| 186 | + |
| 187 | + return res; |
| 188 | +#else |
| 189 | + std::ignore = x; |
| 190 | + (void)y; |
| 191 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 192 | + PI_ERROR_INVALID_DEVICE); |
| 193 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 194 | +} |
| 195 | + |
| 196 | +template <typename T> |
| 197 | +std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) { |
| 198 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 199 | + return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw())); |
| 200 | +#else |
| 201 | + std::ignore = x; |
| 202 | + (void)y; |
| 203 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 204 | + PI_ERROR_INVALID_DEVICE); |
| 205 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 206 | +} |
| 207 | + |
| 208 | +template <size_t N> |
| 209 | +sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x, |
| 210 | + sycl::marray<bfloat16, N> y) { |
| 211 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 212 | + sycl::marray<bfloat16, N> res; |
| 213 | + |
| 214 | + for (size_t i = 0; i < N / 2; i++) { |
| 215 | + auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2), |
| 216 | + detail::to_uint32_t(y, i * 2)); |
| 217 | + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); |
| 218 | + } |
| 219 | + |
| 220 | + if constexpr (N % 2) { |
| 221 | + res[N - 1] = |
| 222 | + bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); |
| 223 | + } |
| 224 | + return res; |
| 225 | +#else |
| 226 | + std::ignore = x; |
| 227 | + (void)y; |
| 228 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 229 | + PI_ERROR_INVALID_DEVICE); |
| 230 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 231 | +} |
| 232 | + |
| 233 | +template <typename T> |
| 234 | +std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) { |
| 235 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 236 | + return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw())); |
| 237 | +#else |
| 238 | + std::ignore = x; |
| 239 | + (void)y; |
| 240 | + (void)z; |
| 241 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 242 | + PI_ERROR_INVALID_DEVICE); |
| 243 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 244 | +} |
| 245 | + |
| 246 | +template <size_t N> |
| 247 | +sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x, |
| 248 | + sycl::marray<bfloat16, N> y, |
| 249 | + sycl::marray<bfloat16, N> z) { |
| 250 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 251 | + sycl::marray<bfloat16, N> res; |
| 252 | + |
| 253 | + for (size_t i = 0; i < N / 2; i++) { |
| 254 | + auto partial_res = |
| 255 | + __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2), |
| 256 | + detail::to_uint32_t(z, i * 2)); |
| 257 | + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); |
| 258 | + } |
| 259 | + |
| 260 | + if constexpr (N % 2) { |
| 261 | + res[N - 1] = bfloat16::from_bits( |
| 262 | + __clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); |
| 263 | + } |
| 264 | + return res; |
| 265 | +#else |
| 266 | + std::ignore = x; |
| 267 | + (void)y; |
| 268 | + throw runtime_error("bfloat16 is not currently supported on the host device.", |
| 269 | + PI_ERROR_INVALID_DEVICE); |
| 270 | +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) |
| 271 | +} |
123 | 272 |
|
124 |
| -} // namespace sycl |
| 273 | +} // namespace sycl::ext::oneapi::experimental |
125 | 274 | } // __SYCL_INLINE_NAMESPACE(cl)
|
126 | 275 |
|
127 | 276 | #undef __SYCL_CONSTANT_AS
|
0 commit comments