@@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
729
729
# pylint: disable=import-outside-toplevel
730
730
from tvm .topi .arm_cpu .pstate_attributes import SMEAttributes
731
731
from tvm .tir .tensor_intrin .arm_cpu import (
732
- ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE ,
732
+ ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE ,
733
733
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA ,
734
734
ARM_SME_INIT ,
735
735
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin ,
@@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
743
743
ko , ki = sch .split (k , factors = (None , tile_K ), disable_predication = True )
744
744
sch .parallel (b )
745
745
sch .reorder (b , ko , mo , ki , mi )
746
- sch .tensorize (ki , ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE )
746
+ sch .tensorize (ki , ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE )
747
747
748
748
# Split and reorder the loops of the GeMM for tensorization
749
749
b , m , n , k = sch .get_loops (gemm_block )
@@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
760
760
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{ K_padded } "
761
761
tvm .tir .TensorIntrin .register (
762
762
sme_gemm_interleaved_intrin_name ,
763
- * get_sme_gemm_interleaved_mopa_2svlx2svl_intrin (K_padded ),
763
+ * get_sme_gemm_interleaved_mopa_2svlx2svl_intrin (K_padded , dtype ),
764
764
override = True ,
765
765
)
766
766
sch .tensorize (mi , sme_gemm_interleaved_intrin_name )
0 commit comments