Skip to content

Commit 640334b

Browse files
Port pytorch#149738 to ROCm (part 1).
This PR ports to release/2.5 the generalization of vectorized elementwise kernels for multiple heterogeneous tensor types. It's still missing the reverted threadblock mapping present in the original PR, which will come in a later PR. Co-authored-by: Jerry Mannil <[email protected]>
1 parent 5ca0bb6 commit 640334b

File tree

2 files changed

+101
-109
lines changed

2 files changed

+101
-109
lines changed

Diff for: aten/src/ATen/native/cuda/CUDALoops.cuh

+99-107
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,6 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
544544

545545
#ifdef USE_ROCM
546546
namespace {
547-
// Static functor type checker for binary functors with
548-
// float as the type of both parameters.
549547
template <
550548
typename TupleLike,
551549
typename FirstParamTy,
@@ -554,12 +552,11 @@ template <
554552
size_t arg_num = 0>
555553
struct check_binary_functor_types_for_specialization {
556554
constexpr static inline bool check() {
557-
bool current = false;
558555
if constexpr (arity != 2)
559556
return false;
560557
if constexpr (arg_num == 0) {
561558
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
562-
if constexpr (std::is_same_v<float, SelectedType>)
559+
if constexpr (std::is_same_v<FirstParamTy, SelectedType>)
563560
return check_binary_functor_types_for_specialization<
564561
TupleLike,
565562
FirstParamTy,
@@ -568,7 +565,7 @@ struct check_binary_functor_types_for_specialization {
568565
arg_num + 1>::check();
569566
} else if constexpr (arg_num == 1) {
570567
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
571-
if constexpr (std::is_same_v<float, SelectedType2>)
568+
if constexpr (std::is_same_v<SecondParamTy, SelectedType2>)
572569
return check_binary_functor_types_for_specialization<
573570
TupleLike,
574571
FirstParamTy,
@@ -613,30 +610,91 @@ struct check_binary_functor_types_for_specialization<
613610
};
614611

615612
// The following is a list of type specializations for vectorized_templated
616-
// elementwise kernel. It refers to the first and second runtime types of the
617-
// arguments of a binary functor.
618-
constexpr int number_of_binary_specializations = 4;
619-
const std::
620-
array<std::array<c10::ScalarType, 2>, number_of_binary_specializations>
621-
rt_binary_specializations = {
622-
{{c10::CppTypeToScalarType<float>::value,
623-
c10::CppTypeToScalarType<BFloat16>::value},
624-
{c10::CppTypeToScalarType<BFloat16>::value,
625-
c10::CppTypeToScalarType<float>::value},
626-
{c10::CppTypeToScalarType<float>::value,
627-
c10::CppTypeToScalarType<Half>::value},
628-
{c10::CppTypeToScalarType<Half>::value,
629-
c10::CppTypeToScalarType<float>::value}}};
613+
// elementwise kernel. The three types refer to runtime types of the output
614+
// tensor, first tensor argument, and the second tensor argument used for a
615+
// binary functor.
616+
constexpr std::array rt_binary_specializations = {
617+
std::array<c10::ScalarType, 3>(
618+
{c10::CppTypeToScalarType<float>::value,
619+
c10::CppTypeToScalarType<float>::value,
620+
c10::CppTypeToScalarType<BFloat16>::value}),
621+
std::array<c10::ScalarType, 3>(
622+
{c10::CppTypeToScalarType<float>::value,
623+
c10::CppTypeToScalarType<BFloat16>::value,
624+
c10::CppTypeToScalarType<float>::value}),
625+
std::array<c10::ScalarType, 3>(
626+
{c10::CppTypeToScalarType<BFloat16>::value,
627+
c10::CppTypeToScalarType<BFloat16>::value,
628+
c10::CppTypeToScalarType<float>::value}),
629+
std::array<c10::ScalarType, 3>(
630+
{c10::CppTypeToScalarType<float>::value,
631+
c10::CppTypeToScalarType<float>::value,
632+
c10::CppTypeToScalarType<Half>::value}),
633+
std::array<c10::ScalarType, 3>(
634+
{c10::CppTypeToScalarType<float>::value,
635+
c10::CppTypeToScalarType<Half>::value,
636+
c10::CppTypeToScalarType<float>::value}),
637+
std::array<c10::ScalarType, 3>(
638+
{c10::CppTypeToScalarType<Half>::value,
639+
c10::CppTypeToScalarType<Half>::value,
640+
c10::CppTypeToScalarType<float>::value})};
630641

631642
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
632643
if (iter.ninputs() != 2)
633644
return false;
634-
for (int i = 0; i < 4; i++)
635-
if (iter.input_dtype(0) == rt_binary_specializations[i][0] &&
636-
iter.input_dtype(1) == rt_binary_specializations[i][1])
645+
for (auto spec : rt_binary_specializations)
646+
if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] &&
647+
iter.input_dtype(1) == spec[2])
637648
return true;
638649
return false;
639650
}
651+
652+
template <int arg_index>
653+
struct type_specialized_kernel_launcher {
654+
template <
655+
typename func_t,
656+
typename array_t,
657+
typename inp_calc_t,
658+
typename out_calc_t,
659+
typename loader_t,
660+
typename storer_t>
661+
static void apply(
662+
ScalarType ret_t,
663+
ScalarType arg0_t,
664+
ScalarType arg1_t,
665+
int64_t numel,
666+
func_t f,
667+
array_t data,
668+
inp_calc_t input_offset_calculator,
669+
out_calc_t output_offset_calculator,
670+
loader_t loader,
671+
storer_t storer) {
672+
if (ret_t == rt_binary_specializations[arg_index][0] &&
673+
arg0_t == rt_binary_specializations[arg_index][1] &&
674+
arg1_t == rt_binary_specializations[arg_index][2])
675+
launch_vectorized_templated_kernel<
676+
func_t,
677+
array_t,
678+
inp_calc_t,
679+
out_calc_t,
680+
loader_t,
681+
storer_t,
682+
decltype(c10::impl::ScalarTypeToCPPType<
683+
rt_binary_specializations[arg_index][0]>::t),
684+
decltype(c10::impl::ScalarTypeToCPPType<
685+
rt_binary_specializations[arg_index][1]>::t),
686+
decltype(c10::impl::ScalarTypeToCPPType<
687+
rt_binary_specializations[arg_index][2]>::t)>(
688+
numel,
689+
f,
690+
data,
691+
input_offset_calculator,
692+
output_offset_calculator,
693+
loader,
694+
storer);
695+
}
696+
};
697+
640698
} // namespace
641699
#endif
642700

@@ -666,10 +724,10 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
666724
#ifdef USE_ROCM
667725
// Attempt to call specialized vectorized elementwise kernel
668726
// that enables interleaving.
669-
if (false && check_binary_rt_types_for_specialization(iter) &&
727+
if (check_binary_rt_types_for_specialization(iter) &&
670728
memory::can_vectorize_up_to<func_t>(data) > 1) {
671-
// constexpr to reduce the amount of kernels (empty) generated for
672-
// unrolled templated elementwise and limit which functors are actually
729+
// constexpr to reduce the amount of kernels generated for
730+
// vectorized templated elementwise and limit which functors are actually
673731
// applied to the load and store at compile time.
674732
using func_tuple = typename traits::ArgsTuple;
675733
if constexpr (
@@ -679,7 +737,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
679737
float,
680738
float,
681739
traits::arity,
682-
/*current=*/0>::check()) {
740+
/*arg_num=*/0>::check()) {
683741
// If we got here, we know we are in one of the specialized cases. We
684742
// need to translate the runtime type to a statically known type. This
685743
// is effectively hoisting to the host the switch over runtime type in
@@ -689,90 +747,24 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
689747
auto output_offset_calculator = TrivialOffsetCalculator<1>();
690748
auto loader = memory::LoadWithCast<traits::arity>(iter);
691749
auto storer = memory::StoreWithCast<1>(iter);
692-
if (iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
693-
iter.input_dtype(1) == c10::CppTypeToScalarType<BFloat16>::value)
694-
launch_vectorized_templated_kernel<
695-
func_t,
696-
at::detail::Array<char*, ntensors>,
697-
decltype(input_offset_calculator),
698-
decltype(output_offset_calculator),
699-
decltype(loader),
700-
decltype(storer),
701-
float,
702-
float,
703-
BFloat16>(
704-
numel,
705-
f,
706-
data,
707-
input_offset_calculator,
708-
output_offset_calculator,
709-
loader,
710-
storer);
711-
else if (
712-
iter.input_dtype(0) == c10::CppTypeToScalarType<BFloat16>::value &&
713-
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
714-
launch_vectorized_templated_kernel<
715-
func_t,
716-
at::detail::Array<char*, ntensors>,
717-
decltype(input_offset_calculator),
718-
decltype(output_offset_calculator),
719-
decltype(loader),
720-
decltype(storer),
721-
float,
722-
BFloat16,
723-
float>(
724-
numel,
725-
f,
726-
data,
727-
input_offset_calculator,
728-
output_offset_calculator,
729-
loader,
730-
storer);
731-
else if (
732-
iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
733-
iter.input_dtype(1) == c10::CppTypeToScalarType<Half>::value)
734-
launch_vectorized_templated_kernel<
735-
func_t,
736-
at::detail::Array<char*, ntensors>,
737-
decltype(input_offset_calculator),
738-
decltype(output_offset_calculator),
739-
decltype(loader),
740-
decltype(storer),
741-
float,
742-
float,
743-
Half>(
744-
numel,
745-
f,
746-
data,
747-
input_offset_calculator,
748-
output_offset_calculator,
749-
loader,
750-
storer);
751-
else if (
752-
iter.input_dtype(0) == c10::CppTypeToScalarType<Half>::value &&
753-
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
754-
launch_vectorized_templated_kernel<
755-
func_t,
756-
at::detail::Array<char*, ntensors>,
757-
decltype(input_offset_calculator),
758-
decltype(output_offset_calculator),
759-
decltype(loader),
760-
decltype(storer),
761-
float,
762-
Half,
763-
float>(
764-
numel,
765-
f,
766-
data,
767-
input_offset_calculator,
768-
output_offset_calculator,
769-
loader,
770-
storer);
771-
else
772-
TORCH_CHECK(false, "unreachable");
750+
memory::detail::static_unroll<
751+
type_specialized_kernel_launcher,
752+
rt_binary_specializations.size()>::
753+
with_args(
754+
iter.dtype(0),
755+
iter.input_dtype(0),
756+
iter.input_dtype(1),
757+
numel,
758+
f,
759+
data,
760+
input_offset_calculator,
761+
output_offset_calculator,
762+
loader,
763+
storer);
773764
return;
774765
}
775766
}
767+
776768
at::detail::Array<ScalarType, ntensors> dtypes;
777769
auto inner_strides = iter.get_inner_strides();
778770
at::detail::Array<int, ntensors> strides;

Diff for: aten/src/ATen/native/cuda/MemoryAccess.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ struct vectorized_templated {
409409
// float(float,bfloat16) and functor add on float(float,float).
410410
template <typename scalar_t>
411411
__device__ inline void store(scalar_t* from, int idx) {
412-
using vec_t = aligned_vector<scalar_t, vec_size>;
413-
scalar_t* to = reinterpret_cast<scalar_t*>(data[0]) + block_work_size * idx;
412+
using vec_t = aligned_vector<CastToT, vec_size>;
413+
CastToT* to = reinterpret_cast<CastToT*>(data[0]) + block_work_size * idx;
414414
vec_t* to_ = reinterpret_cast<vec_t*>(to);
415415
int thread_idx = threadIdx.x;
416416
#pragma unroll

0 commit comments

Comments
 (0)