Skip to content

Commit 78dde79

Browse files
baowendinbaowending.bwd
and
baowending.bwd
authored
bugfix: fix geneate_dispatch_inc args from parser (#870)
head_dim args from setup.py and parser in aot_build_utils/generate_dispatch_inc.py missmatch Co-authored-by: baowending.bwd <[email protected]>
1 parent 7e06dc0 commit 78dde79

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

aot_build_utils/generate_dispatch_inc.py

+4
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
9999
parser.add_argument(
100100
"--path", type=str, required=True, help="Path to the dispatch inc file"
101101
)
102+
parser.add_argument(
103+
"--head_dims_sm90", type=str, required=True, nargs="+", help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
104+
)
102105
parser.add_argument(
103106
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
104107
)
@@ -124,6 +127,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
124127
help="Mask modes",
125128
)
126129
args = parser.parse_args()
130+
args.head_dims_sm90 = [tuple(map(int, x.split(","))) for x in args.head_dims_sm90]
127131
print(args)
128132
with open(Path(args.path), "w") as f:
129133
f.write(get_dispatch_inc_str(args))

0 commit comments

Comments
 (0)