Skip to content

Commit 413a9ef

Browse files
authored
[SYCL][CUDA] Add bf16 builtins operating on storage types (#5748)
Add bf16 builtins operating on storage types. Partially implements https://github.com/intel/llvm/pull/5645/files for CUDA (only operations on storage types). This PR includes a bugfix for some NVPTX intrinsics, which will also be pushed upstream. Tests for this are in intel/llvm-test-suite#897.
1 parent 68b089f commit 413a9ef

File tree

4 files changed

+108
-4
lines changed

4 files changed

+108
-4
lines changed

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

+4-4
Original file line numberDiff line numberDiff line change
@@ -855,13 +855,13 @@ def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
855855
// Abs, Neg bf16, bf16x2
856856
//
857857

858-
def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $dst;", Int16Regs,
858+
def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs,
859859
Int16Regs, int_nvvm_abs_bf16, [hasPTX70, hasSM80]>;
860-
def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $dst;", Int32Regs,
860+
def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs,
861861
Int32Regs, int_nvvm_abs_bf16x2, [hasPTX70, hasSM80]>;
862-
def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $dst;", Int16Regs,
862+
def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs,
863863
Int16Regs, int_nvvm_neg_bf16, [hasPTX70, hasSM80]>;
864-
def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $dst;", Int32Regs,
864+
def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs,
865865
Int32Regs, int_nvvm_neg_bf16x2, [hasPTX70, hasSM80]>;
866866

867867
//

sycl/include/CL/__spirv/spirv_ops.hpp

+24
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,30 @@ extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 8>
793793
extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 16>
794794
__clc_native_exp2(__ocl_vec_t<_Float16, 16>);
795795

796+
#define __CLC_BF16(...) \
797+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs( \
798+
__VA_ARGS__) noexcept; \
799+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin( \
800+
__VA_ARGS__, __VA_ARGS__) noexcept; \
801+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax( \
802+
__VA_ARGS__, __VA_ARGS__) noexcept; \
803+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma( \
804+
__VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept;
805+
806+
#define __CLC_BF16_SCAL_VEC(TYPE) \
807+
__CLC_BF16(TYPE) \
808+
__CLC_BF16(__ocl_vec_t<TYPE, 2>) \
809+
__CLC_BF16(__ocl_vec_t<TYPE, 3>) \
810+
__CLC_BF16(__ocl_vec_t<TYPE, 4>) \
811+
__CLC_BF16(__ocl_vec_t<TYPE, 8>) \
812+
__CLC_BF16(__ocl_vec_t<TYPE, 16>)
813+
814+
__CLC_BF16_SCAL_VEC(uint16_t)
815+
__CLC_BF16_SCAL_VEC(uint32_t)
816+
817+
#undef __CLC_BF16_SCAL_VEC
818+
#undef __CLC_BF16
819+
796820
#else // if !__SYCL_DEVICE_ONLY__
797821

798822
template <typename dataT>

sycl/include/CL/sycl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
6161
#include <sycl/ext/oneapi/backend/level_zero.hpp>
6262
#endif
63+
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
6364
#include <sycl/ext/oneapi/device_global/properties.hpp>
6465
#include <sycl/ext/oneapi/experimental/builtins.hpp>
6566
#include <sycl/ext/oneapi/filter_selector.hpp>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#pragma once
2+
3+
#include <CL/__spirv/spirv_ops.hpp>
4+
#include <CL/sycl/builtins.hpp>
5+
#include <CL/sycl/detail/builtins.hpp>
6+
#include <CL/sycl/detail/generic_type_lists.hpp>
7+
#include <CL/sycl/detail/generic_type_traits.hpp>
8+
#include <CL/sycl/detail/type_traits.hpp>
9+
10+
__SYCL_INLINE_NAMESPACE(cl) {
11+
namespace sycl {
12+
namespace ext {
13+
namespace oneapi {
14+
15+
namespace detail {
16+
17+
template <typename T> struct is_bf16_storage_type {
18+
static constexpr int value = false;
19+
};
20+
21+
template <> struct is_bf16_storage_type<uint16_t> {
22+
static constexpr int value = true;
23+
};
24+
25+
template <> struct is_bf16_storage_type<uint32_t> {
26+
static constexpr int value = true;
27+
};
28+
29+
template <int N> struct is_bf16_storage_type<vec<uint16_t, N>> {
30+
static constexpr int value = true;
31+
};
32+
33+
template <int N> struct is_bf16_storage_type<vec<uint32_t, N>> {
34+
static constexpr int value = true;
35+
};
36+
37+
} // namespace detail
38+
39+
template <typename T>
40+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fabs(T x) {
41+
#ifdef __SYCL_DEVICE_ONLY__
42+
return __clc_fabs(x);
43+
#else
44+
throw runtime_error("bf16 is not supported on host device.",
45+
PI_INVALID_DEVICE);
46+
#endif
47+
}
48+
template <typename T>
49+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmin(T x, T y) {
50+
#ifdef __SYCL_DEVICE_ONLY__
51+
return __clc_fmin(x, y);
52+
#else
53+
throw runtime_error("bf16 is not supported on host device.",
54+
PI_INVALID_DEVICE);
55+
#endif
56+
}
57+
template <typename T>
58+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmax(T x, T y) {
59+
#ifdef __SYCL_DEVICE_ONLY__
60+
return __clc_fmax(x, y);
61+
#else
62+
throw runtime_error("bf16 is not supported on host device.",
63+
PI_INVALID_DEVICE);
64+
#endif
65+
}
66+
template <typename T>
67+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fma(T x, T y, T z) {
68+
#ifdef __SYCL_DEVICE_ONLY__
69+
return __clc_fma(x, y, z);
70+
#else
71+
throw runtime_error("bf16 is not supported on host device.",
72+
PI_INVALID_DEVICE);
73+
#endif
74+
}
75+
76+
} // namespace oneapi
77+
} // namespace ext
78+
} // namespace sycl
79+
} // __SYCL_INLINE_NAMESPACE(cl)

0 commit comments

Comments
 (0)