@@ -482,6 +482,338 @@ struct matrix_params<
482
482
template <typename Group>
483
483
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
484
484
};
485
+
486
+ // ////////////////////////////////////////////
487
+ // / AMD Matrix Cores - GFX90A architecture ///
488
+ // ////////////////////////////////////////////
489
+
490
+ template <typename Ta, typename Tc>
491
+ constexpr bool is_combination_valid_amd_gfx90a (size_t sM , size_t sN ,
492
+ size_t sK ) {
493
+ return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float > &&
494
+ ((sM == 32 && sN == 32 && sK == 8 ) ||
495
+ (sM == 16 && sN == 16 && sK == 16 ))) ||
496
+ (std::is_same_v<Ta, int8_t > && std::is_same_v<Tc, int32_t > &&
497
+ ((sM == 32 && sN == 32 && sK == 8 ) ||
498
+ (sM == 16 && sN == 16 && sK == 16 ))) ||
499
+ (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float > &&
500
+ ((sM == 32 && sN == 32 && sK == 8 ) ||
501
+ (sM == 16 && sN == 16 && sK == 16 ))) ||
502
+ (std::is_same_v<Ta, double > && std::is_same_v<Tc, double > &&
503
+ (sM == 16 && sN == 16 && sK == 4 ));
504
+ }
505
+
506
+ template <typename Ta, typename Tc>
507
+ constexpr bool are_types_valid_amd_gfx90a () {
508
+ return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float >) ||
509
+ (std::is_same_v<Ta, int8_t > && std::is_same_v<Tc, int32_t >) ||
510
+ (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float >) ||
511
+ (std::is_same_v<Ta, double > && std::is_same_v<Tc, double >);
512
+ }
513
+
514
+ // Default-values query:
515
+ // Specialization for when only types are given, need to query only sizes
516
+ template <typename Ta, typename Tb, typename Tc, typename Td>
517
+ struct matrix_params <
518
+ architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, 0 , 0 , 0 ,
519
+ typename std::enable_if_t <(
520
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
521
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
522
+ std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
523
+ static_assert (
524
+ are_types_valid_amd_gfx90a<Ta, Tc>(),
525
+ " Invalid types for AMD gfx90a, supported types are half, float, "
526
+ " int8_t, int32_t, double and bfloat16 " );
527
+
528
+ // Default sizes for AMD gfx90a were chosen to represent a square matrix
529
+ static constexpr std::size_t M = 16 ;
530
+ static constexpr std::size_t N = 16 ;
531
+ static constexpr std::size_t K = ((sizeof (Ta) == 8 ) ? 16 : 4 );
532
+
533
+ template <typename Group, layout Layout>
534
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
535
+ template <typename Group, layout Layout>
536
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
537
+ template <typename Group>
538
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
539
+ template <typename Group>
540
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
541
+ };
542
+
543
+ // Validation query
544
+ // Specialization when both types and sizes are given
545
+ template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM ,
546
+ size_t sN , size_t sK >
547
+ struct matrix_params <
548
+ architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, sM , sN , sK ,
549
+ typename std::enable_if_t <(
550
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
551
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
552
+ std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 &&
553
+ sN != 0 && sK != 0 )>> {
554
+ static_assert (
555
+ is_combination_valid_amd_gfx90a<Ta, Tc>(sM , sN , sK ),
556
+ " Invalid parameters for AMD gfx90a, query valid combinations "
557
+ " using: "
558
+ " q.get_device().get_info<sycl::info::device::matrix::combinations>()" );
559
+
560
+ static constexpr std::size_t M = sM ;
561
+ static constexpr std::size_t N = sN ;
562
+ static constexpr std::size_t K = sK ;
563
+
564
+ template <typename Group, layout Layout>
565
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
566
+ template <typename Group, layout Layout>
567
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
568
+ template <typename Group>
569
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
570
+ template <typename Group>
571
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
572
+ };
573
+
574
+ // ///////////////////////////////////////////////
575
+ // / CUDA Tensor Cores - sm70, sm72 and sm80 ///
576
+ // ///////////////////////////////////////////////
577
+
578
+ template <typename Ta, typename Tc, typename Td>
579
+ constexpr bool are_types_valid_cuda_sm70 () {
580
+ return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float > &&
581
+ std::is_same_v<Td, float >) ||
582
+ (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
583
+ std::is_same_v<Td, half>) ||
584
+ (std::is_same_v<Ta, half> && std::is_same_v<Tc, float > &&
585
+ std::is_same_v<Td, half>) ||
586
+ (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
587
+ std::is_same_v<Td, float >);
588
+ }
589
+
590
+ template <typename Ta, typename Tc, typename Td>
591
+ constexpr bool are_types_valid_cuda_sm72 () {
592
+ return (std::is_same_v<Ta, int8_t > && std::is_same_v<Tc, int32_t > &&
593
+ std::is_same_v<Td, int32_t >) ||
594
+ (std::is_same_v<Ta, uint8_t > && std::is_same_v<Tc, int32_t > &&
595
+ std::is_same_v<Td, int32_t >);
596
+ }
597
+
598
+ template <typename Ta, typename Tc, typename Td>
599
+ constexpr bool are_types_valid_cuda_sm80 () {
600
+ return (std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float > &&
601
+ std::is_same_v<Td, float >) ||
602
+ (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float > &&
603
+ std::is_same_v<Td, float >) ||
604
+ (std::is_same_v<Ta, double > && std::is_same_v<Tc, double > &&
605
+ std::is_same_v<Td, double >);
606
+ }
607
+
608
+ template <typename Ta, typename Tc, typename Td>
609
+ constexpr bool is_combination_valid_cuda_sm70 (size_t sM , size_t sN , size_t sK ) {
610
+ return are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
611
+ ((sM == 8 && sN == 32 && sK == 16 ) ||
612
+ (sM == 16 && sN == 16 && sK == 16 ) ||
613
+ (sM == 32 && sN == 8 && sK == 16 ));
614
+ }
615
+
616
+ template <typename Ta, typename Tc, typename Td>
617
+ constexpr bool is_combination_valid_cuda_sm72 (size_t sM , size_t sN , size_t sK ) {
618
+ return are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
619
+ ((sM == 8 && sN == 32 && sK == 16 ) ||
620
+ (sM == 16 && sN == 16 && sK == 16 ) ||
621
+ (sM == 32 && sN == 8 && sK == 16 ));
622
+ }
623
+
624
+ template <typename Ta, typename Tc, typename Td>
625
+ constexpr bool is_combination_valid_cuda_sm80 (size_t sM , size_t sN , size_t sK ) {
626
+ return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float > &&
627
+ std::is_same_v<Td, float >)&&(sM == 16 && sN == 16 && sK == 8 )) ||
628
+ ((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float > &&
629
+ std::is_same_v<Td, float >)&&((sM == 16 && sN == 16 && sK == 16 ) ||
630
+ (sM == 8 && sN == 32 && sK == 16 ) ||
631
+ (sM == 32 && sN == 8 && sK == 16 ))) ||
632
+ ((std::is_same_v<Ta, double > && std::is_same_v<Tc, double > &&
633
+ std::is_same_v<Td, double >)&&(sM == 8 && sN == 8 && sK == 4 ));
634
+ }
635
+
636
+ // Default-values query (nvidia sm70):
637
+ // Specialization for when only types are given, need to query only sizes
638
+ template <typename Ta, typename Tb, typename Tc, typename Td>
639
+ struct matrix_params <
640
+ architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, 0 , 0 , 0 ,
641
+ typename std::enable_if_t <(
642
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
643
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
644
+ std::is_same_v<Ta, Tb>)>> {
645
+ static_assert (
646
+ are_types_valid_cuda_sm70<Ta, Tc, Td>(),
647
+ " Invalid types for nvidia sm70, supported types are half and float " );
648
+
649
+ // Default sizes for nvidia sm70 were chosen to represent a square matrix
650
+ static constexpr std::size_t M = 16 ;
651
+ static constexpr std::size_t N = 16 ;
652
+ static constexpr std::size_t K = 16 ;
653
+
654
+ template <typename Group, layout Layout>
655
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
656
+ template <typename Group, layout Layout>
657
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
658
+ template <typename Group>
659
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
660
+ template <typename Group>
661
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
662
+ };
663
+
664
+ // Default-values query (nvidia sm72):
665
+ // Specialization for when only types are given, need to query only sizes
666
+ template <typename Ta, typename Tb, typename Tc, typename Td>
667
+ struct matrix_params <
668
+ architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, 0 , 0 , 0 ,
669
+ typename std::enable_if<(
670
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
671
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
672
+ std::is_same_v<Ta, Tb>)>::type> {
673
+ static_assert (
674
+ are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
675
+ are_types_valid_cuda_sm72<Ta, Tc, Td>(),
676
+ " Invalid types for nvidia sm72, supported types are half, float "
677
+ " int8_t, uint8_t and int32_t " );
678
+
679
+ static constexpr std::size_t M = 16 ;
680
+ static constexpr std::size_t N = 16 ;
681
+ static constexpr std::size_t K = 16 ;
682
+
683
+ template <typename Group, layout Layout>
684
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
685
+ template <typename Group, layout Layout>
686
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
687
+ template <typename Group>
688
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
689
+ template <typename Group>
690
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
691
+ };
692
+
693
+ // Default-values query (nvidia sm80):
694
+ // Specialization for when only types are given, need to query only sizes
695
+ template <typename Ta, typename Tb, typename Tc, typename Td>
696
+ struct matrix_params <
697
+ architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, 0 , 0 , 0 ,
698
+ typename std::enable_if_t <(
699
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
700
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
701
+ std::is_same_v<Ta, Tb>)>> {
702
+ static_assert (
703
+ are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
704
+ are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
705
+ are_types_valid_cuda_sm80<Ta, Tc, Td>(),
706
+ " Invalid types for nvidia sm80, supported types are half, float "
707
+ " int8_t, uint8_t, int32_t, double, tf32 and bfloat16 " );
708
+
709
+ static constexpr std::size_t M = (sizeof (Ta) == 8 ) ? 8 : 16 ;
710
+ static constexpr std::size_t N = (sizeof (Ta) == 8 ) ? 8 : 16 ;
711
+ static constexpr std::size_t K =
712
+ std::is_same_v<Ta, precision::tf32> ? 8 : (sizeof (Ta) == 8 ? 4 : 16 );
713
+
714
+ template <typename Group, layout Layout>
715
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
716
+ template <typename Group, layout Layout>
717
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
718
+ template <typename Group>
719
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
720
+ template <typename Group>
721
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
722
+ };
723
+
724
+ // Validation query (nvidia sm70)
725
+ // Specialization when both types and sizes are given
726
+ template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM ,
727
+ size_t sN , size_t sK >
728
+ struct matrix_params <
729
+ architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, sM , sN , sK ,
730
+ typename std::enable_if_t <(
731
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
732
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
733
+ std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0 )>> {
734
+ static_assert (
735
+ is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM , sN , sK ),
736
+ " Invalid parameters for nvidia sm70, query valid combinations "
737
+ " using: "
738
+ " q.get_device().get_info<sycl::info::device::matrix::combinations>()" );
739
+
740
+ static constexpr std::size_t M = sM ;
741
+ static constexpr std::size_t N = sN ;
742
+ static constexpr std::size_t K = sK ;
743
+
744
+ template <typename Group, layout Layout>
745
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
746
+ template <typename Group, layout Layout>
747
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
748
+ template <typename Group>
749
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
750
+ template <typename Group>
751
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
752
+ };
753
+
754
+ // Validation query (nvidia sm72)
755
+ // Specialization when both types and sizes are given
756
+ template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM ,
757
+ size_t sN , size_t sK >
758
+ struct matrix_params <
759
+ architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, sM , sN , sK ,
760
+ typename std::enable_if_t <(
761
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
762
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
763
+ std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0 )>> {
764
+ static_assert (
765
+ is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM , sN , sK ) ||
766
+ is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM , sN , sK ),
767
+ " Invalid parameters for nvidia sm72, query valid combinations "
768
+ " using: "
769
+ " q.get_device().get_info<sycl::info::device::matrix::combinations>()" );
770
+
771
+ static constexpr std::size_t M = sM ;
772
+ static constexpr std::size_t N = sN ;
773
+ static constexpr std::size_t K = sK ;
774
+
775
+ template <typename Group, layout Layout>
776
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
777
+ template <typename Group, layout Layout>
778
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
779
+ template <typename Group>
780
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
781
+ template <typename Group>
782
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
783
+ };
784
+
785
+ // Validation query (nvidia sm80)
786
+ // Specialization when both types and sizes are given
787
+ template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM ,
788
+ size_t sN , size_t sK >
789
+ struct matrix_params <
790
+ architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, sM , sN , sK ,
791
+ typename std::enable_if_t <(
792
+ !std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
793
+ !std::is_same_v<Tc, void > && !std::is_same_v<Td, void > &&
794
+ std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0 )>> {
795
+ static_assert (
796
+ is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM , sN , sK ) ||
797
+ is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM , sN , sK ) ||
798
+ is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM , sN , sK ),
799
+ " Invalid parameters for nvidia sm80, query valid combinations "
800
+ " using: "
801
+ " q.get_device().get_info<sycl::info::device::matrix::combinations>()" );
802
+
803
+ static constexpr std::size_t M = sM ;
804
+ static constexpr std::size_t N = sN ;
805
+ static constexpr std::size_t K = sK ;
806
+
807
+ template <typename Group, layout Layout>
808
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
809
+ template <typename Group, layout Layout>
810
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
811
+ template <typename Group>
812
+ using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
813
+ template <typename Group>
814
+ using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
815
+ };
816
+
485
817
} // namespace experimental::matrix
486
818
} // namespace oneapi
487
819
} // namespace ext
0 commit comments