-
Notifications
You must be signed in to change notification settings - Fork 188
[CK_TILE] Move GEMM pipeline tail handling logic to pipelines #2222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Please remove all now unused try_run
functions (it was called by the check_tail
). I think we would need this for GEMM kernel as well.
|
||
template <typename RunFunction> | ||
CK_TILE_HOST_DEVICE static auto | ||
TailHandler(RunFunction run_func, bool has_hot_loop, TailNumber tail_number) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TailHandler(RunFunction run_func, bool has_hot_loop, TailNumber tail_number) | |
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) |
} | ||
} | ||
#if defined(__HIP_DEVICE_COMPILE__) | ||
// This path should be unreachable in device code if tail_number is always valid. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// This path should be unreachable in device code if tail_number is always valid. | |
// This path should be unreachable in device code if tail_number is valid. |
|
||
template <typename RunFunction> | ||
CK_TILE_HOST_DEVICE static auto | ||
TailHandler(RunFunction run_func, bool has_hot_loop, TailNumber tail_number) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TailHandler(RunFunction run_func, bool has_hot_loop, TailNumber tail_number) | |
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) |
} | ||
|
||
auto check_tail = [&](auto... TNs) { | ||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the unused try_run
Proposed changes
In this PR, the GEMM pipeline tail handling logic is moved to a function inside the pipeline itself, so that the user code (e.g. the example and tests) doesn't have to do it for each pipeline that can be used. Also the persistent variant of the grouped gemm kernel needs to do tail handling logic in device code, now that is also implemented within the pipeline code itself.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered