@@ -544,8 +544,6 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
544
544
545
545
#ifdef USE_ROCM
546
546
namespace {
547
- // Static functor type checker for binary functors with
548
- // float as the type of both parameters.
549
547
template <
550
548
typename TupleLike,
551
549
typename FirstParamTy,
@@ -554,12 +552,11 @@ template <
554
552
size_t arg_num = 0 >
555
553
struct check_binary_functor_types_for_specialization {
556
554
constexpr static inline bool check () {
557
- bool current = false ;
558
555
if constexpr (arity != 2 )
559
556
return false ;
560
557
if constexpr (arg_num == 0 ) {
561
558
using SelectedType = std::tuple_element_t <arg_num, TupleLike>;
562
- if constexpr (std::is_same_v<float , SelectedType>)
559
+ if constexpr (std::is_same_v<FirstParamTy , SelectedType>)
563
560
return check_binary_functor_types_for_specialization<
564
561
TupleLike,
565
562
FirstParamTy,
@@ -568,7 +565,7 @@ struct check_binary_functor_types_for_specialization {
568
565
arg_num + 1 >::check ();
569
566
} else if constexpr (arg_num == 1 ) {
570
567
using SelectedType2 = std::tuple_element_t <arg_num, TupleLike>;
571
- if constexpr (std::is_same_v<float , SelectedType2>)
568
+ if constexpr (std::is_same_v<SecondParamTy , SelectedType2>)
572
569
return check_binary_functor_types_for_specialization<
573
570
TupleLike,
574
571
FirstParamTy,
@@ -613,30 +610,91 @@ struct check_binary_functor_types_for_specialization<
613
610
};
614
611
615
612
// The following is a list of type specializations for vectorized_templated
616
- // elementwise kernel. It refers to the first and second runtime types of the
617
- // arguments of a binary functor.
618
- constexpr int number_of_binary_specializations = 4 ;
619
- const std::
620
- array<std::array<c10::ScalarType, 2 >, number_of_binary_specializations>
621
- rt_binary_specializations = {
622
- {{c10::CppTypeToScalarType<float >::value,
623
- c10::CppTypeToScalarType<BFloat16>::value},
624
- {c10::CppTypeToScalarType<BFloat16>::value,
625
- c10::CppTypeToScalarType<float >::value},
626
- {c10::CppTypeToScalarType<float >::value,
627
- c10::CppTypeToScalarType<Half>::value},
628
- {c10::CppTypeToScalarType<Half>::value,
629
- c10::CppTypeToScalarType<float >::value}}};
613
+ // elementwise kernel. The three types refer to runtime types of the output
614
+ // tensor, first tensor argument, and the second tensor argument used for a
615
+ // binary functor.
616
+ constexpr std::array rt_binary_specializations = {
617
+ std::array<c10::ScalarType, 3 >(
618
+ {c10::CppTypeToScalarType<float >::value,
619
+ c10::CppTypeToScalarType<float >::value,
620
+ c10::CppTypeToScalarType<BFloat16>::value}),
621
+ std::array<c10::ScalarType, 3 >(
622
+ {c10::CppTypeToScalarType<float >::value,
623
+ c10::CppTypeToScalarType<BFloat16>::value,
624
+ c10::CppTypeToScalarType<float >::value}),
625
+ std::array<c10::ScalarType, 3 >(
626
+ {c10::CppTypeToScalarType<BFloat16>::value,
627
+ c10::CppTypeToScalarType<BFloat16>::value,
628
+ c10::CppTypeToScalarType<float >::value}),
629
+ std::array<c10::ScalarType, 3 >(
630
+ {c10::CppTypeToScalarType<float >::value,
631
+ c10::CppTypeToScalarType<float >::value,
632
+ c10::CppTypeToScalarType<Half>::value}),
633
+ std::array<c10::ScalarType, 3 >(
634
+ {c10::CppTypeToScalarType<float >::value,
635
+ c10::CppTypeToScalarType<Half>::value,
636
+ c10::CppTypeToScalarType<float >::value}),
637
+ std::array<c10::ScalarType, 3 >(
638
+ {c10::CppTypeToScalarType<Half>::value,
639
+ c10::CppTypeToScalarType<Half>::value,
640
+ c10::CppTypeToScalarType<float >::value})};
630
641
631
642
bool check_binary_rt_types_for_specialization (TensorIteratorBase& iter) {
632
643
if (iter.ninputs () != 2 )
633
644
return false ;
634
- for (int i = 0 ; i < 4 ; i++ )
635
- if (iter.input_dtype (0 ) == rt_binary_specializations[i][ 0 ] &&
636
- iter.input_dtype (1 ) == rt_binary_specializations[i][ 1 ])
645
+ for (auto spec : rt_binary_specializations )
646
+ if (iter.dtype (0 ) == spec[ 0 ] && iter. input_dtype ( 0 ) == spec[ 1 ] &&
647
+ iter.input_dtype (1 ) == spec[ 2 ])
637
648
return true ;
638
649
return false ;
639
650
}
651
+
652
+ template <int arg_index>
653
+ struct type_specialized_kernel_launcher {
654
+ template <
655
+ typename func_t ,
656
+ typename array_t ,
657
+ typename inp_calc_t ,
658
+ typename out_calc_t ,
659
+ typename loader_t ,
660
+ typename storer_t >
661
+ static void apply (
662
+ ScalarType ret_t ,
663
+ ScalarType arg0_t ,
664
+ ScalarType arg1_t ,
665
+ int64_t numel,
666
+ func_t f,
667
+ array_t data,
668
+ inp_calc_t input_offset_calculator,
669
+ out_calc_t output_offset_calculator,
670
+ loader_t loader,
671
+ storer_t storer) {
672
+ if (ret_t == rt_binary_specializations[arg_index][0 ] &&
673
+ arg0_t == rt_binary_specializations[arg_index][1 ] &&
674
+ arg1_t == rt_binary_specializations[arg_index][2 ])
675
+ launch_vectorized_templated_kernel<
676
+ func_t ,
677
+ array_t ,
678
+ inp_calc_t ,
679
+ out_calc_t ,
680
+ loader_t ,
681
+ storer_t ,
682
+ decltype (c10::impl::ScalarTypeToCPPType<
683
+ rt_binary_specializations[arg_index][0 ]>::t),
684
+ decltype (c10::impl::ScalarTypeToCPPType<
685
+ rt_binary_specializations[arg_index][1 ]>::t),
686
+ decltype (c10::impl::ScalarTypeToCPPType<
687
+ rt_binary_specializations[arg_index][2 ]>::t)>(
688
+ numel,
689
+ f,
690
+ data,
691
+ input_offset_calculator,
692
+ output_offset_calculator,
693
+ loader,
694
+ storer);
695
+ }
696
+ };
697
+
640
698
} // namespace
641
699
#endif
642
700
@@ -666,10 +724,10 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
666
724
#ifdef USE_ROCM
667
725
// Attempt to call specialized vectorized elementwise kernel
668
726
// that enables interleaving.
669
- if (false && check_binary_rt_types_for_specialization (iter) &&
727
+ if (check_binary_rt_types_for_specialization (iter) &&
670
728
memory::can_vectorize_up_to<func_t >(data) > 1 ) {
671
- // constexpr to reduce the amount of kernels (empty) generated for
672
- // unrolled templated elementwise and limit which functors are actually
729
+ // constexpr to reduce the amount of kernels generated for
730
+ // vectorized templated elementwise and limit which functors are actually
673
731
// applied to the load and store at compile time.
674
732
using func_tuple = typename traits::ArgsTuple;
675
733
if constexpr (
@@ -679,7 +737,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
679
737
float ,
680
738
float ,
681
739
traits::arity,
682
- /* current =*/ 0 >::check ()) {
740
+ /* arg_num =*/ 0 >::check ()) {
683
741
// If we got here, we know we are in one of the specialized cases. We
684
742
// need to translate the runtime type to a statically known type. This
685
743
// is effectively hoisting to the host the switch over runtime type in
@@ -689,90 +747,24 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
689
747
auto output_offset_calculator = TrivialOffsetCalculator<1 >();
690
748
auto loader = memory::LoadWithCast<traits::arity>(iter);
691
749
auto storer = memory::StoreWithCast<1 >(iter);
692
- if (iter.input_dtype (0 ) == c10::CppTypeToScalarType<float >::value &&
693
- iter.input_dtype (1 ) == c10::CppTypeToScalarType<BFloat16>::value)
694
- launch_vectorized_templated_kernel<
695
- func_t ,
696
- at::detail::Array<char *, ntensors>,
697
- decltype (input_offset_calculator),
698
- decltype (output_offset_calculator),
699
- decltype (loader),
700
- decltype (storer),
701
- float ,
702
- float ,
703
- BFloat16>(
704
- numel,
705
- f,
706
- data,
707
- input_offset_calculator,
708
- output_offset_calculator,
709
- loader,
710
- storer);
711
- else if (
712
- iter.input_dtype (0 ) == c10::CppTypeToScalarType<BFloat16>::value &&
713
- iter.input_dtype (1 ) == c10::CppTypeToScalarType<float >::value)
714
- launch_vectorized_templated_kernel<
715
- func_t ,
716
- at::detail::Array<char *, ntensors>,
717
- decltype (input_offset_calculator),
718
- decltype (output_offset_calculator),
719
- decltype (loader),
720
- decltype (storer),
721
- float ,
722
- BFloat16,
723
- float >(
724
- numel,
725
- f,
726
- data,
727
- input_offset_calculator,
728
- output_offset_calculator,
729
- loader,
730
- storer);
731
- else if (
732
- iter.input_dtype (0 ) == c10::CppTypeToScalarType<float >::value &&
733
- iter.input_dtype (1 ) == c10::CppTypeToScalarType<Half>::value)
734
- launch_vectorized_templated_kernel<
735
- func_t ,
736
- at::detail::Array<char *, ntensors>,
737
- decltype (input_offset_calculator),
738
- decltype (output_offset_calculator),
739
- decltype (loader),
740
- decltype (storer),
741
- float ,
742
- float ,
743
- Half>(
744
- numel,
745
- f,
746
- data,
747
- input_offset_calculator,
748
- output_offset_calculator,
749
- loader,
750
- storer);
751
- else if (
752
- iter.input_dtype (0 ) == c10::CppTypeToScalarType<Half>::value &&
753
- iter.input_dtype (1 ) == c10::CppTypeToScalarType<float >::value)
754
- launch_vectorized_templated_kernel<
755
- func_t ,
756
- at::detail::Array<char *, ntensors>,
757
- decltype (input_offset_calculator),
758
- decltype (output_offset_calculator),
759
- decltype (loader),
760
- decltype (storer),
761
- float ,
762
- Half,
763
- float >(
764
- numel,
765
- f,
766
- data,
767
- input_offset_calculator,
768
- output_offset_calculator,
769
- loader,
770
- storer);
771
- else
772
- TORCH_CHECK (false , " unreachable" );
750
+ memory::detail::static_unroll<
751
+ type_specialized_kernel_launcher,
752
+ rt_binary_specializations.size ()>::
753
+ with_args (
754
+ iter.dtype (0 ),
755
+ iter.input_dtype (0 ),
756
+ iter.input_dtype (1 ),
757
+ numel,
758
+ f,
759
+ data,
760
+ input_offset_calculator,
761
+ output_offset_calculator,
762
+ loader,
763
+ storer);
773
764
return ;
774
765
}
775
766
}
767
+
776
768
at::detail::Array<ScalarType, ntensors> dtypes;
777
769
auto inner_strides = iter.get_inner_strides ();
778
770
at::detail::Array<int , ntensors> strides;
0 commit comments