diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index fe9b748a1..2d8fdb614 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -275,7 +275,7 @@ def write_if_different(path: Path, content: str) -> None: ) parser.add_argument( "--use_fp16_qk_reductions", - type=lambda x: x if isinstance(x, int) else int(x.lower() == "true"), + type=lambda x: x if isinstance(x, int) else int(x.lower() == "true" or x.lower() == "on"), required=True, nargs="+", help="Allow fp16 qk reductions", @@ -289,30 +289,30 @@ def write_if_different(path: Path, content: str) -> None: ) parser.add_argument( "--enable_f16", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"), required=True, - nargs="+", + nargs="?", help="Enable fp16", ) parser.add_argument( "--enable_bf16", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"), required=True, - nargs="+", + nargs="?", help="Enable bf16", ) parser.add_argument( "--enable_fp8_e4m3", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"), default=True, - nargs="+", + nargs="?", help="Enable fp8_e4m3", ) parser.add_argument( "--enable_fp8_e5m2", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"), default=True, - nargs="+", + nargs="?", help="Enable fp8_e5m2", ) args = parser.parse_args()