Skip to content

Commit 9139bdd

Browse files
authored
[ROCm] CK implementation support causal mask (#18943)
Use `MaskingSpecialization::MaskOutUpperTriangle` to support causal mask in ck implementation.
1 parent a2eb967 commit 9139bdd

File tree

7 files changed

+364
-113
lines changed

7 files changed

+364
-113
lines changed

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio
3131

3232
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
3333

34-
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
34+
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
3535
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation
3636

3737
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
@@ -141,6 +141,35 @@ std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
141141
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
142142
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>();
143143

144+
template <>
145+
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
146+
2, 1, 1, 1, 1,
147+
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
148+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
149+
MaskingSpecialization::MaskOutUpperTriangle>>>
150+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
151+
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();
152+
153+
// fp16, biased, non-masked
154+
template <>
155+
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
156+
2, 1, 1, 1, 1,
157+
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
158+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
159+
MaskingSpecialization::MaskOutUpperTriangle>>>
160+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
161+
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();
162+
163+
// fp16, biased, fp16 masked, basically, two bias
164+
template <>
165+
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
166+
2, 1, 1, 1, 1,
167+
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
168+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
169+
MaskingSpecialization::MaskOutUpperTriangle>>>
170+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
171+
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();
172+
144173
} // namespace internal
145174
} // namespace rocm
146175
} // namespace contrib

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
3232
return instances;
3333
}
3434

35+
using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
36+
2, 1, 1, 1, 1,
37+
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
38+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
39+
MaskingSpecialization::MaskOutUpperTriangle>;
40+
41+
template <>
42+
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>>
43+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
44+
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
45+
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>> instances;
46+
ck::tensor_operation::device::instance::add_device_operation_instances(
47+
instances,
48+
device_batched_gemm_softmax_gemm_permute_instances<
49+
2, 1, 1, 1, 1,
50+
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp,
51+
MaskingSpecialization::MaskOutUpperTriangle>{});
52+
53+
return instances;
54+
}
55+
3556
} // namespace internal
3657
} // namespace rocm
3758
} // namespace contrib

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
3232
return instances;
3333
}
3434

35+
using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
36+
2, 1, 1, 1, 1,
37+
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
38+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
39+
MaskingSpecialization::MaskOutUpperTriangle>;
40+
41+
template <>
42+
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
43+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
44+
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
45+
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
46+
ck::tensor_operation::device::instance::add_device_operation_instances(
47+
instances,
48+
device_batched_gemm_softmax_gemm_permute_instances<
49+
2, 1, 1, 1, 1,
50+
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp,
51+
MaskingSpecialization::MaskOutUpperTriangle>{});
52+
53+
return instances;
54+
}
55+
3556
} // namespace internal
3657
} // namespace rocm
3758
} // namespace contrib

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
3232
return instances;
3333
}
3434

35+
using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
36+
2, 1, 1, 1, 1,
37+
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
38+
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
39+
MaskingSpecialization::MaskOutUpperTriangle>;
40+
41+
template <>
42+
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
43+
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
44+
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
45+
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
46+
ck::tensor_operation::device::instance::add_device_operation_instances(
47+
instances,
48+
device_batched_gemm_softmax_gemm_permute_instances<
49+
2, 1, 1, 1, 1,
50+
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp,
51+
MaskingSpecialization::MaskOutUpperTriangle>{});
52+
53+
return instances;
54+
}
55+
3556
} // namespace internal
3657
} // namespace rocm
3758
} // namespace contrib

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

Lines changed: 123 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
732732
733733
#ifdef USE_COMPOSABLE_KERNEL
734734
735-
template <typename T, bool USE_BIAS, bool USE_MASK>
736-
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
735+
template <typename U, typename V, typename T, bool USE_BIAS, bool USE_MASK>
736+
auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params) {
737737
constexpr const int kNumBiasBuffer = static_cast<int>(USE_BIAS) + static_cast<int>(USE_MASK);
738738
739739
using Nop = ck::tensor_operation::element_wise::PassThrough;
740740
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;
741741
742+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
743+
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
744+
"attention mode is not supported, got ", params->attention->mode);
745+
if constexpr (USE_BIAS) {
746+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
747+
params->bias_buffer == nullptr, "biased version only support input with bias");
748+
} else {
749+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
750+
params->bias_buffer != nullptr, "non-biased version only support input without bias");
751+
}
752+
if constexpr (USE_MASK) {
753+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
754+
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
755+
"mask type is not supported, got ", params->attention->mask_type);
756+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
757+
params->mask_index_buffer == nullptr, "masked version only support input with mask");
758+
} else {
759+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
760+
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
761+
}
762+
763+
auto attn = params->attention;
764+
const int& G0 = attn->batch_size;
765+
const int& G1 = attn->num_heads;
766+
const int& M = attn->sequence_length;
767+
const int& N = attn->total_sequence_length;
768+
const int& K = attn->head_size;
769+
const int& O = attn->v_head_size;
770+
{
771+
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
772+
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
773+
}
774+
775+
auto [qs, ks, vs] = GetQkvStrides(attn);
776+
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
777+
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
778+
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
779+
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
780+
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
781+
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
782+
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
783+
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213
784+
785+
std::array<void*, kNumBiasBuffer> bias_buffers{};
786+
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
787+
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
788+
if constexpr (USE_BIAS) {
789+
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
790+
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
791+
bias_strides[0] = {G1 * M * N, M * N, N, 1};
792+
}
793+
if constexpr (USE_MASK) {
794+
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
795+
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
796+
if (params->mask_index_dims.size() == 2) { // [B,T]
797+
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
798+
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
799+
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
800+
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
801+
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
802+
} else {
803+
ORT_ENFORCE(false, "Unreachable");
804+
}
805+
}
806+
807+
auto arg = impl->MakeArgumentPointer(
808+
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
809+
bias_buffers, // Gemm1 bias, as attention mask
810+
{}, // Gemm2 bias
811+
q_buffer_lengths, q_buffer_strides,
812+
k_buffer_lengths, k_buffer_strides,
813+
v_buffer_lengths, v_buffer_strides,
814+
out_buffer_lengths, out_buffer_strides,
815+
bias_lengths, bias_strides,
816+
{},
817+
{},
818+
Nop{},
819+
Nop{},
820+
Acc0ElementOp{params->scale},
821+
Nop{},
822+
Nop{});
823+
824+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
825+
impl->GetTypeString(), " does not support the params");
826+
827+
if constexpr (USE_MASK) {
828+
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
829+
}
830+
831+
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
832+
return Status::OK();
833+
}
834+
835+
template <typename T, bool USE_BIAS, bool USE_MASK>
836+
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
742837
using CKDataType = typename CKDataTypeAdaptor<T>::type;
743838
using D0DataType = typename ck::detail::tuple_concat<
744839
std::conditional_t<USE_BIAS, ck::Tuple<CKDataType>, ck::Tuple<>>,
745840
std::conditional_t<USE_MASK, ck::Tuple<CKDataType>, ck::Tuple<>>>::type;
746841
747-
constexpr static auto MaskingSpec =
842+
constexpr static auto MaskingSpecMaskDisabled =
748843
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
844+
constexpr static auto MaskingSpecMaskOutUpperTriangle =
845+
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
846+
847+
std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>>
848+
ret;
749849
750-
std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>> ret;
751850
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
752-
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) {
851+
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) {
753852
auto type_string = impl->GetTypeString();
754853
755854
auto invoker = impl->MakeInvokerPointer();
756855
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
757856
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
758857
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
759-
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
760-
"attention mode is not supported, got ", params->attention->mode);
761-
if constexpr (USE_BIAS) {
762-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
763-
params->bias_buffer == nullptr, "biased version only support input with bias");
764-
} else {
765-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
766-
params->bias_buffer != nullptr, "non-biased version only support input without bias");
767-
}
768-
if constexpr (USE_MASK) {
769-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
770-
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
771-
"mask type is not supported, got ", params->attention->mask_type);
772-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
773-
params->mask_index_buffer == nullptr, "masked version only support input with mask");
774-
} else {
775-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
776-
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
777-
}
858+
params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled");
778859
779-
auto attn = params->attention;
780-
const int& G0 = attn->batch_size;
781-
const int& G1 = attn->num_heads;
782-
const int& M = attn->sequence_length;
783-
const int& N = attn->total_sequence_length;
784-
const int& K = attn->head_size;
785-
const int& O = attn->v_head_size;
786-
{
787-
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
788-
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
789-
}
860+
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
861+
};
862+
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
863+
}
790864
791-
auto [qs, ks, vs] = GetQkvStrides(attn);
792-
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
793-
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
794-
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
795-
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
796-
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
797-
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
798-
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
799-
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213
800-
801-
std::array<void*, kNumBiasBuffer> bias_buffers{};
802-
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
803-
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
804-
if constexpr (USE_BIAS) {
805-
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
806-
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
807-
bias_strides[0] = {G1 * M * N, M * N, N, 1};
808-
}
809-
if constexpr (USE_MASK) {
810-
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
811-
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
812-
if (params->mask_index_dims.size() == 2) { // [B,T]
813-
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
814-
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
815-
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
816-
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
817-
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
818-
} else {
819-
ORT_ENFORCE(false, "Unreachable");
820-
}
821-
}
865+
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
866+
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) {
867+
auto type_string = impl->GetTypeString();
822868
823-
auto arg = impl->MakeArgumentPointer(
824-
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
825-
bias_buffers, // Gemm1 bias, as attention mask
826-
{}, // Gemm2 bias
827-
q_buffer_lengths, q_buffer_strides,
828-
k_buffer_lengths, k_buffer_strides,
829-
v_buffer_lengths, v_buffer_strides,
830-
out_buffer_lengths, out_buffer_strides,
831-
bias_lengths, bias_strides,
832-
{},
833-
{},
834-
Nop{},
835-
Nop{},
836-
Acc0ElementOp{params->scale},
837-
Nop{},
838-
Nop{});
839-
840-
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
841-
impl->GetTypeString(), " does not support the params");
842-
843-
if constexpr (USE_MASK) {
844-
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
845-
}
846-
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
847-
return Status::OK();
869+
auto invoker = impl->MakeInvokerPointer();
870+
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
871+
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
872+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
873+
!params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle");
874+
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
875+
params->attention->sequence_length != params->attention->total_sequence_length,
876+
"seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle");
877+
878+
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
848879
};
849880
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
850881
}
882+
851883
return ret;
852884
}
853885
#endif // USE_COMPOSABLE_KERNEL

0 commit comments

Comments
 (0)