Skip to content

Commit 8bdd54b

Browse files
authored
[TOPI] Fix SME conv2d schedule import and intrin argument (#17040)
Fixes a merge conflict between #16981 and #17003. Change-Id: Ifcc983ef0b8c00250568a048fd682933adfdcde4
1 parent d9240e4 commit 8bdd54b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/tvm/topi/arm_cpu/conv2d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
729729
# pylint: disable=import-outside-toplevel
730730
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
731731
from tvm.tir.tensor_intrin.arm_cpu import (
732-
ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
732+
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
733733
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
734734
ARM_SME_INIT,
735735
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
@@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
743743
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
744744
sch.parallel(b)
745745
sch.reorder(b, ko, mo, ki, mi)
746-
sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)
746+
sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
747747

748748
# Split and reorder the loops of the GeMM for tensorization
749749
b, m, n, k = sch.get_loops(gemm_block)
@@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
760760
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
761761
tvm.tir.TensorIntrin.register(
762762
sme_gemm_interleaved_intrin_name,
763-
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded),
763+
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
764764
override=True,
765765
)
766766
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)

0 commit comments

Comments
 (0)