Skip to content

[SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix #5870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 8, 2022

Conversation

hdelan
Copy link
Contributor

@hdelan hdelan commented Mar 23, 2022

Changing joint_matrix impl to use float as storage type instead of uint32_t for tf32.

@hdelan hdelan requested a review from a team as a code owner March 23, 2022 15:50
@hdelan hdelan requested a review from v-klochkov March 23, 2022 15:50
@hdelan hdelan changed the title Tf32 joint matrix [SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix Mar 23, 2022
@JackAKirk JackAKirk requested a review from dkhaldi March 23, 2022 16:10
@dkhaldi
Copy link
Contributor

dkhaldi commented Mar 23, 2022

Here, you are changing the spec of joint_matrix to add a new template argument which is the actual data type (tf32) while using the existing type as the storage type (float here). I don't think this is what we discussed before.
The idea is to keep only one type argument to the joint matrix class and introduce a new tf32 type in the form of an empty class that is only accessed and used from within the matrix namespace.

@hdelan
Copy link
Contributor Author

hdelan commented Mar 24, 2022

Hi @dkhaldi thanks for your response. We have talked about this a bit internally and we think that each approach has pros and cons:

Approach 1:

Using an extra template parameter in joint_matrix constructor to specify the precision. This is the same as this current approach, but with enum class use_tf32 {yes, no}; being replaced with something more generic like

enum class precision { default, tf32 /* some other precisions for single bit types etc */ };

This would default to precision::default so the user only needs to concern themselves with the enum in the case of using some non-standard precision.

We could check that the precision parameter is compatible with the underlying type at the construction of the joint_matrix. Semantically this makes it clear that the programmer need only concern themselves with floats, and the implementation will take care of the tf32 precision bit.

It requires an extra template parameter however this could be useful down the line when other precisions are offered.

Another benefit to this approach is that there may be multiple mappings of matrix array types to joint_matrix::data registers, since the register type is determined by the precision parameter and the implementation could allow many mappings from array types to a given register type. This would give a lot of flexibility to the implementation, and all the programmer needs to be aware of is what combination of array type and precision type is allowed, which can be easily determined at compile time.

Approach 2:

Use an empty tf32 class as the type argument into the joint matrix constructor. This avoids adding an extra template parameter, however it has some drawbacks:

  1. It encourages the programmer to consider the tf32 as an actual type, when in fact it is an empty class.
  2. It does not make the relationship between float and tf32 clear. If the programmer is constructing a joint matrix of type tf32, then why should joint_matrix_load take a multi_ptr to a float? The programmer might try to make an accessor to tf32s instead, which would not work as it is an empty class.
  3. Errors of incompatibility between the storage type and the tf32 type would only be caught at joint_matrix_load instead of one step earlier, upon the construction of joint_matrix. Moreover the errors are likely to be more difficult to parse than if they were to be caught upon constructing the joint_matrix.

Please let me know your thoughts. Thanks

@dkhaldi
Copy link
Contributor

dkhaldi commented Mar 29, 2022

I started putting together support for tf32. A draft PR can be found here:
#5920
This can give an idea on the changes that are needed to handle tf32 and the way to differentiate between element type and storage type.
The missing parts are mainly related to SPIRV, that's why I declared this as a draft. But in your case, since you don't support JIT and you don't have element wise ops yet, I believe adding the empty class will be the only change.

@hdelan
Copy link
Contributor Author

hdelan commented Apr 8, 2022

I cannot see the logs for the test suite run but locally the test InorderQueue/in_order_get_property.cpp is failing, which is unrelated to this PR. Therefore I think this is ready to merge, if possible.

typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tagging @yubingex007-a11y here as changing the type of the load will be necessary to handle tf32 case: type of memory can be difference from type of joint matrix.
However, @JackAKirk, we should restrict this flexibility to only tf32.
Can this work in the case of bfloat16? load from float to bfloat16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the final bfloat16 cuda impl is ready now using the old API (#5964).

Sounds fine to restrict the flexibility: I think the way this is implemented it already does restrict it to the tf32 type. If we add subbyte/single-bit cases then I think this would also encounter type of memory can be difference from type of joint matrix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -573,6 +604,26 @@ joint_matrix_mad(
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

float float_to_tf32(float a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in sync with the fact that an element indexing of joint matrix of type tf32 is of type float.
joint_matrix<precision::tf32, TM, TK> sub_a(sg);
sub_a.get_wi_data()[i] = float_to_tf32(sub_a.get_wi_data()[i]);
sub_a.get_wi_data()[i] is of type float but numerically it is tf32 after this conversion.
Please add this clarification as a comment.

// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
// Round a, b to tf32
for (auto i = 0; i < 4; ++i)
sub_a.data[i] = float_to_tf32(sub_a.data[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the expected way to perform the rounding, if users want to, but I am still find exposing ".data" is different from the element wise indexing we are currently doing.
I would recommend moving from this to the current API:
sub_a.get_wi_data()[i] = float_to_tf32(sub_a.get_wi_data()[i]);

Copy link
Contributor

@JackAKirk JackAKirk Apr 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly what we will do (but in a future PR): I switched the impl here to use marray for data in joint_matrix: Then get_wi_data()[i] will call get_wi_elem that will return the ith element of the marray. We will loop over get_wi_data.length() as you do too.

}

// This function just zeros out the bottom 13 bits of the tf32 type
float tf32_to_float(float a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a use case for this?
cutlass has this.
if yes, rename it to truncate_to_tf32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have renamed the function.

@hdelan hdelan requested a review from a team as a code owner April 14, 2022 09:36
dkhaldi
dkhaldi previously approved these changes Apr 22, 2022
Copy link
Contributor

@dkhaldi dkhaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with addressing ".data" change to a future PR to adapt to the current spec syntax as follows:
sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]);

as addressed here: #5870 (comment)

Thus, this PR LGTM

@hdelan
Copy link
Contributor Author

hdelan commented Apr 22, 2022

Thanks @dkhaldi !

@hdelan hdelan force-pushed the tf32-joint-matrix branch from d3e1247 to 438a9f2 Compare May 9, 2022 15:27
@hdelan hdelan requested review from a team and pvchupin as code owners May 9, 2022 15:27
@hdelan hdelan requested a review from smaslov-intel May 9, 2022 15:27
@hdelan hdelan force-pushed the tf32-joint-matrix branch from 21cc02c to 13b1efb Compare May 10, 2022 15:53
Copy link
Contributor

@dkhaldi dkhaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JackAKirk
Copy link
Contributor

@intel/llvm-reviewers-cuda any more review for this? If not it would be super nice if it could be merged within the next 12 hrs or so.

@pvchupin
Copy link
Contributor

pvchupin commented Jun 8, 2022

Ping @v-klochkov for review.

Copy link
Contributor

@steffenlarsen steffenlarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a minor question. I am okay with merging as-is and potentially addressing it separately.

} else if constexpr (NumRows == 32 && NumCols == 8) {
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
if (std::is_same<S, float>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this not to be if constexpr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if constexpr is c++17 so it may have been removed for that reason. Although I noticed that for some reason c++17 is allowed in the extension namespace, although I don't understand this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe extensions that need to be explicitly included are allowed to use C++17 features.

Copy link
Contributor

@JackAKirk JackAKirk Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowed in the sense that tests don't appear to fail due to c++17 in extension namespace that fail due to c++17 stuff in other namespaces!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

Copy link
Contributor

@JackAKirk JackAKirk Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964

As long as it doesn't bleed into sycl.hpp then it should be fine. There should be a test that fails if it does.

@pvchupin pvchupin merged commit 2340b33 into intel:sycl Jun 8, 2022
pvchupin pushed a commit to intel/llvm-test-suite that referenced this pull request Jun 8, 2022
aelovikov-intel pushed a commit to aelovikov-intel/llvm that referenced this pull request Mar 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants