Skip to content

[NFC][SYCL] Use bit_cast for bfloat16 casts. #17256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <sycl/aliases.hpp> // for half
#include <sycl/bit_cast.hpp> // for bit_cast
#include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
#include <sycl/half_type.hpp> // for half

Expand Down Expand Up @@ -51,11 +52,6 @@ class bfloat16;

namespace detail {
using Bfloat16StorageT = uint16_t;
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);
// Class to convert different data types to Bfloat16
// with different rounding modes.
class ConvertToBfloat16;

template <int N> void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) {
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
Expand Down Expand Up @@ -84,12 +80,6 @@ class bfloat16 {
protected:
detail::Bfloat16StorageT value;

friend inline detail::Bfloat16StorageT
detail::bfloat16ToBits(const bfloat16 &Value);
friend inline bfloat16
detail::bitsToBfloat16(const detail::Bfloat16StorageT Value);
friend class detail::ConvertToBfloat16;

public:
bfloat16() = default;
~bfloat16() = default;
Expand Down Expand Up @@ -187,7 +177,7 @@ class bfloat16 {
(__SYCL_CUDA_ARCH__ >= 800)
detail::Bfloat16StorageT res;
asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
return detail::bitsToBfloat16(res);
return bit_cast<bfloat16>(res);
#elif defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
#else
Expand Down Expand Up @@ -294,19 +284,6 @@ template <int N> void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) {
#endif
}

// Helper function for getting the internal representation of a bfloat16.
inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) {
return Value.value;
}

// Helper function for creating a float16 from a value with the same type as the
// internal representation.
inline bfloat16 bitsToBfloat16(const Bfloat16StorageT Value) {
bfloat16 res;
res.value = Value;
return res;
}

// Class to convert different data types to Bfloat16
// with different rounding modes.
class ConvertToBfloat16 {
Expand Down Expand Up @@ -348,15 +325,15 @@ class ConvertToBfloat16 {
// +/-infinity and NAN
if (bf16_exp == 0xFF) {
if (!f_mant)
return bitsToBfloat16(bf16_sign ? 0xFF80 : 0x7F80);
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0xFF80 : 0x7F80);
else
return bitsToBfloat16((bf16_sign << 15) | (bf16_exp << 7) |
bf16_mant);
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) |
(bf16_exp << 7) | bf16_mant);
}

// +/-0
if (!bf16_exp && !f_mant) {
return bitsToBfloat16(bf16_sign ? 0x8000 : 0x0);
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0x8000 : 0x0);
}

uint16_t mant_discard = static_cast<uint16_t>(f_mant & 0xFFFF);
Expand Down Expand Up @@ -385,7 +362,8 @@ class ConvertToBfloat16 {
bf16_exp++;
}

return bitsToBfloat16((bf16_sign << 15) | (bf16_exp << 7) | bf16_mant);
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) | (bf16_exp << 7) |
bf16_mant);
}
}

Expand All @@ -401,7 +379,7 @@ class ConvertToBfloat16 {
size_t msb_pos = get_msb_pos(u);
// return half representation for 1
if (msb_pos == 0)
return bitsToBfloat16(0x3F80);
return bit_cast<bfloat16, uint16_t>(0x3F80);

T mant = u & ((static_cast<T>(1) << msb_pos) - 1);
// Unsigned integral value can be represented by 1.mant * (2^msb_pos),
Expand Down Expand Up @@ -442,7 +420,7 @@ class ConvertToBfloat16 {
}

b_exp += 127;
return bitsToBfloat16((b_exp << 7) | b_mant);
return bit_cast<bfloat16, uint16_t>((b_exp << 7) | b_mant);
}

// Helper function to get BF16 from signed integral data types.
Expand All @@ -459,7 +437,7 @@ class ConvertToBfloat16 {
UTy ui = (i > 0) ? static_cast<UTy>(i) : static_cast<UTy>(-i);
size_t msb_pos = get_msb_pos<UTy>(ui);
if (msb_pos == 0)
return bitsToBfloat16(b_sign ? 0xBF80 : 0x3F80);
return bit_cast<bfloat16, uint16_t>(b_sign ? 0xBF80 : 0x3F80);
UTy mant = ui & ((static_cast<UTy>(1) << msb_pos) - 1);

uint16_t b_exp = msb_pos;
Expand Down Expand Up @@ -495,7 +473,7 @@ class ConvertToBfloat16 {
b_mant = 0;
}
b_exp += 127;
return bitsToBfloat16(b_sign | (b_exp << 7) | b_mant);
return bit_cast<bfloat16, uint16_t>(b_sign | (b_exp << 7) | b_mant);
}

// Helper function to get BF16 from double with RTE rounding modes.
Expand Down
102 changes: 45 additions & 57 deletions sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#pragma once

#include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
#include <sycl/bit_cast.hpp> // for sycl::bit_cast
#include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
#include <sycl/detail/vector_convert.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
#include <sycl/marray.hpp> // for marray

#include <cstring> // for size_t
Expand Down Expand Up @@ -46,7 +47,7 @@ constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
// significand has non-zero bits.
template <typename T>
std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
uint16_t XBits = bit_cast<uint16_t>(x);
return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
}

Expand Down Expand Up @@ -90,15 +91,15 @@ template <typename T>
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
(__SYCL_CUDA_ARCH__ >= 800)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
uint16_t XBits = bit_cast<uint16_t>(x);
return bit_cast<bfloat16>(__clc_fabs(XBits));
#else
if (!isnan(x)) {
const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
constexpr uint16_t SignMask = 0x8000;
uint16_t XBits = bit_cast<uint16_t>(x);
x = ((XBits & SignMask) == SignMask)
? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
: x;
? bit_cast<bfloat16, uint16_t>(XBits & ~SignMask)
: bit_cast<bfloat16>(x);
}
return x;
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
Expand All @@ -116,9 +117,8 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
res[N - 1] = bit_cast<bfloat16>(__clc_fabs(XBits));
}
#else
for (size_t i = 0; i < N; i++) {
Expand Down Expand Up @@ -154,25 +154,22 @@ template <typename T>
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
(__SYCL_CUDA_ARCH__ >= 800)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
uint16_t XBits = bit_cast<uint16_t>(x);
uint16_t YBits = bit_cast<uint16_t>(y);
return bit_cast<bfloat16>(__clc_fmin(XBits, YBits));
#else
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
constexpr uint16_t CanonicalNan = 0x7FC0;
if (isnan(x) && isnan(y))
return oneapi::detail::bitsToBfloat16(CanonicalNan);
return bit_cast<bfloat16>(CanonicalNan);

if (isnan(x))
return y;
if (isnan(y))
return x;
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
if (((XBits | YBits) ==
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
!(XBits & YBits))
return oneapi::detail::bitsToBfloat16(
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
uint16_t XBits = bit_cast<uint16_t>(x);
uint16_t YBits = bit_cast<uint16_t>(y);
if (((XBits | YBits) == static_cast<uint16_t>(0x8000)) && !(XBits & YBits))
return bit_cast<bfloat16>(static_cast<uint16_t>(0x8000));

return (x < y) ? x : y;
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
Expand All @@ -192,11 +189,9 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
res[N - 1] = bit_cast<bfloat16>(__clc_fmin(XBits, YBits));
}
#else
for (size_t i = 0; i < N; i++) {
Expand Down Expand Up @@ -237,24 +232,22 @@ template <typename T>
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
(__SYCL_CUDA_ARCH__ >= 800)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
uint16_t XBits = bit_cast<uint16_t>(x);
uint16_t YBits = bit_cast<uint16_t>(y);
return bit_cast<bfloat16>(__clc_fmax(XBits, YBits));
#else
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
constexpr uint16_t CanonicalNan = 0x7FC0;
if (isnan(x) && isnan(y))
return oneapi::detail::bitsToBfloat16(CanonicalNan);
return bit_cast<bfloat16>(CanonicalNan);

if (isnan(x))
return y;
if (isnan(y))
return x;
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
if (((XBits | YBits) ==
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
!(XBits & YBits))
return oneapi::detail::bitsToBfloat16(0);
uint16_t XBits = bit_cast<uint16_t>(x);
uint16_t YBits = bit_cast<uint16_t>(y);
if (((XBits | YBits) == static_cast<uint16_t>(0x8000)) && !(XBits & YBits))
return bit_cast<bfloat16, uint16_t>(0);

return (x > y) ? x : y;
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
Expand All @@ -274,11 +267,9 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
res[N - 1] = bit_cast<bfloat16>(__clc_fmax(XBits, YBits));
}
#else
for (size_t i = 0; i < N; i++) {
Expand Down Expand Up @@ -319,10 +310,10 @@ template <typename T>
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
(__SYCL_CUDA_ARCH__ >= 800)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z);
return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
uint16_t XBits = bit_cast<uint16_t>(x);
uint16_t YBits = bit_cast<uint16_t>(y);
uint16_t ZBits = bit_cast<uint16_t>(z);
return bit_cast<bfloat16>(__clc_fma(XBits, YBits, ZBits));
#else
return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
Expand All @@ -344,13 +335,10 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
oneapi::detail::Bfloat16StorageT ZBits =
oneapi::detail::bfloat16ToBits(z[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
uint16_t ZBits = bit_cast<uint16_t>(z[N - 1]);
res[N - 1] = bit_cast<bfloat16>(__clc_fma(XBits, YBits, ZBits));
}
#else
for (size_t i = 0; i < N; i++) {
Expand Down
Loading