-
Notifications
You must be signed in to change notification settings - Fork 769
[SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix #5870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added bfloat16 in oneapi experimental namespace. Signed-off-by: jack.kirk <[email protected]>
Signed-off-by: jack.kirk <[email protected]>
…_BF16_CONVERSION.asciidoc
Removed aspect reference: can be added once the ext_oneapi_bfloat16 aspect is merged.
Here, you are changing the spec of joint_matrix to add a new template argument which is the actual data type (tf32) while using the existing type as the storage type (float here). I don't think this is what we discussed before. |
Hi @dkhaldi thanks for your response. We have talked about this a bit internally and we think that each approach has pros and cons: Approach 1:Using an extra template parameter in
This would default to We could check that the precision parameter is compatible with the underlying type at the construction of the It requires an extra template parameter however this could be useful down the line when other precisions are offered. Another benefit to this approach is that there may be multiple mappings of matrix array types to Approach 2:Use an empty
Please let me know your thoughts. Thanks |
I started putting together support for tf32. A draft PR can be found here: |
I cannot see the logs for the test suite run but locally the test |
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental:: | ||
matrix::matrix_layout::row_major || | ||
Layout == sycl::ext::oneapi::experimental:: | ||
matrix::matrix_layout::col_major>> { | ||
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, | ||
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am tagging @yubingex007-a11y here as changing the type of the load will be necessary to handle tf32 case: type of memory can be difference from type of joint matrix.
However, @JackAKirk, we should restrict this flexibility to only tf32.
Can this work in the case of bfloat16? load from float to bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the final bfloat16 cuda impl is ready now using the old API (#5964).
Sounds fine to restrict the flexibility: I think the way this is implemented it already does restrict it to the tf32 type. If we add subbyte/single-bit cases then I think this would also encounter type of memory can be difference from type of joint matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently restricted to only being used by tf32 when the other datatype is float. See https://github.com/intel/llvm/pull/5870/files/618c80750930b0eaec8cde468c880d52ba54c80c#diff-f71a436bdeda598b29caad471fa637a2844a12f38fe4e85b15b2ccb37bd09833R539
@@ -573,6 +604,26 @@ joint_matrix_mad( | |||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | |||
} | |||
|
|||
float float_to_tf32(float a) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in sync with the fact that an element indexing of joint matrix of type tf32 is of type float.
joint_matrix<precision::tf32, TM, TK> sub_a(sg);
sub_a.get_wi_data()[i] = float_to_tf32(sub_a.get_wi_data()[i]);
sub_a.get_wi_data()[i] is of type float but numerically it is tf32 after this conversion.
Please add this clarification as a comment.
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} | ||
// Round a, b to tf32 | ||
for (auto i = 0; i < 4; ++i) | ||
sub_a.data[i] = float_to_tf32(sub_a.data[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be the expected way to perform the rounding, if users want to, but I am still find exposing ".data" is different from the element wise indexing we are currently doing.
I would recommend moving from this to the current API:
sub_a.get_wi_data()[i] = float_to_tf32(sub_a.get_wi_data()[i]);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly what we will do (but in a future PR): I switched the impl here to use marray
for data
in joint_matrix
: Then get_wi_data()[i]
will call get_wi_elem
that will return the ith element of the marray
. We will loop over get_wi_data.length()
as you do too.
} | ||
|
||
// This function just zeros out the bottom 13 bits of the tf32 type | ||
float tf32_to_float(float a) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a use case for this?
cutlass has this.
if yes, rename it to truncate_to_tf32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have renamed the function.
sycl/doc/extensions/experimental/sycl_ext_oneapi_bfloat16.asciidoc
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am okay with addressing ".data" change to a future PR to adapt to the current spec syntax as follows:
sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]);
as addressed here: #5870 (comment)
Thus, this PR LGTM
Thanks @dkhaldi ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@intel/llvm-reviewers-cuda any more review for this? If not it would be super nice if it could be merged within the next 12 hrs or so. |
Ping @v-klochkov for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only a minor question. I am okay with merging as-is and potentially addressing it separately.
} else if constexpr (NumRows == 32 && NumCols == 8) { | ||
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
if (std::is_same<S, float>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for this not to be if constexpr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if constexpr
is c++17 so it may have been removed for that reason. Although I noticed that for some reason c++17 is allowed in the extension namespace, although I don't understand this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe extensions that need to be explicitly included are allowed to use C++17 features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allowed in the sense that tests don't appear to fail due to c++17 in extension namespace that fail due to c++17 stuff in other namespaces!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964
As long as it doesn't bleed into sycl.hpp then it should be fine. There should be a test that fails if it does.
Changing
joint_matrix
impl to usefloat
as storage type instead ofuint32_t
for tf32.