@@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
732
732
733
733
#ifdef USE_COMPOSABLE_KERNEL
734
734
735
- template <typename T, bool USE_BIAS, bool USE_MASK>
736
- auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps ( ) {
735
+ template <typename U, typename V, typename T, bool USE_BIAS, bool USE_MASK>
736
+ auto GetArgAndRunInvoker ( const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params ) {
737
737
constexpr const int kNumBiasBuffer = static_cast <int >(USE_BIAS) + static_cast <int >(USE_MASK);
738
738
739
739
using Nop = ck::tensor_operation::element_wise::PassThrough;
740
740
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;
741
741
742
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
743
+ !GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode (params->attention ),
744
+ " attention mode is not supported, got " , params->attention ->mode );
745
+ if constexpr (USE_BIAS) {
746
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
747
+ params->bias_buffer == nullptr , " biased version only support input with bias" );
748
+ } else {
749
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
750
+ params->bias_buffer != nullptr , " non-biased version only support input without bias" );
751
+ }
752
+ if constexpr (USE_MASK) {
753
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
754
+ !GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType (params->attention ),
755
+ " mask type is not supported, got " , params->attention ->mask_type );
756
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
757
+ params->mask_index_buffer == nullptr , " masked version only support input with mask" );
758
+ } else {
759
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
760
+ params->mask_index_buffer != nullptr , " non-masked version only support input without mask" );
761
+ }
762
+
763
+ auto attn = params->attention ;
764
+ const int & G0 = attn->batch_size ;
765
+ const int & G1 = attn->num_heads ;
766
+ const int & M = attn->sequence_length ;
767
+ const int & N = attn->total_sequence_length ;
768
+ const int & K = attn->head_size ;
769
+ const int & O = attn->v_head_size ;
770
+ {
771
+ auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch ();
772
+ ORT_ENFORCE (M == m && N == n && K == k && O == o && G0 * G1 == batch, " semantic mismatch" );
773
+ }
774
+
775
+ auto [qs, ks, vs] = GetQkvStrides (attn);
776
+ std::vector<ck::index_t > q_buffer_lengths = {G0, G1, M, K};
777
+ std::vector<ck::index_t > q_buffer_strides = qs.template ForBNSHCoord <std::vector<ck::index_t >>();
778
+ std::vector<ck::index_t > k_buffer_lengths = {G0, G1, N, K};
779
+ std::vector<ck::index_t > k_buffer_strides = ks.template ForBNSHCoord <std::vector<ck::index_t >>();
780
+ std::vector<ck::index_t > v_buffer_lengths = {G0, G1, O, N};
781
+ std::vector<ck::index_t > v_buffer_strides = vs.template ForBNHSCoord <std::vector<ck::index_t >>();
782
+ std::vector<ck::index_t > out_buffer_lengths = {G0, G1, M, O};
783
+ std::vector<ck::index_t > out_buffer_strides = {M * G1 * O, O, G1 * O, 1 }; // permute 0213
784
+
785
+ std::array<void *, kNumBiasBuffer > bias_buffers{};
786
+ std::array<std::vector<ck::index_t >, kNumBiasBuffer > bias_lengths{};
787
+ std::array<std::vector<ck::index_t >, kNumBiasBuffer > bias_strides{};
788
+ if constexpr (USE_BIAS) {
789
+ bias_buffers[0 ] = const_cast <T*>(params->bias_buffer );
790
+ bias_lengths[0 ] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
791
+ bias_strides[0 ] = {G1 * M * N, M * N, N, 1 };
792
+ }
793
+ if constexpr (USE_MASK) {
794
+ bias_buffers[kNumBiasBuffer - 1 ] = params->workspace_buffer ;
795
+ bias_lengths[kNumBiasBuffer - 1 ] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
796
+ if (params->mask_index_dims .size () == 2 ) { // [B,T]
797
+ bias_strides[kNumBiasBuffer - 1 ] = {N, 0 , 0 , 1 };
798
+ } else if (params->mask_index_dims .size () == 3 ) { // [B,S,T]
799
+ bias_strides[kNumBiasBuffer - 1 ] = {M * N, 0 , N, 1 };
800
+ } else if (params->mask_index_dims .size () == 4 ) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
801
+ bias_strides[kNumBiasBuffer - 1 ] = {M * N, 0 , N, 1 };
802
+ } else {
803
+ ORT_ENFORCE (false , " Unreachable" );
804
+ }
805
+ }
806
+
807
+ auto arg = impl->MakeArgumentPointer (
808
+ params->q_buffer , params->k_buffer , params->v_buffer , params->out_buffer ,
809
+ bias_buffers, // Gemm1 bias, as attention mask
810
+ {}, // Gemm2 bias
811
+ q_buffer_lengths, q_buffer_strides,
812
+ k_buffer_lengths, k_buffer_strides,
813
+ v_buffer_lengths, v_buffer_strides,
814
+ out_buffer_lengths, out_buffer_strides,
815
+ bias_lengths, bias_strides,
816
+ {},
817
+ {},
818
+ Nop{},
819
+ Nop{},
820
+ Acc0ElementOp{params->scale },
821
+ Nop{},
822
+ Nop{});
823
+
824
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (!impl->IsSupportedArgument (arg.get ()),
825
+ impl->GetTypeString (), " does not support the params" );
826
+
827
+ if constexpr (USE_MASK) {
828
+ ORT_RETURN_IF_ERROR (GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue (params));
829
+ }
830
+
831
+ invoker->Run (arg.get (), StreamConfig{params->StreamHandle ()});
832
+ return Status::OK ();
833
+ }
834
+
835
+ template <typename T, bool USE_BIAS, bool USE_MASK>
836
+ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps () {
742
837
using CKDataType = typename CKDataTypeAdaptor<T>::type;
743
838
using D0DataType = typename ck::detail::tuple_concat<
744
839
std::conditional_t <USE_BIAS, ck::Tuple<CKDataType>, ck::Tuple<>>,
745
840
std::conditional_t <USE_MASK, ck::Tuple<CKDataType>, ck::Tuple<>>>::type;
746
841
747
- constexpr static auto MaskingSpec =
842
+ constexpr static auto MaskingSpecMaskDisabled =
748
843
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
844
+ constexpr static auto MaskingSpecMaskOutUpperTriangle =
845
+ ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
846
+
847
+ std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>>
848
+ ret;
749
849
750
- std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>> ret;
751
850
for (auto && impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
752
- CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec >()) {
851
+ CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled >()) {
753
852
auto type_string = impl->GetTypeString ();
754
853
755
854
auto invoker = impl->MakeInvokerPointer ();
756
855
auto op = [impl = std::move (impl), invoker = std::move (invoker)](
757
856
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
758
857
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
759
- !GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode (params->attention ),
760
- " attention mode is not supported, got " , params->attention ->mode );
761
- if constexpr (USE_BIAS) {
762
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
763
- params->bias_buffer == nullptr , " biased version only support input with bias" );
764
- } else {
765
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
766
- params->bias_buffer != nullptr , " non-biased version only support input without bias" );
767
- }
768
- if constexpr (USE_MASK) {
769
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
770
- !GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType (params->attention ),
771
- " mask type is not supported, got " , params->attention ->mask_type );
772
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
773
- params->mask_index_buffer == nullptr , " masked version only support input with mask" );
774
- } else {
775
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
776
- params->mask_index_buffer != nullptr , " non-masked version only support input without mask" );
777
- }
858
+ params->attention ->is_unidirectional , " unidirectional attention is not supported with MaskingSpecMaskDisabled" );
778
859
779
- auto attn = params->attention ;
780
- const int & G0 = attn->batch_size ;
781
- const int & G1 = attn->num_heads ;
782
- const int & M = attn->sequence_length ;
783
- const int & N = attn->total_sequence_length ;
784
- const int & K = attn->head_size ;
785
- const int & O = attn->v_head_size ;
786
- {
787
- auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch ();
788
- ORT_ENFORCE (M == m && N == n && K == k && O == o && G0 * G1 == batch, " semantic mismatch" );
789
- }
860
+ return GetArgAndRunInvoker<decltype (impl), decltype (invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
861
+ };
862
+ ret.emplace_back (std::make_pair (std::move (type_string), std::move (op)));
863
+ }
790
864
791
- auto [qs, ks, vs] = GetQkvStrides (attn);
792
- std::vector<ck::index_t > q_buffer_lengths = {G0, G1, M, K};
793
- std::vector<ck::index_t > q_buffer_strides = qs.template ForBNSHCoord <std::vector<ck::index_t >>();
794
- std::vector<ck::index_t > k_buffer_lengths = {G0, G1, N, K};
795
- std::vector<ck::index_t > k_buffer_strides = ks.template ForBNSHCoord <std::vector<ck::index_t >>();
796
- std::vector<ck::index_t > v_buffer_lengths = {G0, G1, O, N};
797
- std::vector<ck::index_t > v_buffer_strides = vs.template ForBNHSCoord <std::vector<ck::index_t >>();
798
- std::vector<ck::index_t > out_buffer_lengths = {G0, G1, M, O};
799
- std::vector<ck::index_t > out_buffer_strides = {M * G1 * O, O, G1 * O, 1 }; // permute 0213
800
-
801
- std::array<void *, kNumBiasBuffer > bias_buffers{};
802
- std::array<std::vector<ck::index_t >, kNumBiasBuffer > bias_lengths{};
803
- std::array<std::vector<ck::index_t >, kNumBiasBuffer > bias_strides{};
804
- if constexpr (USE_BIAS) {
805
- bias_buffers[0 ] = const_cast <T*>(params->bias_buffer );
806
- bias_lengths[0 ] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
807
- bias_strides[0 ] = {G1 * M * N, M * N, N, 1 };
808
- }
809
- if constexpr (USE_MASK) {
810
- bias_buffers[kNumBiasBuffer - 1 ] = params->workspace_buffer ;
811
- bias_lengths[kNumBiasBuffer - 1 ] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
812
- if (params->mask_index_dims .size () == 2 ) { // [B,T]
813
- bias_strides[kNumBiasBuffer - 1 ] = {N, 0 , 0 , 1 };
814
- } else if (params->mask_index_dims .size () == 3 ) { // [B,S,T]
815
- bias_strides[kNumBiasBuffer - 1 ] = {M * N, 0 , N, 1 };
816
- } else if (params->mask_index_dims .size () == 4 ) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
817
- bias_strides[kNumBiasBuffer - 1 ] = {M * N, 0 , N, 1 };
818
- } else {
819
- ORT_ENFORCE (false , " Unreachable" );
820
- }
821
- }
865
+ for (auto && impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
866
+ CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) {
867
+ auto type_string = impl->GetTypeString ();
822
868
823
- auto arg = impl->MakeArgumentPointer (
824
- params->q_buffer , params->k_buffer , params->v_buffer , params->out_buffer ,
825
- bias_buffers, // Gemm1 bias, as attention mask
826
- {}, // Gemm2 bias
827
- q_buffer_lengths, q_buffer_strides,
828
- k_buffer_lengths, k_buffer_strides,
829
- v_buffer_lengths, v_buffer_strides,
830
- out_buffer_lengths, out_buffer_strides,
831
- bias_lengths, bias_strides,
832
- {},
833
- {},
834
- Nop{},
835
- Nop{},
836
- Acc0ElementOp{params->scale },
837
- Nop{},
838
- Nop{});
839
-
840
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (!impl->IsSupportedArgument (arg.get ()),
841
- impl->GetTypeString (), " does not support the params" );
842
-
843
- if constexpr (USE_MASK) {
844
- ORT_RETURN_IF_ERROR (GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue (params));
845
- }
846
- invoker->Run (arg.get (), StreamConfig{params->StreamHandle ()});
847
- return Status::OK ();
869
+ auto invoker = impl->MakeInvokerPointer ();
870
+ auto op = [impl = std::move (impl), invoker = std::move (invoker)](
871
+ const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
872
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
873
+ !params->attention ->is_unidirectional , " bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle" );
874
+ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF (
875
+ params->attention ->sequence_length != params->attention ->total_sequence_length ,
876
+ " seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle" );
877
+
878
+ return GetArgAndRunInvoker<decltype (impl), decltype (invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
848
879
};
849
880
ret.emplace_back (std::make_pair (std::move (type_string), std::move (op)));
850
881
}
882
+
851
883
return ret;
852
884
}
853
885
#endif // USE_COMPOSABLE_KERNEL
0 commit comments