Skip to content

Commit 00eebe1

Browse files
konradkusiak97Konrad Kusiak
and
Konrad Kusiak
authored
[SYCL][Matrix] Add joint matrix query for CUDA and HIP backends (#12075)
This PR adds joint matrix query for CUDA and HIP backends as described in [sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc#query-interface) --------- Co-authored-by: Konrad Kusiak <[email protected]>
1 parent 62a0010 commit 00eebe1

File tree

8 files changed

+693
-11
lines changed

8 files changed

+693
-11
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ sycl/plugins/**/hip/ @intel/llvm-reviewers-cuda
5050
# CUDA specific runtime implementations
5151
sycl/include/sycl/ext/oneapi/experimental/cuda/ @intel/llvm-reviewers-cuda
5252

53-
# CUDA device code tests
53+
# CUDA and HIP device code tests
5454
sycl/test/check_device_code/cuda/ @intel/llvm-reviewers-cuda
55+
sycl/test/check_device_code/hip/ @intel/llvm-reviewers-cuda
5556

5657
# XPTI instrumentation utilities
5758
xpti/ @intel/llvm-reviewers-runtime

sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,338 @@ struct matrix_params<
482482
template <typename Group>
483483
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
484484
};
485+
486+
//////////////////////////////////////////////
487+
/// AMD Matrix Cores - GFX90A architecture ///
488+
//////////////////////////////////////////////
489+
490+
template <typename Ta, typename Tc>
491+
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN,
492+
size_t sK) {
493+
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
494+
((sM == 32 && sN == 32 && sK == 8) ||
495+
(sM == 16 && sN == 16 && sK == 16))) ||
496+
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
497+
((sM == 32 && sN == 32 && sK == 8) ||
498+
(sM == 16 && sN == 16 && sK == 16))) ||
499+
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
500+
((sM == 32 && sN == 32 && sK == 8) ||
501+
(sM == 16 && sN == 16 && sK == 16))) ||
502+
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
503+
(sM == 16 && sN == 16 && sK == 4));
504+
}
505+
506+
template <typename Ta, typename Tc>
507+
constexpr bool are_types_valid_amd_gfx90a() {
508+
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
509+
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
510+
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
511+
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>);
512+
}
513+
514+
// Default-values query:
515+
// Specialization for when only types are given, need to query only sizes
516+
template <typename Ta, typename Tb, typename Tc, typename Td>
517+
struct matrix_params<
518+
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, 0, 0, 0,
519+
typename std::enable_if_t<(
520+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
521+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
522+
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
523+
static_assert(
524+
are_types_valid_amd_gfx90a<Ta, Tc>(),
525+
"Invalid types for AMD gfx90a, supported types are half, float, "
526+
"int8_t, int32_t, double and bfloat16 ");
527+
528+
// Default sizes for AMD gfx90a were chosen to represent a square matrix
529+
static constexpr std::size_t M = 16;
530+
static constexpr std::size_t N = 16;
531+
static constexpr std::size_t K = ((sizeof(Ta) == 8) ? 16 : 4);
532+
533+
template <typename Group, layout Layout>
534+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
535+
template <typename Group, layout Layout>
536+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
537+
template <typename Group>
538+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
539+
template <typename Group>
540+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
541+
};
542+
543+
// Validation query
544+
// Specialization when both types and sizes are given
545+
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
546+
size_t sN, size_t sK>
547+
struct matrix_params<
548+
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, sM, sN, sK,
549+
typename std::enable_if_t<(
550+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
551+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
552+
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 &&
553+
sN != 0 && sK != 0)>> {
554+
static_assert(
555+
is_combination_valid_amd_gfx90a<Ta, Tc>(sM, sN, sK),
556+
"Invalid parameters for AMD gfx90a, query valid combinations "
557+
"using: "
558+
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
559+
560+
static constexpr std::size_t M = sM;
561+
static constexpr std::size_t N = sN;
562+
static constexpr std::size_t K = sK;
563+
564+
template <typename Group, layout Layout>
565+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
566+
template <typename Group, layout Layout>
567+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
568+
template <typename Group>
569+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
570+
template <typename Group>
571+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
572+
};
573+
574+
/////////////////////////////////////////////////
575+
/// CUDA Tensor Cores - sm70, sm72 and sm80 ///
576+
/////////////////////////////////////////////////
577+
578+
template <typename Ta, typename Tc, typename Td>
579+
constexpr bool are_types_valid_cuda_sm70() {
580+
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
581+
std::is_same_v<Td, float>) ||
582+
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
583+
std::is_same_v<Td, half>) ||
584+
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
585+
std::is_same_v<Td, half>) ||
586+
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
587+
std::is_same_v<Td, float>);
588+
}
589+
590+
template <typename Ta, typename Tc, typename Td>
591+
constexpr bool are_types_valid_cuda_sm72() {
592+
return (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
593+
std::is_same_v<Td, int32_t>) ||
594+
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
595+
std::is_same_v<Td, int32_t>);
596+
}
597+
598+
template <typename Ta, typename Tc, typename Td>
599+
constexpr bool are_types_valid_cuda_sm80() {
600+
return (std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
601+
std::is_same_v<Td, float>) ||
602+
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
603+
std::is_same_v<Td, float>) ||
604+
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
605+
std::is_same_v<Td, double>);
606+
}
607+
608+
template <typename Ta, typename Tc, typename Td>
609+
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) {
610+
return are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
611+
((sM == 8 && sN == 32 && sK == 16) ||
612+
(sM == 16 && sN == 16 && sK == 16) ||
613+
(sM == 32 && sN == 8 && sK == 16));
614+
}
615+
616+
template <typename Ta, typename Tc, typename Td>
617+
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) {
618+
return are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
619+
((sM == 8 && sN == 32 && sK == 16) ||
620+
(sM == 16 && sN == 16 && sK == 16) ||
621+
(sM == 32 && sN == 8 && sK == 16));
622+
}
623+
624+
template <typename Ta, typename Tc, typename Td>
625+
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) {
626+
return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
627+
std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
628+
((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
629+
std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
630+
(sM == 8 && sN == 32 && sK == 16) ||
631+
(sM == 32 && sN == 8 && sK == 16))) ||
632+
((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
633+
std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4));
634+
}
635+
636+
// Default-values query (nvidia sm70):
637+
// Specialization for when only types are given, need to query only sizes
638+
template <typename Ta, typename Tb, typename Tc, typename Td>
639+
struct matrix_params<
640+
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, 0, 0, 0,
641+
typename std::enable_if_t<(
642+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
643+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
644+
std::is_same_v<Ta, Tb>)>> {
645+
static_assert(
646+
are_types_valid_cuda_sm70<Ta, Tc, Td>(),
647+
"Invalid types for nvidia sm70, supported types are half and float ");
648+
649+
// Default sizes for nvidia sm70 were chosen to represent a square matrix
650+
static constexpr std::size_t M = 16;
651+
static constexpr std::size_t N = 16;
652+
static constexpr std::size_t K = 16;
653+
654+
template <typename Group, layout Layout>
655+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
656+
template <typename Group, layout Layout>
657+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
658+
template <typename Group>
659+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
660+
template <typename Group>
661+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
662+
};
663+
664+
// Default-values query (nvidia sm72):
665+
// Specialization for when only types are given, need to query only sizes
666+
template <typename Ta, typename Tb, typename Tc, typename Td>
667+
struct matrix_params<
668+
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, 0, 0, 0,
669+
typename std::enable_if<(
670+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
671+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
672+
std::is_same_v<Ta, Tb>)>::type> {
673+
static_assert(
674+
are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
675+
are_types_valid_cuda_sm72<Ta, Tc, Td>(),
676+
"Invalid types for nvidia sm72, supported types are half, float "
677+
"int8_t, uint8_t and int32_t ");
678+
679+
static constexpr std::size_t M = 16;
680+
static constexpr std::size_t N = 16;
681+
static constexpr std::size_t K = 16;
682+
683+
template <typename Group, layout Layout>
684+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
685+
template <typename Group, layout Layout>
686+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
687+
template <typename Group>
688+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
689+
template <typename Group>
690+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
691+
};
692+
693+
// Default-values query (nvidia sm80):
694+
// Specialization for when only types are given, need to query only sizes
695+
template <typename Ta, typename Tb, typename Tc, typename Td>
696+
struct matrix_params<
697+
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, 0, 0, 0,
698+
typename std::enable_if_t<(
699+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
700+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
701+
std::is_same_v<Ta, Tb>)>> {
702+
static_assert(
703+
are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
704+
are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
705+
are_types_valid_cuda_sm80<Ta, Tc, Td>(),
706+
"Invalid types for nvidia sm80, supported types are half, float "
707+
"int8_t, uint8_t, int32_t, double, tf32 and bfloat16 ");
708+
709+
static constexpr std::size_t M = (sizeof(Ta) == 8) ? 8 : 16;
710+
static constexpr std::size_t N = (sizeof(Ta) == 8) ? 8 : 16;
711+
static constexpr std::size_t K =
712+
std::is_same_v<Ta, precision::tf32> ? 8 : (sizeof(Ta) == 8 ? 4 : 16);
713+
714+
template <typename Group, layout Layout>
715+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
716+
template <typename Group, layout Layout>
717+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
718+
template <typename Group>
719+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
720+
template <typename Group>
721+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
722+
};
723+
724+
// Validation query (nvidia sm70)
725+
// Specialization when both types and sizes are given
726+
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
727+
size_t sN, size_t sK>
728+
struct matrix_params<
729+
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, sM, sN, sK,
730+
typename std::enable_if_t<(
731+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
732+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
733+
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
734+
static_assert(
735+
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK),
736+
"Invalid parameters for nvidia sm70, query valid combinations "
737+
"using: "
738+
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
739+
740+
static constexpr std::size_t M = sM;
741+
static constexpr std::size_t N = sN;
742+
static constexpr std::size_t K = sK;
743+
744+
template <typename Group, layout Layout>
745+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
746+
template <typename Group, layout Layout>
747+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
748+
template <typename Group>
749+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
750+
template <typename Group>
751+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
752+
};
753+
754+
// Validation query (nvidia sm72)
755+
// Specialization when both types and sizes are given
756+
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
757+
size_t sN, size_t sK>
758+
struct matrix_params<
759+
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, sM, sN, sK,
760+
typename std::enable_if_t<(
761+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
762+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
763+
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
764+
static_assert(
765+
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
766+
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK),
767+
"Invalid parameters for nvidia sm72, query valid combinations "
768+
"using: "
769+
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
770+
771+
static constexpr std::size_t M = sM;
772+
static constexpr std::size_t N = sN;
773+
static constexpr std::size_t K = sK;
774+
775+
template <typename Group, layout Layout>
776+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
777+
template <typename Group, layout Layout>
778+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
779+
template <typename Group>
780+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
781+
template <typename Group>
782+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
783+
};
784+
785+
// Validation query (nvidia sm80)
786+
// Specialization when both types and sizes are given
787+
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
788+
size_t sN, size_t sK>
789+
struct matrix_params<
790+
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, sM, sN, sK,
791+
typename std::enable_if_t<(
792+
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
793+
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
794+
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
795+
static_assert(
796+
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
797+
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
798+
is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK),
799+
"Invalid parameters for nvidia sm80, query valid combinations "
800+
"using: "
801+
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
802+
803+
static constexpr std::size_t M = sM;
804+
static constexpr std::size_t N = sN;
805+
static constexpr std::size_t K = sK;
806+
807+
template <typename Group, layout Layout>
808+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
809+
template <typename Group, layout Layout>
810+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
811+
template <typename Group>
812+
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
813+
template <typename Group>
814+
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
815+
};
816+
485817
} // namespace experimental::matrix
486818
} // namespace oneapi
487819
} // namespace ext

0 commit comments

Comments
 (0)